├── .gitignore ├── README.md ├── acdc_processing.py ├── main.py └── utils ├── args_utils.py ├── data_utils.py ├── dataset_utils.py └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | ACDC/ 2 | Task027_ACDC/ 3 | utils/__pycache__/ 4 | output_model/ 5 | python.log 6 | .DS_Store 7 | lightning_logs/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FCT-Pytorch 2 | Pytorch implementation for The Fully Convolutional Transformer(FCT) 3 | 4 | ## note 5 | This repo can: 6 | 1. reproduces the origin aurhor's work on tensorflow.You need reference the original repo's issue that they only use ACDC train set(split ACDC/traning set into 7:2:1 train:validation:test).You can get dice 92.9 7 | 8 | 2. Get about 90 dice on official test set if your train on the whole train set(using ACDC/training and test on ACDC/testing). 9 | 10 | 11 | ## training 12 | 1. Get ACDC dataset.And remember to delete `.md` file in your ACDC dataset folder 13 | 2. use `python main.py` to start training 14 | -------------------------------------------------------------------------------- /acdc_processing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections import OrderedDict 16 | from batchgenerators.utilities.file_and_folder_operations import * 17 | import shutil 18 | import numpy as np 19 | from sklearn.model_selection import KFold 20 | 21 | 22 | def convert_to_submission(source_dir, target_dir): 23 | niftis = subfiles(source_dir, join=False, suffix=".nii.gz") 24 | patientids = np.unique([i[:10] for i in niftis]) 25 | maybe_mkdir_p(target_dir) 26 | for p in patientids: 27 | files_of_that_patient = subfiles(source_dir, prefix=p, suffix=".nii.gz", join=False) 28 | assert len(files_of_that_patient) 29 | files_of_that_patient.sort() 30 | # first is ED, second is ES 31 | shutil.copy(join(source_dir, files_of_that_patient[0]), join(target_dir, p + "_ED.nii.gz")) 32 | shutil.copy(join(source_dir, files_of_that_patient[1]), join(target_dir, p + "_ES.nii.gz")) 33 | 34 | 35 | if __name__ == "__main__": 36 | folder_train = "ACDC/training" 37 | folder_test = "ACDC/testing" 38 | out_folder = "Task027_ACDC" 39 | 40 | maybe_mkdir_p(join(out_folder, "imagesTr")) 41 | maybe_mkdir_p(join(out_folder, "imagesTs")) 42 | maybe_mkdir_p(join(out_folder, "labelsTr")) 43 | maybe_mkdir_p(join(out_folder, "labelsTs")) 44 | 45 | # train 46 | all_train_files = [] 47 | patient_dirs_train = subfolders(folder_train, prefix="patient") 48 | for p in patient_dirs_train: 49 | current_dir = p 50 | data_files_train = [i for i in subfiles(current_dir, suffix=".nii.gz") if i.find("_gt") == -1 and i.find("_4d") == -1] 51 | corresponding_seg_files = [i[:-7] + "_gt.nii.gz" for i in data_files_train] 52 | for d, s in zip(data_files_train, corresponding_seg_files): 53 | patient_identifier = d.split("/")[-1][:-7] 54 | all_train_files.append(patient_identifier + ".nii.gz") 55 | shutil.copy(d, join(out_folder, "imagesTr", patient_identifier + ".nii.gz")) 56 | shutil.copy(s, join(out_folder, "labelsTr", patient_identifier + "_gt.nii.gz")) 57 | 58 | # test 59 | all_test_files = [] 60 | patient_dirs_test = subfolders(folder_test, prefix="patient") 61 | for p in patient_dirs_test: 62 | current_dir = p 63 | data_files_test = [i for i in subfiles(current_dir, suffix=".nii.gz") if i.find("_gt") == -1 and i.find("_4d") == -1] 64 | for d in data_files_test: 65 | patient_identifier = d.split("/")[-1][:-7] 66 | all_test_files.append(patient_identifier + ".nii.gz") 67 | shutil.copy(d, join(out_folder, "imagesTs", patient_identifier + ".nii.gz")) 68 | shutil.copy(s, join(out_folder, "labelsTs", patient_identifier + "_gt.nii.gz")) 69 | 70 | 71 | json_dict = OrderedDict() 72 | json_dict['name'] = "ACDC" 73 | json_dict['description'] = "cardias cine MRI segmentation" 74 | json_dict['tensorImageSize'] = "4D" 75 | json_dict['reference'] = "see ACDC challenge" 76 | json_dict['licence'] = "see ACDC challenge" 77 | json_dict['release'] = "0.0" 78 | json_dict['modality'] = { 79 | "0": "MRI", 80 | } 81 | json_dict['labels'] = { 82 | "0": "background", 83 | "1": "RV", 84 | "2": "MLV", 85 | "3": "LVC" 86 | } 87 | json_dict['numTraining'] = len(all_train_files) 88 | json_dict['numTest'] = len(all_test_files) 89 | json_dict['training'] = [{'image': f"./imagesTr/{i}", 90 | "label": f"./labelsTr/{i[:-7]}_gt.nii.gz"} for i in 91 | all_train_files] 92 | json_dict['test'] = [{'image': f"./imagesTs/{i}", 93 | "label": f"./labelsTs/{i[:-7]}_gt.nii.gz"} for i in 94 | all_test_files] 95 | 96 | save_json(json_dict, os.path.join(out_folder, "dataset.json")) 97 | 98 | # # create a dummy split (patients need to be separated) 99 | # splits = [] 100 | # patients = np.unique([i[:10] for i in all_train_files]) 101 | # patientids = [i[:-12] for i in all_train_files] 102 | 103 | # kf = KFold(5, True, 12345) 104 | # for tr, val in kf.split(patients): 105 | # splits.append(OrderedDict()) 106 | # tr_patients = patients[tr] 107 | # splits[-1]['train'] = [i[:-12] for i in all_train_files if i[:10] in tr_patients] 108 | # val_patients = patients[val] 109 | # splits[-1]['val'] = [i[:-12] for i in all_train_files if i[:10] in val_patients] 110 | 111 | # save_pickle(splits, "/media/fabian/nnunet/Task027_ACDC/splits_final.pkl") -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import monai.metrics 5 | import lightning as L 6 | import torch.nn.functional as F 7 | import matplotlib.pyplot as plt 8 | from torch.utils.data import DataLoader,TensorDataset 9 | from datetime import datetime 10 | from lightning.pytorch.callbacks.lr_monitor import LearningRateMonitor 11 | # from google.colab import drive 12 | 13 | from torch.utils.tensorboard import SummaryWriter 14 | from utils.args_utils import parse_args 15 | from utils.data_utils import get_acdc,convert_masks 16 | from utils.model import FCT 17 | from utils.dataset_utils import ACDCTrainDataset 18 | 19 | args = parse_args() 20 | 21 | def get_lr_scheduler(args,optimizer): 22 | if args.lr_scheduler == 'none': 23 | return None 24 | if args.lr_scheduler == 'ReduceLROnPlateau': 25 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 26 | optimizer, 27 | mode='min', 28 | factor=args.lr_factor, 29 | verbose=True, 30 | threshold=1e-6, 31 | patience=5, 32 | min_lr=args.min_lr) 33 | return scheduler 34 | if args.lr_scheduler == 'CosineAnnealingWarmRestarts': 35 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( 36 | optimizer, 37 | T_0=500 38 | ) 39 | return scheduler 40 | 41 | @torch.no_grad() 42 | def init_weights(m): 43 | """ 44 | Initialize the weights 45 | """ 46 | if isinstance(m, nn.Conv2d): 47 | torch.nn.init.kaiming_normal_(m.weight) 48 | if m.bias is not None: 49 | torch.nn.init.zeros_(m.bias) 50 | 51 | 52 | def main(): 53 | # model instatation 54 | model = FCT(args) 55 | model.apply(init_weights) 56 | 57 | # get data 58 | # training 59 | acdc_data, _, _ = get_acdc('ACDC/training', input_size=(args.img_size,args.img_size,1)) 60 | acdc_data = ACDCTrainDataset(acdc_data[0], acdc_data[1],args) 61 | train_dataloader = DataLoader(acdc_data, batch_size=args.batch_size,num_workers=args.workers) 62 | # validation 63 | acdc_data, _, _ = get_acdc('ACDC/testing', input_size=(args.img_size,args.img_size,1)) 64 | acdc_data[1] = convert_masks(acdc_data[1]) 65 | acdc_data[0] = np.transpose(acdc_data[0], (0, 3, 1, 2)) # for the channels 66 | acdc_data[1] = np.transpose(acdc_data[1], (0, 3, 1, 2)) # for the channels 67 | acdc_data[0] = torch.Tensor(acdc_data[0]) # convert to tensors 68 | acdc_data[1] = torch.Tensor(acdc_data[1]) # convert to tensors 69 | acdc_data = TensorDataset(acdc_data[0], acdc_data[1]) 70 | validation_dataloader = DataLoader(acdc_data, batch_size=args.batch_size,num_workers=args.workers) 71 | 72 | # resume 73 | # TODO need debug 74 | if args.resume: 75 | if args.new_param: 76 | model = FCT.load_from_checkpoint('lightning_logs/version_2/checkpoints/epoch=74-step=4500.ckpt',args=args) 77 | else: 78 | # load weights,old hyper parameter and optimizer state 79 | model = FCT.load_from_checkpoint('this is path') 80 | 81 | precision = '16-mixed' if args.amp else 32 82 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 83 | 84 | trainer = L.Trainer(precision=precision,max_epochs=args.max_epoch,callbacks=[lr_monitor]) 85 | trainer.fit(model=model,train_dataloaders=train_dataloader,val_dataloaders=validation_dataloader) 86 | 87 | 88 | if __name__ == '__main__': 89 | main() -------------------------------------------------------------------------------- /utils/args_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_args(): 5 | parser = argparse.ArgumentParser(description='FCT for medical image') 6 | 7 | # msic 8 | # https://arxiv.org/pdf/2109.08203.pdf 3407 or other you like 9 | parser.add_argument('--random_seed',default=0,type=int,help='random seed,default 0') 10 | parser.add_argument("--workers", default=4, type=int, help="number of workers") 11 | parser.add_argument("--img_size",default=224,type=int) 12 | parser.add_argument("--amp",default=True,help="use automatic mix precision,defualt True") 13 | 14 | # train parameters 15 | parser.add_argument('--batch_size',default=32,type=int,help='batch size,default size 1') 16 | parser.add_argument('--max_epoch',default=120,type=int,help='the max number with epoch') 17 | parser.add_argument('--lr',default=1e-4,type=float,help='learning rate,when resume,set to 0 means using checkpoint lr,or a new lr') 18 | parser.add_argument('--decay',default=0.001,help='L2 norm') 19 | parser.add_argument('--lr_factor',default=0.5,help='dynamic learning rate factor,when loss not change,new lr = old_lr * lr_factor') 20 | parser.add_argument('--min_lr',default=6e-6,help='min dynamic learing rate') 21 | parser.add_argument('--lr_scheduler',default='ReduceLROnPlateau') 22 | parser.add_argument('--resume',action='store_true') 23 | parser.add_argument('--new_param',action='store_true',help='use new param when resume from check point') 24 | 25 | parser.add_argument('--colab',action='store_true') 26 | # network parameters 27 | 28 | 29 | args = parser.parse_args() 30 | 31 | return args 32 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import nibabel as nib 4 | import numpy as np 5 | from monai import data, transforms 6 | from monai.data import load_decathlon_datalist 7 | 8 | def get_loader(args): 9 | data_dir = args.data_dir 10 | datalist_json = os.path.join(data_dir, 'dataset.json') 11 | train_transform = transforms.Compose( 12 | [ 13 | transforms.LoadImaged(keys=["image", "label"]), 14 | transforms.AddChanneld(keys=["image", "label"]), 15 | transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), 16 | transforms.Spacingd(keys=["image", "label"],pixdim=[1.5,1.5,10]), 17 | transforms.NormalizeIntensityd(keys=["image", "label"]), 18 | # transforms.Resized(keys=['image', 'label'],spatial_size=[256,256,8],mode='trilinear') 19 | ] 20 | ) 21 | 22 | val_transform = transforms.Compose( 23 | [ 24 | transforms.LoadImaged(keys=["image", "label"]), 25 | ] 26 | ) 27 | if args.predict_mode: 28 | pass 29 | else: 30 | datalist = load_decathlon_datalist(datalist_json, True, "training", base_dir=data_dir) 31 | if args.use_normal_dataset: 32 | train_ds = data.Dataset(data=datalist, transform=train_transform) 33 | else: 34 | train_ds = data.CacheDataset( 35 | data=datalist, transform=train_transform, cache_num=24, cache_rate=1.0, num_workers=args.workers 36 | ) 37 | train_loader = data.DataLoader( 38 | train_ds, 39 | batch_size=args.batch_size, 40 | shuffle=True, 41 | num_workers=args.workers, 42 | pin_memory=True, 43 | ) 44 | val_files = load_decathlon_datalist(datalist_json, True, "test", base_dir=data_dir) 45 | val_ds = data.Dataset(data=val_files, transform=val_transform) 46 | val_loader = data.DataLoader( 47 | val_ds, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True 48 | ) 49 | loader = [train_loader, val_loader] 50 | return loader 51 | 52 | def get_acdc(path,input_size=(224,224,1)): 53 | """ 54 | Read images and masks for the ACDC dataset 55 | """ 56 | all_imgs = [] 57 | all_gt = [] 58 | all_header = [] 59 | all_affine = [] 60 | info = [] 61 | for root, directories, files in os.walk(path): 62 | for file in files: 63 | if ".gz" and "frame" in file: 64 | if "_gt" not in file: 65 | img_path = root + "/" + file 66 | img = nib.load(img_path).get_fdata() 67 | all_header.append(nib.load(img_path).header) 68 | all_affine.append(nib.load(img_path).affine) 69 | for idx in range(img.shape[2]): 70 | i = cv2.resize(img[:,:,idx], (input_size[0], input_size[1]), interpolation=cv2.INTER_NEAREST) 71 | all_imgs.append(i) 72 | 73 | else: 74 | img_path = root + "/" + file 75 | img = nib.load(img_path).get_fdata() 76 | for idx in range(img.shape[2]): 77 | i = cv2.resize(img[:,:,idx], (input_size[0], input_size[1]), interpolation=cv2.INTER_NEAREST) 78 | all_gt.append(i) 79 | 80 | 81 | data = [all_imgs, all_gt, info] 82 | 83 | 84 | data[0] = np.expand_dims(data[0], axis=3) 85 | if path[-9:] != "true_test": 86 | data[1] = np.expand_dims(data[1], axis=3) 87 | 88 | return data, all_affine, all_header 89 | 90 | def convert_masks(y, data="acdc"): 91 | """ 92 | Given one masks with many classes create one mask per class 93 | """ 94 | 95 | if data == "acdc": 96 | # initialize 97 | masks = np.zeros((y.shape[0], y.shape[1], y.shape[2], 4)) 98 | 99 | for i in range(y.shape[0]): 100 | masks[i][:,:,0] = np.where(y[i]==0, 1, 0)[:,:,-1] 101 | masks[i][:,:,1] = np.where(y[i]==1, 1, 0)[:,:,-1] 102 | masks[i][:,:,2] = np.where(y[i]==2, 1, 0)[:,:,-1] 103 | masks[i][:,:,3] = np.where(y[i]==3, 1, 0)[:,:,-1] 104 | 105 | elif data == "synapse": 106 | masks = np.zeros((y.shape[0], y.shape[1], y.shape[2], 9)) 107 | 108 | for i in range(y.shape[0]): 109 | masks[i][:,:,0] = np.where(y[i]==0, 1, 0)[:,:,-1] 110 | masks[i][:,:,1] = np.where(y[i]==1, 1, 0)[:,:,-1] 111 | masks[i][:,:,2] = np.where(y[i]==2, 1, 0)[:,:,-1] 112 | masks[i][:,:,3] = np.where(y[i]==3, 1, 0)[:,:,-1] 113 | masks[i][:,:,4] = np.where(y[i]==4, 1, 0)[:,:,-1] 114 | masks[i][:,:,5] = np.where(y[i]==5, 1, 0)[:,:,-1] 115 | masks[i][:,:,6] = np.where(y[i]==6, 1, 0)[:,:,-1] 116 | masks[i][:,:,7] = np.where(y[i]==7, 1, 0)[:,:,-1] 117 | masks[i][:,:,8] = np.where(y[i]==8, 1, 0)[:,:,-1] 118 | 119 | else: 120 | print("Data set not recognized") 121 | 122 | return masks 123 | 124 | def convert_mask_single(y): 125 | """ 126 | Given one masks with many classes create one mask per class 127 | y: shape (w,h) 128 | """ 129 | mask = np.zeros((4, y.shape[0], y.shape[1])) 130 | mask[0, :, :] = np.where(y == 0, 1, 0) 131 | mask[1, :, :] = np.where(y == 1, 1, 0) 132 | mask[2, :, :] = np.where(y == 2, 1, 0) 133 | mask[3, :, :] = np.where(y == 3, 1, 0) 134 | 135 | return mask 136 | 137 | def visualize(image_raw,mask): 138 | """ 139 | iamge_raw:gray image with shape [width,height,1] 140 | mask: segment mask image with shape [num_class,width,height] 141 | this function return an image using multi color to visualize masks in raw image 142 | """ 143 | # Convert grayscale image to RGB 144 | image = cv2.cvtColor(image_raw, cv2.COLOR_GRAY2RGB) 145 | 146 | # Get the number of classes (i.e. channels) in the mask 147 | num_class = mask.shape[0] 148 | 149 | # Define colors for each class (using a simple color map) 150 | colors = [] 151 | for i in range(1, num_class): # skip first class (background) 152 | hue = int(i/float(num_class-1) * 179) 153 | color = np.zeros((1, 1, 3), dtype=np.uint8) 154 | color[0, 0, 0] = hue 155 | color[0, 0, 1:] = 255 156 | color = cv2.cvtColor(color, cv2.COLOR_HSV2RGB) 157 | colors.append(color) 158 | 159 | # Overlay each non-background class mask with a different color on the original image 160 | for i in range(1, num_class): 161 | class_mask = mask[i, :, :] 162 | class_mask = np.repeat(class_mask[:, :, np.newaxis], 3, axis=2) 163 | class_mask = class_mask.astype(image.dtype) 164 | class_mask = class_mask * colors[i-1] 165 | image = cv2.addWeighted(image, 1.0, class_mask, 0.5, 0.0) 166 | 167 | return image -------------------------------------------------------------------------------- /utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import numpy as np 4 | import random 5 | import PIL.Image 6 | import cv2 7 | from torch.utils.data import Dataset 8 | from .data_utils import convert_mask_single,visualize 9 | 10 | class ACDCTrainDataset(Dataset): 11 | def __init__(self,x,y,args) -> None: 12 | super().__init__() 13 | self.x = x 14 | self.y = y 15 | self.transform = transforms.Compose([ 16 | transforms.RandomVerticalFlip(), 17 | transforms.RandomHorizontalFlip(), 18 | transforms.RandomAffine(degrees=20,translate=(0.2,0.2)) 19 | ]) 20 | self.img_size = args.img_size 21 | 22 | def __len__(self): 23 | return self.x.shape[0] 24 | 25 | def __getitem__(self, index): 26 | seed = np.random.randint(2147483647) 27 | 28 | x = PIL.Image.fromarray(self.x[index].reshape(self.img_size, self.img_size)) 29 | y = PIL.Image.fromarray(self.y[index].reshape(self.img_size, self.img_size)) 30 | 31 | torch.manual_seed(seed) 32 | tar_x = np.array(self.transform(x)) 33 | # cv2.imwrite('x.jpg',tar_x) 34 | 35 | torch.manual_seed(seed) 36 | tar_y = np.array(self.transform(y)) 37 | tar_y = convert_mask_single(tar_y) 38 | # cv2.imwrite('y_0.jpg',tar_y[0]* 255) 39 | # cv2.imwrite('y_1.jpg',tar_y[1]* 255) 40 | # cv2.imwrite('y_2.jpg',tar_y[2]* 255) 41 | # cv2.imwrite('y_3.jpg',tar_y[3]* 255) 42 | 43 | 44 | # vis_img = visualize(tar_x,tar_y) 45 | # cv2.imwrite('test.jpg',vis_img) 46 | tar_x = tar_x.reshape(1,self.img_size,self.img_size) 47 | torch.manual_seed(0) 48 | return torch.tensor(tar_x).float(),torch.tensor(tar_y).float() -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import lightning as L 6 | from torchvision.ops.stochastic_depth import StochasticDepth 7 | 8 | 9 | class Convolutional_Attention(nn.Module): 10 | def __init__(self, 11 | channels, 12 | num_heads, 13 | img_size, 14 | proj_drop=0.0, 15 | kernel_size=3, 16 | stride_kv=1, 17 | stride_q=1, 18 | padding_kv="same", 19 | padding_q="same", 20 | attention_bias=True 21 | ): 22 | super().__init__() 23 | self.stride_kv = stride_kv 24 | self.stride_q = stride_q 25 | self.num_heads = num_heads 26 | self.proj_drop = proj_drop 27 | 28 | self.layer_q = nn.Sequential( 29 | nn.Conv2d(channels, channels, kernel_size, stride_q, padding_q, bias=attention_bias, groups=channels), 30 | nn.ReLU(), 31 | ) 32 | self.layernorm_q = nn.LayerNorm([channels,img_size,img_size], eps=1e-5) 33 | 34 | self.layer_k = nn.Sequential( 35 | nn.Conv2d(channels, channels, kernel_size, stride_kv, padding_kv, bias=attention_bias, groups=channels), 36 | nn.ReLU(), 37 | ) 38 | self.layernorm_k = nn.LayerNorm([channels,img_size,img_size], eps=1e-5) 39 | 40 | self.layer_v = nn.Sequential( 41 | nn.Conv2d(channels, channels, kernel_size, stride_kv, padding_kv, bias=attention_bias, groups=channels), 42 | nn.ReLU(), 43 | ) 44 | self.layernorm_v = nn.LayerNorm([channels,img_size,img_size], eps=1e-5) 45 | 46 | self.attention = nn.MultiheadAttention(embed_dim=channels, 47 | bias=attention_bias, 48 | batch_first=True, 49 | dropout=self.proj_drop, 50 | num_heads=self.num_heads) 51 | 52 | def _build_projection(self, x, mode): 53 | # x shape [batch,channel,size,size] 54 | # mode:0->q,1->k,2->v,for torch.script can not script str 55 | 56 | if mode == 0: 57 | x1 = self.layer_q(x) 58 | proj = self.layernorm_q(x1) 59 | elif mode == 1: 60 | x1 = self.layer_k(x) 61 | proj = self.layernorm_k(x1) 62 | elif mode == 2: 63 | x1 = self.layer_v(x) 64 | proj = self.layernorm_v(x1) 65 | 66 | return proj 67 | 68 | def get_qkv(self, x): 69 | q = self._build_projection(x, 0) 70 | k = self._build_projection(x, 1) 71 | v = self._build_projection(x, 2) 72 | 73 | return q, k, v 74 | 75 | def forward(self, x): 76 | q, k, v = self.get_qkv(x) 77 | q = q.view(q.shape[0], q.shape[1], q.shape[2]*q.shape[3]) 78 | k = k.view(k.shape[0], k.shape[1], k.shape[2]*k.shape[3]) 79 | v = v.view(v.shape[0], v.shape[1], v.shape[2]*v.shape[3]) 80 | q = q.permute(0, 2, 1) 81 | k = k.permute(0, 2, 1) 82 | v = v.permute(0, 2, 1) 83 | x1 = self.attention(query=q, value=v, key=k, need_weights=False) 84 | 85 | x1 = x1[0].permute(0, 2, 1) 86 | x1 = x1.view(x1.shape[0], x1.shape[1], np.sqrt( 87 | x1.shape[2]).astype(int), np.sqrt(x1.shape[2]).astype(int)) 88 | 89 | return x1 90 | 91 | 92 | class Transformer(nn.Module): 93 | 94 | def __init__(self, 95 | # in_channels, 96 | out_channels, 97 | num_heads, 98 | dpr, 99 | img_size, 100 | proj_drop=0.5, 101 | attention_bias=True, 102 | padding_q="same", 103 | padding_kv="same", 104 | stride_kv=1, 105 | stride_q=1): 106 | super().__init__() 107 | 108 | self.attention_output = Convolutional_Attention(channels=out_channels, 109 | num_heads=num_heads, 110 | img_size=img_size, 111 | proj_drop=proj_drop, 112 | padding_q=padding_q, 113 | padding_kv=padding_kv, 114 | stride_kv=stride_kv, 115 | stride_q=stride_q, 116 | attention_bias=attention_bias, 117 | ) 118 | 119 | self.stochastic_depth = StochasticDepth(dpr,mode='batch') 120 | self.conv1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding="same") 121 | self.layernorm = nn.LayerNorm([out_channels, img_size, img_size]) 122 | self.wide_focus = Wide_Focus(out_channels, out_channels) 123 | 124 | def forward(self, x): 125 | x1 = self.attention_output(x) 126 | x1 = self.stochastic_depth(x1) 127 | x2 = self.conv1(x1) + x 128 | 129 | x3 = self.layernorm(x2) 130 | x3 = self.wide_focus(x3) 131 | x3 = self.stochastic_depth(x3) 132 | 133 | out = x3 + x2 134 | return out 135 | 136 | 137 | class Wide_Focus(nn.Module): 138 | """ 139 | Wide-Focus module. 140 | """ 141 | 142 | def __init__(self, 143 | in_channels, 144 | out_channels): 145 | super().__init__() 146 | 147 | self.layer1 = nn.Sequential( 148 | nn.Conv2d(in_channels,out_channels, kernel_size=3, stride=1, padding="same"), 149 | nn.GELU(), 150 | nn.Dropout(0.3) 151 | ) 152 | self.layer_dilation2 = nn.Sequential( 153 | nn.Conv2d(in_channels,out_channels, kernel_size=3, stride=1, padding="same", dilation=2), 154 | nn.GELU(), 155 | nn.Dropout(0.3) 156 | ) 157 | self.layer_dilation3 = nn.Sequential( 158 | nn.Conv2d(in_channels,out_channels, kernel_size=3, stride=1, padding="same", dilation=3), 159 | nn.GELU(), 160 | nn.Dropout(0.3) 161 | ) 162 | self.layer4 = nn.Sequential( 163 | nn.Conv2d(in_channels,out_channels, kernel_size=3, stride=1, padding="same"), 164 | nn.GELU(), 165 | nn.Dropout(0.3) 166 | ) 167 | 168 | def forward(self, x): 169 | x1 = self.layer1(x) 170 | x2 = self.layer_dilation2(x) 171 | x3 = self.layer_dilation3(x) 172 | added = x1 + x2 + x3 173 | x_out = self.layer4(added) 174 | return x_out 175 | 176 | 177 | class Block_decoder(nn.Module): 178 | def __init__(self, in_channels, out_channels, att_heads, dpr, img_size): 179 | super().__init__() 180 | self.layernorm = nn.LayerNorm([in_channels, img_size, img_size]) 181 | # img size *= 2 182 | self.upsample = nn.Upsample(scale_factor=2,mode='bilinear') 183 | self.layer1 = nn.Sequential( 184 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding="same"), 185 | nn.ReLU() 186 | ) 187 | self.layer2 = nn.Sequential( 188 | nn.Conv2d(out_channels*2, out_channels, kernel_size=3, stride=1, padding="same"), 189 | nn.ReLU() 190 | ) 191 | self.layer3 = nn.Sequential( 192 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding="same"), 193 | nn.ReLU(), 194 | nn.Dropout(0.3) 195 | ) 196 | self.trans = Transformer(out_channels, att_heads, dpr, img_size * 2) 197 | 198 | def forward(self, x, skip): 199 | x1 = self.layernorm(x) 200 | x1 = self.upsample(x1) 201 | x1 = self.layer1(x1) 202 | x1 = torch.cat((skip, x1), axis=1) 203 | x1 = self.layer2(x1) 204 | x1 = self.layer3(x1) 205 | out = self.trans(x1) 206 | return out 207 | 208 | 209 | class DS_out(nn.Module): 210 | def __init__(self, in_channels, out_channels, img_size): 211 | super().__init__() 212 | # img size *= 2 213 | self.upsample = nn.Upsample(scale_factor=2) 214 | self.layernorm = nn.LayerNorm([in_channels,img_size*2,img_size*2], eps=1e-5) 215 | self.conv1 = nn.Sequential( 216 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding="same"), 217 | nn.ReLU() 218 | ) 219 | self.conv2 = nn.Sequential( 220 | nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding="same"), 221 | nn.ReLU() 222 | ) 223 | self.conv3 = nn.Sequential( 224 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding="same"), 225 | ) 226 | 227 | def forward(self, x): 228 | x1 = self.upsample(x) 229 | x1 = self.layernorm(x1) 230 | x1 = self.conv1(x1) 231 | x1 = self.conv2(x1) 232 | out = self.conv3(x1) 233 | 234 | return out 235 | 236 | 237 | class Block_encoder_without_skip(nn.Module): 238 | def __init__(self, in_channels, out_channels, att_heads, dpr, img_size): 239 | super().__init__() 240 | """LayerNorm Example 241 | >>> # Image Example 242 | >>> N, C, H, W = 20, 5, 10, 10 243 | >>> input = torch.randn(N, C, H, W) 244 | >>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions) 245 | >>> # as shown in the image below 246 | >>> layer_norm = nn.LayerNorm([C, H, W]) 247 | >>> output = layer_norm(input) 248 | """ 249 | self.layernorm = nn.LayerNorm([in_channels, img_size, img_size]) 250 | self.layer1 = nn.Sequential( 251 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding="same"), 252 | nn.ReLU() 253 | ) 254 | # img_size => img_size // 2 255 | self.layer2 = nn.Sequential( 256 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding="same"), 257 | nn.ReLU(), 258 | nn.Dropout(0.3), 259 | nn.MaxPool2d((2, 2)) 260 | ) 261 | self.trans = Transformer(out_channels, att_heads, dpr, img_size // 2) 262 | 263 | def forward(self, x): 264 | x = self.layernorm(x) 265 | x1 = self.layer1(x) 266 | x1 = self.layer2(x1) 267 | x1 = self.trans(x1) 268 | return x1 269 | 270 | 271 | class Block_encoder_with_skip(nn.Module): 272 | def __init__(self, in_channels, out_channels, att_heads, dpr, img_size): 273 | super().__init__() 274 | self.layernorm = nn.LayerNorm([in_channels, img_size, img_size]) 275 | self.layer1 = nn.Sequential( 276 | nn.Conv2d(1, in_channels, kernel_size=3, stride=1, padding="same"), 277 | nn.ReLU() 278 | ) 279 | self.layer2 = nn.Sequential( 280 | nn.Conv2d(in_channels * 2, out_channels, kernel_size=3, stride=1, padding="same"), 281 | nn.ReLU() 282 | ) 283 | # image size /= 2 284 | self.layer3 = nn.Sequential( 285 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding="same"), 286 | nn.ReLU(), 287 | nn.Dropout(0.3), 288 | nn.MaxPool2d((2, 2)) 289 | ) 290 | self.trans = Transformer(out_channels, att_heads, dpr, img_size // 2) 291 | 292 | def forward(self, x, scale_img): 293 | x1 = self.layernorm(x) 294 | x1 = torch.cat((self.layer1(scale_img), x1), axis=1) 295 | x1 = self.layer2(x1) 296 | x1 = self.layer3(x1) 297 | x1 = self.trans(x1) 298 | return x1 299 | 300 | 301 | class FCT(L.LightningModule): 302 | def __init__(self, args): 303 | super().__init__() 304 | 305 | self.drp_out = 0.3 306 | self.img_size = args.img_size 307 | self.loss_fn = nn.BCEWithLogitsLoss() 308 | self.args = args 309 | 310 | # attention heads and filters per block 311 | att_heads = [2, 4, 8, 16, 32, 16, 8, 4, 2] 312 | filters = [32, 64, 128, 256, 512, 256, 128, 64, 32] 313 | 314 | # number of blocks used in the model 315 | blocks = len(filters) 316 | 317 | stochastic_depth_rate = 0.5 318 | 319 | # probability for each block 320 | dpr = [x for x in np.linspace(0, stochastic_depth_rate, blocks)] 321 | 322 | # Multi-scale input 323 | self.scale_img = nn.AvgPool2d(2, 2) 324 | 325 | # model 326 | # [N,1,img_size,img_size] => [N,filters[0],img_size // 2,img_size // 2] 327 | self.block_1 = Block_encoder_without_skip(1, filters[0], att_heads[0], dpr[0], self.img_size) 328 | self.block_2 = Block_encoder_with_skip(filters[0], filters[1], att_heads[1], dpr[1], self.img_size // 2) 329 | self.block_3 = Block_encoder_with_skip(filters[1], filters[2], att_heads[2], dpr[2], self.img_size // 4) 330 | self.block_4 = Block_encoder_with_skip(filters[2], filters[3], att_heads[3], dpr[3], self.img_size // 8) 331 | self.block_5 = Block_encoder_without_skip(filters[3], filters[4], att_heads[4], dpr[4], self.img_size // 16) 332 | self.block_6 = Block_decoder(filters[4], filters[5], att_heads[5], dpr[5], self.img_size // 32) 333 | self.block_7 = Block_decoder(filters[5], filters[6], att_heads[6], dpr[6], self.img_size // 16) 334 | self.block_8 = Block_decoder(filters[6], filters[7], att_heads[7], dpr[7], self.img_size // 8) 335 | self.block_9 = Block_decoder(filters[7], filters[8], att_heads[8], dpr[8], self.img_size // 4) 336 | 337 | self.ds7 = DS_out(filters[6], 4, self.img_size // 8) 338 | self.ds8 = DS_out(filters[7], 4, self.img_size // 4) 339 | self.ds9 = DS_out(filters[8], 4, self.img_size // 2) 340 | 341 | def forward(self, x): 342 | 343 | # Multi-scale input 344 | scale_img_2 = self.scale_img(x) # x shape[batch_size,channel(1),224,224] 345 | scale_img_3 = self.scale_img(scale_img_2) # shape[batch,1,56,56] 346 | scale_img_4 = self.scale_img(scale_img_3) # shape[batch,1,28,28] 347 | 348 | x = self.block_1(x) 349 | # print(f"Block 1 out -> {list(x.size())}") 350 | skip1 = x 351 | x = self.block_2(x, scale_img_2) 352 | # print(f"Block 2 out -> {list(x.size())}") 353 | skip2 = x 354 | x = self.block_3(x, scale_img_3) 355 | # print(f"Block 3 out -> {list(x.size())}") 356 | skip3 = x 357 | x = self.block_4(x, scale_img_4) 358 | # print(f"Block 4 out -> {list(x.size())}") 359 | skip4 = x 360 | 361 | x = self.block_5(x) 362 | # print(f"Block 5 out -> {list(x.size())}") 363 | x = self.block_6(x, skip4) 364 | # print(f"Block 6 out -> {list(x.size())}") 365 | x = self.block_7(x, skip3) 366 | # print(f"Block 7 out -> {list(x.size())}") 367 | skip7 = x 368 | x = self.block_8(x, skip2) 369 | # print(f"Block 8 out -> {list(x.size())}") 370 | skip8 = x 371 | x = self.block_9(x, skip1) 372 | # print(f"Block 9 out -> {list(x.size())}") 373 | skip9 = x 374 | 375 | out7 = self.ds7(skip7) 376 | # print(f"DS 7 out -> {list(out7.size())}") 377 | out8 = self.ds8(skip8) 378 | # print(f"DS 8 out -> {list(out8.size())}") 379 | out9 = self.ds9(skip9) 380 | # print(f"DS 9 out -> {list(out9.size())}") 381 | 382 | return out7, out8, out9 383 | 384 | def training_step(self, batch, batch_idx): 385 | x, y = batch 386 | pred_y = self(x) 387 | # dsn 388 | down1 = F.interpolate(y, self.img_size // 2) 389 | down2 = F.interpolate(y, self.img_size // 4) 390 | loss = (self.loss_fn(pred_y[2], y) * 0.57 + self.loss_fn(pred_y[1], down1) * 0.29 + self.loss_fn(pred_y[0], down2) * 0.14) 391 | self.log("loss/train_loss", loss,on_epoch=True) 392 | 393 | # train dice 394 | y_pred = torch.argmax(pred_y[2], axis=1) 395 | y_pred_onehot = F.one_hot(y_pred, 4).permute(0, 3, 1, 2) 396 | dice = self.compute_dice(y_pred_onehot, y) 397 | dice_LV = dice[3] 398 | dice_RV = dice[1] 399 | dice_MYO = dice[2] 400 | self.log('dice/all_train_dice', dice[1:].mean(), on_epoch=True) 401 | self.log('dice/train_LV_dice', dice_LV, on_epoch=True) 402 | self.log('dice/train_RV_dice', dice_RV, on_epoch=True) 403 | self.log('dice/train_MYO_dice', dice_MYO, on_epoch=True) 404 | # save grad 405 | for name, params in self.named_parameters(): 406 | if params.grad is not None: 407 | self.log(f'abs_{name}',params.grad.abs().mean(), on_epoch=True) 408 | return loss 409 | 410 | def validation_step(self, batch, batch_idx): 411 | x, y = batch 412 | pred_y = self(x) 413 | # dsn 414 | down1 = F.interpolate(y, self.img_size // 2) 415 | down2 = F.interpolate(y, self.img_size // 4) 416 | loss = (self.loss_fn(pred_y[2], y) * 0.57 + self.loss_fn(pred_y[1], down1) * 0.29 + self.loss_fn(pred_y[0], down2) * 0.14) 417 | self.log("loss/validation_loss", loss,on_epoch=True) 418 | 419 | # train dice 420 | y_pred = torch.argmax(pred_y[2], axis=1) 421 | y_pred_onehot = F.one_hot(y_pred, 4).permute(0, 3, 1, 2) 422 | dice = self.compute_dice(y_pred_onehot, y) 423 | dice_LV = dice[3] 424 | dice_RV = dice[1] 425 | dice_MYO = dice[2] 426 | self.log('dice/all_validate_dice', dice[1:].mean(), on_epoch=True) 427 | self.log('dice/LV_dice', dice_LV, on_epoch=True) 428 | self.log('dice/RV_dice', dice_RV, on_epoch=True) 429 | self.log('dice/MYO_dice', dice_MYO, on_epoch=True) 430 | return loss 431 | 432 | def configure_optimizers(self): 433 | optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr) 434 | return { 435 | "optimizer": optimizer, 436 | "lr_scheduler": { 437 | "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( 438 | optimizer, mode='min', factor=self.args.lr_factor, min_lr=self.args.min_lr,patience=5), 439 | "monitor": "loss/train_loss", 440 | "name": "lr" 441 | } 442 | } 443 | 444 | 445 | @torch.no_grad() 446 | def compute_dice(self, pred_y, y): 447 | """ 448 | Computes the Dice coefficient for each class in the ACDC dataset. 449 | Assumes binary masks with shape (num_masks, num_classes, height, width). 450 | """ 451 | epsilon = 1e-6 452 | num_masks = pred_y.shape[0] 453 | num_classes = pred_y.shape[1] 454 | dice_scores = torch.zeros((num_classes,), device=self.device) 455 | 456 | for c in range(num_classes): 457 | intersection = torch.sum(pred_y[:, c] * y[:, c]) 458 | sum_masks = torch.sum(pred_y[:, c]) + torch.sum(y[:, c]) 459 | dice_scores[c] = (2. * intersection + epsilon) / (sum_masks + epsilon) 460 | 461 | return dice_scores 462 | --------------------------------------------------------------------------------