├── figures └── main_r2.png ├── Loss └── Dice_CE_Loss.py ├── README.md ├── inference.py ├── Data_Loader └── loaders.py ├── train.py └── models ├── classification_model_Mamba_Encoder.py └── seg_model.py /figures/main_r2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kabbas570/CAMS-Net/HEAD/figures/main_r2.png -------------------------------------------------------------------------------- /Loss/Dice_CE_Loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 6 | 7 | # Ensure the weight tensor matches the number of classes 8 | W_1 = torch.tensor([1, 1, 1, 1, 2], device=DEVICE, dtype=torch.float) 9 | 10 | def to_one_hot(targets, num_classes): 11 | batch_size, height, width = targets.size() 12 | one_hot = torch.zeros(batch_size, num_classes, height, width, device=DEVICE) 13 | return one_hot.scatter_(1, targets.unsqueeze(1), 1) 14 | 15 | class DiceCELoss(nn.Module): 16 | def __init__(self, weight=W_1, size_average=True): 17 | super(DiceCELoss, self).__init__() 18 | self.weight = weight 19 | self.size_average = size_average 20 | 21 | def forward(self, inputs, targets, smooth=1): 22 | num_classes = inputs.size(1) 23 | 24 | # Convert targets to one-hot encoding 25 | targets_one_hot = to_one_hot(targets, num_classes) 26 | 27 | # Apply softmax to inputs to get probabilities for Dice loss calculation 28 | inputs_softmax = F.softmax(inputs, dim=1) 29 | 30 | # Dice Loss with class weights 31 | dice_loss = 0 32 | for i in range(num_classes): 33 | input_flat = inputs_softmax[:, i].contiguous().view(-1) 34 | target_flat = targets_one_hot[:, i].contiguous().view(-1) 35 | intersection = (input_flat * target_flat).sum() 36 | # Incorporate class weights into Dice loss calculation 37 | dice_loss += self.weight[i] * (1 - ((2. * intersection + smooth) / 38 | (input_flat.sum() + target_flat.sum() + smooth))) 39 | dice_loss /= self.weight.sum() # Normalize by sum of weights 40 | 41 | # Cross-Entropy Loss 42 | ce_loss = F.cross_entropy(inputs, targets, weight=self.weight) 43 | 44 | # Combine losses 45 | loss = dice_loss + ce_loss 46 | 47 | return loss.mean() if self.size_average else loss.sum() 48 | 49 | # Example usage 50 | # Assuming inputs and targets are defined elsewhere with appropriate shapes 51 | # loss_fn = DiceCELoss() 52 | # loss = loss_fn(inputs, targets) 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CAMS: Convolution and Attention-Free Mamba-based Cardiac Image Segmentation 2 | 3 | ## [Accepted at WACV 2025 (IEEE/CVF Winter Conference on Applications of Computer Vision)](https://wacv2025.thecvf.com/) 4 | 5 | You can view or download the metadata [HERE](https://arxiv.org/abs/2406.05786) 6 | 7 | 8 | # Overview: 9 | This paper demonstrates that convolution and self-attention, while widely used, are not the only effective methods for segmentation. Breaking with convention, we present a Convolution and self-attention-free Mamba-based semantic Segmentation Network named CAMS-Net for the task of medical image segmentation. 10 | 11 | ![image](https://github.com/kabbas570/CAMS-Net/blob/052ac53b678a907be29aaca6b4abdd7dbd973d7a/figures/main_r2.png) 12 | 13 | ## Key Contributions: 14 | ***First Convolution and Self-attention-Free Architecture:*** To the best of our knowledge, we are the first to propose a convolution and self-attention-free Mamba-based segmentation network. 15 | 16 | ***Linearly Interconnected Factorized Mamba (LIFM):*** LIFM block to reduce the trainable parameters of Mamba and improve its non-linearity. LIFM implements a weight-sharing strategy for different scanning directions, specifically for the two scanning direction strategies of vision Mamba, to reduce the computational complexity further whilst maintaining accuracy. 17 | 18 | ***Mamba Channel Mamba Spatial Aggregators:*** These modules learn information along the channel and spatial dimensions of the features, respectively. 19 | 20 | ## Evaluation: 21 | Our approach was evaluated on two modalities using publicly available datasets: 22 | 23 | ***M&Ms-2 Dataset*** 24 | 25 | ***CMR×Recon Segmentation Dataset*** 26 | 27 | Results demonstrate state-of-the-art segmentation performance across diverse cardiac imaging modalities. 28 | 29 | # Training Steps 30 | 31 | ## Segmentation Model 32 | 33 | ## ImageNet Pretrained Weights 34 | The imageNet pre-trained weights are available for the following two sizes, as we trained for the M&Ms-2 Dataset and CMR×Recon Segmentation Dataset. 35 | 36 | [Mamba-Encoder for input size with spatial dim of 256 x 256](https://drive.google.com/open?id=1IMmrYufVxRek3sVfrY1FZDax0NVpQy7g&usp=drive_copy) 37 | 38 | [Mamba-Encoder for input size with spatial dim of 160 x 160](https://drive.google.com/open?id=1zmxagS6x7_osxNpoQxICSvNltVuNxJTC&usp=drive_copy) 39 | 40 | 41 | 42 | 43 | 44 | # Citation 45 | @article{khan2024cams, 46 | 47 | title={CAMS: Convolution and Attention-Free Mamba-based Cardiac Image Segmentation}, 48 | 49 | author={Khan, Abbas and Asad, Muhammad and Benning, Martin and Roney, Caroline and Slabaugh, Gregory}, 50 | 51 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision}, 52 | 53 | year={2025} 54 | } 55 | 56 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data import Dataset 4 | import SimpleITK as sitk 5 | import os 6 | import torch 7 | import matplotlib.pyplot as plt 8 | ########### Dataloader ############# 9 | NUM_WORKERS = 8 10 | PIN_MEMORY=True 11 | DIM_ = 256 12 | 13 | from loader import Data_Loader_val 14 | 15 | import matplotlib.pyplot as plt 16 | from tqdm import tqdm 17 | import torch.optim as optim 18 | from medpy import metric 19 | import kornia 20 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 21 | 22 | def make_edges(image,three): 23 | three = np.stack((three,)*3, axis=2) 24 | three =torch.tensor(three) 25 | three = np.transpose(three, (2,0,1)) ## to bring channel first 26 | three= torch.unsqueeze(three,axis = 0) 27 | magnitude, edges=kornia.filters.canny(three, low_threshold=0.1, high_threshold=0.2, kernel_size=(7, 7), sigma=(1, 1), hysteresis=True, eps=1e-06) 28 | image[np.where(edges[0,0,:,:]!=0)] = 1 29 | return image 30 | 31 | def normalize(x): 32 | return np.array((x - np.min(x)) / (np.max(x) - np.min(x))) 33 | 34 | def blend(image,LV,MYO,RV,three): 35 | 36 | image = normalize(image) 37 | image = np.stack((image,)*3, axis=2) 38 | image[np.where(LV==1)] = [0.9,0.9,0] 39 | image[np.where(MYO==1)] = [0.9,0,0] 40 | image[np.where(RV==1)] = [0,0,0.9] 41 | image = make_edges(image,three) 42 | return image 43 | 44 | #### Specify all the paths here ##### 45 | 46 | def calculate_metric_percase(pred, gt): 47 | gt = gt[0,:] 48 | pred = pred[0,:] 49 | dice = metric.binary.dc(pred, gt) 50 | hd = metric.binary.hd95(pred, gt) 51 | return dice, hd 52 | 53 | def Average(lst): 54 | return sum(lst) / len(lst) 55 | 56 | 57 | Dice_LV_LA = [] 58 | HD_LV_LA = [] 59 | 60 | Dice_MYO_LA = [] 61 | HD_MYO_LA = [] 62 | 63 | Dice_RV_LA = [] 64 | HD_RV_LA = [] 65 | 66 | def check_Dice_Score(loader, model1, device=DEVICE): 67 | 68 | Dice_LV = 0 69 | HD_LV = 0 70 | 71 | Dice_MYO = 0 72 | HD_MYO = 0 73 | 74 | Dice_RV = 0 75 | HD_RV = 0 76 | 77 | loop = tqdm(loader) 78 | model1.eval() 79 | 80 | 81 | for batch_idx, (img,gt,name) in enumerate(loop): 82 | 83 | img = img.to(device=DEVICE,dtype=torch.float) 84 | gt = gt.to(device=DEVICE,dtype=torch.float) 85 | 86 | 87 | with torch.no_grad(): 88 | 89 | pre_2d = model1(img) 90 | pred = torch.argmax(pre_2d, dim=1) 91 | 92 | out_LV = torch.zeros_like(pred) 93 | out_LV[torch.where(pred==1)] = 1 94 | out_MYO = torch.zeros_like(pred) 95 | out_MYO[torch.where(pred==2)] = 1 96 | out_RV = torch.zeros_like(pred) 97 | out_RV[torch.where(pred==3)] = 1 98 | 99 | 100 | single_lv,single_hd_lv = 0,0 101 | single_myo,single_hd_myo = 0,0 102 | single_rv,single_hd_rv = 0,0 103 | 104 | if torch.sum(out_LV)!=0: 105 | single_lv,single_hd_lv = calculate_metric_percase(out_LV.detach().cpu().numpy(),gt[:,1,:].detach().cpu().numpy()) 106 | if torch.sum(out_MYO)!=0: 107 | single_myo,single_hd_myo = calculate_metric_percase(out_MYO.detach().cpu().numpy(),gt[:,2,:].detach().cpu().numpy()) 108 | if torch.sum(out_RV)!=0: 109 | single_rv,single_hd_rv = calculate_metric_percase(out_RV.detach().cpu().numpy(),gt[:,3,:].detach().cpu().numpy()) 110 | 111 | Dice_LV+=single_lv 112 | HD_LV+=single_hd_lv 113 | 114 | Dice_MYO+=single_myo 115 | HD_MYO+=single_hd_myo 116 | 117 | Dice_RV+=single_rv 118 | HD_RV+=single_hd_rv 119 | 120 | img = img.detach().cpu().numpy() 121 | gt = gt.detach().cpu().numpy() 122 | 123 | out_LV = out_LV.detach().cpu().numpy() 124 | out_MYO = out_MYO.detach().cpu().numpy() 125 | out_RV = out_RV.detach().cpu().numpy() 126 | 127 | pred_blend = blend(img[0,0,:],out_LV[0,:],out_MYO[0,:],out_RV[0,:],1-gt[0,0,:]) 128 | plt.imsave(viz_pred_path + name[0] + '.png', pred_blend) 129 | 130 | 131 | print("for fold -->", fold) 132 | print(' :: Dice Scores ::') 133 | print(f"Dice_LV : {Dice_LV/len(loader)}") 134 | print(f"Dice_MYO : {Dice_MYO/len(loader)}") 135 | print(f"Dice_RV : {Dice_RV/len(loader)}") 136 | print(' :: HD Scores :: ') 137 | print(f"HD_LV : {HD_LV/len(loader)}") 138 | print(f"HD_MYO : {HD_MYO/len(loader)}") 139 | print(f"HD_RV : {HD_RV/len(loader)}") 140 | print(" ") 141 | 142 | return Dice_LV/len(loader), Dice_MYO/len(loader),Dice_RV/len(loader), HD_LV/len(loader),HD_MYO/len(loader),HD_RV/len(loader) 143 | for fold in range(1,6): 144 | 145 | fold = str(fold) ## training fold number 146 | 147 | path_to_checkpoints = "/data/scratch/acw676/Mamba_Data/NEW_DATA/Data_Aug/MNM_OTHER/F"+fold+_"Mamba_1.pth.tar" 148 | 149 | viz_gt_path = '/data/scratch/acw676/Mamba_Data/NEW_DATA/Visual_Res/swinunet/gts/F'+fold+'/' 150 | viz_pred_path = '/data/scratch/acw676/Mamba_Data/NEW_DATA/Data_Aug/MNM_OTHER/Visual_R_MNM/Mamba_UNet/F'+fold+'/' 151 | val_imgs = "/data/scratch/acw676/Mamba_Data/NEW_DATA/Aug_MnM2/new_split_mix/F"+fold+"/val/imgs/" 152 | 153 | Batch_Size = 1 154 | val_loader = Data_Loader_val(val_imgs,batch_size = Batch_Size) 155 | print(len(val_loader)) ### same here 156 | 157 | from pre1 import EncoderDecoder,CustomEncoder 158 | encoder = CustomEncoder() 159 | model_1 = EncoderDecoder(encoder) 160 | 161 | 162 | def eval_(): 163 | model = model_1.to(device=DEVICE,dtype=torch.float) 164 | checkpoint = torch.load(path_to_checkpoints,map_location=DEVICE) 165 | model.load_state_dict(checkpoint['state_dict']) 166 | 167 | Dice_LV, Dice_MYO,Dice_RV,HD_LV,HD_MYO,HD_RV= check_Dice_Score(val_loader, model, device=DEVICE) 168 | 169 | Dice_LV_LA.append(Dice_LV) 170 | Dice_MYO_LA.append(Dice_MYO) 171 | Dice_RV_LA.append(Dice_RV) 172 | 173 | HD_LV_LA.append(HD_LV) 174 | HD_MYO_LA.append(HD_MYO) 175 | HD_RV_LA.append(HD_RV) 176 | 177 | 178 | if __name__ == "__main__": 179 | eval_() 180 | 181 | print("Average Five Fold Dice_LV_LA --> ", Average(Dice_LV_LA)) 182 | print("Average Five Fold Dice_MYO_LA --> ", Average(Dice_MYO_LA)) 183 | print("Average Five Fold Dice_RV_LA --> ", Average(Dice_RV_LA)) 184 | 185 | print("Average Five Fold HD_LV_LA --> ", Average(HD_LV_LA)) 186 | print("Average Five Fold HD_MYO_LA --> ", Average(HD_MYO_LA)) 187 | print("Average Five Fold HD_RV_LA --> ", Average(HD_RV_LA)) 188 | -------------------------------------------------------------------------------- /Data_Loader/loaders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data import Dataset 4 | import SimpleITK as sitk 5 | import os 6 | import torch 7 | import matplotlib.pyplot as plt 8 | #from typing import List, Union, Tuple 9 | 10 | import torchio as tio 11 | ########### Dataloader ############# 12 | NUM_WORKERS = 8 13 | PIN_MEMORY=True 14 | DIM_ = 160 15 | 16 | def crop_center_3D(img,cropx=DIM_,cropy=DIM_): 17 | z,x,y = img.shape 18 | startx = x//2 - cropx//2 19 | starty = (y)//2 - cropy//2 20 | return img[:,startx:startx+cropx, starty:starty+cropy] 21 | 22 | def Cropping_3d(org_dim3,org_dim1,org_dim2,DIM_,img_):# org_dim3->numof channels 23 | 24 | if org_dim1DIM_ and org_dim2>DIM_: 31 | img_ = crop_center_3D(img_) 32 | ## two dims are different #### 33 | if org_dim1=DIM_: 34 | padding1=int((DIM_-org_dim1)//2) 35 | temp=np.zeros([org_dim3,DIM_,org_dim2]) 36 | temp[:,padding1:org_dim1+padding1,:] = img_[:,:,:] 37 | img_=temp 38 | img_ = crop_center_3D(img_) 39 | if org_dim1==DIM_ and org_dim2DIM_ and org_dim2 [H,W,C] 105 | img = sitk.GetArrayFromImage(img) ## --> [C,H,W] 106 | img = Normalization_1(img) 107 | gt_path = os.path.join(self.gt_folder,str(self.images_name[index]).zfill(3)) 108 | gt_path = gt_path[:-11]+'_gt.nii.gz' 109 | gt = sitk.ReadImage(gt_path) ## --> [H,W,C] 110 | gt = sitk.GetArrayFromImage(gt) ## --> [C,H,W] 111 | gt = gt.astype(np.float64) 112 | 113 | gt = np.expand_dims(gt, axis=0) 114 | img = np.expand_dims(img, axis=0) 115 | 116 | C = img.shape[0] 117 | H = img.shape[1] 118 | W = img.shape[2] 119 | img = Cropping_3d(C,H,W,DIM_,img) 120 | 121 | C = gt.shape[0] 122 | H = gt.shape[1] 123 | W = gt.shape[2] 124 | gt = Cropping_3d(C,H,W,DIM_,gt) 125 | 126 | ## apply augmentaitons here ### 127 | 128 | img = np.expand_dims(img, axis=3) 129 | gt = np.expand_dims(gt, axis=3) 130 | 131 | d = {} 132 | d['Image'] = tio.Image(tensor = img, type=tio.INTENSITY) 133 | d['Mask'] = tio.Image(tensor = gt, type=tio.LABEL) 134 | sample = tio.Subject(d) 135 | if self.transformations is not None: 136 | transformed_tensor = self.transformations(sample) 137 | img = transformed_tensor['Image'].data 138 | gt = transformed_tensor['Mask'].data 139 | 140 | gt = gt[...,0] 141 | img = img[...,0] 142 | 143 | gt = generate_label(gt) 144 | 145 | return img,gt 146 | 147 | def Data_Loader_io_transforms(images_folder,batch_size,num_workers=NUM_WORKERS,pin_memory=PIN_MEMORY): 148 | test_ids = Dataset_io(images_folder=images_folder) 149 | data_loader = DataLoader(test_ids,batch_size=batch_size,num_workers=num_workers,pin_memory=pin_memory,shuffle=True) 150 | return data_loader 151 | 152 | 153 | class Dataset_val(Dataset): 154 | def __init__(self, images_folder): ## If I apply Data Augmentation here, the validation loss becomes None. 155 | self.images_folder = images_folder 156 | self.gt_folder = self.images_folder[:-5] + 'gts' 157 | self.images_name = os.listdir(images_folder) 158 | def __len__(self): 159 | return len(self.images_name) 160 | def __getitem__(self, index): 161 | 162 | img_path = os.path.join(self.images_folder,str(self.images_name[index]).zfill(3)) 163 | img = sitk.ReadImage(img_path) ## --> [H,W,C] 164 | img = sitk.GetArrayFromImage(img) ## --> [C,H,W] 165 | img = Normalization_1(img) 166 | gt_path = os.path.join(self.gt_folder,str(self.images_name[index]).zfill(3)) 167 | gt_path = gt_path[:-11]+'_gt.nii.gz' 168 | gt = sitk.ReadImage(gt_path) ## --> [H,W,C] 169 | gt = sitk.GetArrayFromImage(gt) ## --> [C,H,W] 170 | gt = gt.astype(np.float64) 171 | 172 | gt = np.expand_dims(gt, axis=0) 173 | img = np.expand_dims(img, axis=0) 174 | 175 | C = img.shape[0] 176 | H = img.shape[1] 177 | W = img.shape[2] 178 | img = Cropping_3d(C,H,W,DIM_,img) 179 | 180 | C = gt.shape[0] 181 | H = gt.shape[1] 182 | W = gt.shape[2] 183 | gt = Cropping_3d(C,H,W,DIM_,gt) 184 | gt = generate_label(gt) 185 | 186 | return img,gt 187 | 188 | def Data_Loader_val(images_folder,batch_size,num_workers=NUM_WORKERS,pin_memory=PIN_MEMORY): 189 | test_ids = Dataset_val(images_folder=images_folder) 190 | data_loader = DataLoader(test_ids,batch_size=batch_size,num_workers=num_workers,pin_memory=pin_memory,shuffle=True) 191 | return data_loader 192 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import numpy as np 4 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 5 | from tqdm import tqdm 6 | import matplotlib.pyplot as plt 7 | 8 | def get_lr(optimizer): 9 | for param_group in optimizer.param_groups: 10 | return param_group['lr'] 11 | 12 | 13 | def check_Dice_Score(loader, model1, device=DEVICE): 14 | 15 | Dice_score_LA = 0 16 | Dice_score_RA = 0 17 | Dice_score_LV = 0 18 | Dice_score_RV = 0 19 | 20 | loop = tqdm(loader) 21 | model1.eval() 22 | 23 | for batch_idx, (img,gt) in enumerate(loop): 24 | 25 | img = img.to(device=DEVICE,dtype=torch.float) 26 | gt = gt.to(device=DEVICE,dtype=torch.float) 27 | 28 | with torch.no_grad(): 29 | pre_2d= model1(img) 30 | 31 | ## segemntaiton ## 32 | 33 | pred = torch.argmax(pre_2d, dim=1) 34 | 35 | out_LA = torch.zeros_like(pred) 36 | out_LA[torch.where(pred==1)] = 1 37 | 38 | out_RA = torch.zeros_like(pred) 39 | out_RA[torch.where(pred==2)] = 1 40 | 41 | out_LV = torch.zeros_like(pred) 42 | out_LV[torch.where(pred==3)] = 1 43 | 44 | out_RV = torch.zeros_like(pred) 45 | out_RV[torch.where(pred==4)] = 1 46 | 47 | 48 | single_LA = (2 * (out_LA * gt[:,1,:]).sum()) / ( 49 | (out_LA + gt[:,1,:]).sum() + 1e-8) 50 | 51 | Dice_score_LA +=single_LA 52 | 53 | single_RA = (2 * (out_RA * gt[:,2,:]).sum()) / ( 54 | (out_RA + gt[:,2,:]).sum() + 1e-8) 55 | 56 | Dice_score_RA +=single_RA 57 | 58 | 59 | single_LV = (2 * (out_LV * gt[:,3,:]).sum()) / ( 60 | (out_LV + gt[:,3,:]).sum() + 1e-8) 61 | 62 | Dice_score_LV +=single_LV 63 | 64 | single_RV = (2 * (out_RV * gt[:,4,:]).sum()) / ( 65 | (out_RV + gt[:,4,:]).sum() + 1e-8) 66 | 67 | Dice_score_RV +=single_RV 68 | 69 | 70 | ## segemntaiton ## 71 | print(f"Dice_score_LA : {Dice_score_LA/len(loader)}") 72 | print(f"Dice_score_RA : {Dice_score_RA/len(loader)}") 73 | print(f"Dice_score_LV : {Dice_score_LV/len(loader)}") 74 | print(f"Dice_score_RV : {Dice_score_RV/len(loader)}") 75 | 76 | Overall_Dicescore_LA = (Dice_score_LA + Dice_score_RA + Dice_score_LV + Dice_score_RV )/4 77 | 78 | print(f"Overall_Dicescore_LA : {Overall_Dicescore_LA/len(loader)}") 79 | 80 | return Overall_Dicescore_LA/len(loader) 81 | 82 | 83 | def train_fn(loader_train1,loader_valid1,model1, optimizer1, scaler1,loss_fn_DC1): ### Loader_1--> ED and Loader2-->ES 84 | 85 | train_losses1_seg = [] # loss of each batch 86 | valid_losses1_seg = [] # loss of each batch 87 | 88 | 89 | loop = tqdm(loader_train1) 90 | model1.train() 91 | 92 | for param_group in optimizer1.param_groups: 93 | print(f"Current learning rate: {param_group['lr']}") 94 | 95 | 96 | for batch_idx,(img,gt) in enumerate(loop): 97 | 98 | img = img.to(device=DEVICE,dtype=torch.float) 99 | gt = gt.to(device=DEVICE,dtype=torch.float) 100 | 101 | 102 | with torch.cuda.amp.autocast(): 103 | pre_2d = model1(img) ## loss1 is for 4 classes 104 | gt = torch.argmax(gt, dim=1) ## used for Loss1 105 | ## segmentation losses ## 106 | loss = loss_fn_DC1(pre_2d,gt) 107 | 108 | loss.backward() 109 | torch.nn.utils.clip_grad_norm_(model1.parameters(), 2.0) 110 | optimizer1.first_step(zero_grad=True) 111 | 112 | loss_fn_DC1(model1(img),gt).backward() 113 | torch.nn.utils.clip_grad_norm_(model1.parameters(), 2.0) 114 | optimizer1.second_step(zero_grad=True) 115 | 116 | # update tqdm loop 117 | loop.set_postfix(loss = loss.item()) ## loss = loss1.item() 118 | 119 | train_losses1_seg.append(float(loss)) 120 | 121 | loop_v = tqdm(loader_valid1) 122 | model1.eval() 123 | for batch_idx,(img,gt) in enumerate(loop_v): 124 | img = img.to(device=DEVICE,dtype=torch.float) 125 | gt = gt.to(device=DEVICE,dtype=torch.float) 126 | 127 | with torch.no_grad(): 128 | pre_2d = model1(img) ## loss1 is for 4 classes 129 | ## segmentation losses ## 130 | gt = torch.argmax(gt, dim=1) ## used for Loss1 131 | loss = loss_fn_DC1(pre_2d,gt) 132 | 133 | # backward 134 | loop_v.set_postfix(loss = loss.item()) 135 | valid_losses1_seg.append(float(loss)) 136 | 137 | train_loss_per_epoch1_seg = np.average(train_losses1_seg) 138 | valid_loss_per_epoch1_seg = np.average(valid_losses1_seg) 139 | 140 | avg_train_losses1_seg.append(train_loss_per_epoch1_seg) 141 | avg_valid_losses1_seg.append(valid_loss_per_epoch1_seg) 142 | 143 | 144 | return train_loss_per_epoch1_seg,valid_loss_per_epoch1_seg 145 | 146 | 147 | from Loss3 import DiceCELoss 148 | loss_fn_DC1 = DiceCELoss() 149 | 150 | 151 | from loader import Data_Loader_val,Data_Loader_train 152 | for fold in range(1,6): 153 | 154 | from pre2 import EncoderDecoder,CustomEncoder 155 | 156 | 157 | fold = str(fold) ## training fold number 158 | 159 | ### Data is arranged as follows; 160 | # Data_CMR/F1/train/imgs 161 | # /gts 162 | 163 | # Data_CMR/F2/train/imgs 164 | # /gts 165 | 166 | # Data_CMR/F3/train/imgs 167 | # /gts 168 | 169 | # Data_CMR/F4/train/imgs 170 | # /gts 171 | 172 | # Data_CMR/F5/train/imgs 173 | # /gts 174 | 175 | 176 | 177 | 178 | train_imgs = "Data_CMR/F"+fold+"/train/imgs/" ## ABSOLUTE PATHs 179 | val_imgs = "Data_CMR/F"+fold+"/val/imgs/" 180 | 181 | Batch_Size = 16 182 | Max_Epochs = 500 183 | 184 | train_loader = Data_Loader_train(train_imgs,batch_size = Batch_Size) # Data_Loader_io_transforms 185 | val_loader = Data_Loader_val(val_imgs,batch_size = 1) 186 | 187 | 188 | print(len(train_loader)) ### this shoud be = Total_images/ batch size 189 | print(len(val_loader)) ### same here 190 | #print(len(test_loader)) ### same here 191 | 192 | avg_train_losses1_seg = [] # losses of all training epochs 193 | avg_valid_losses1_seg = [] #losses of all training epochs 194 | 195 | avg_valid_DS_ValSet_seg = [] # all training epochs 196 | avg_valid_DS_TrainSet_seg = [] # all training epochs 197 | 198 | path_to_save_Learning_Curve = '/data/scratch/acw676/Mamba_Data/NEW_DATA/Data_Aug/DC1_CMR/'+'/F'+fold+'Mamba_1' 199 | path_to_save_check_points = '/data/scratch/acw676/Mamba_Data/NEW_DATA/Data_Aug/DC1_CMR/'+'/F'+fold+'Mamba_1' 200 | 201 | ### 3 - this function will save the check-points 202 | def save_checkpoint(state, filename=path_to_save_check_points+".pth.tar"): 203 | print("=> Saving checkpoint") 204 | torch.save(state, filename) 205 | 206 | encoder = CustomEncoder() 207 | 208 | ### Freeze Or Unfreeze this part to use the pre-trained weights #### 209 | 210 | # pretrained_weights_path = "/data/scratch/acw676/Mamba_Data/imageNet_weights/Mamba_160.pth.tar" 211 | # checkpoint = torch.load(pretrained_weights_path, map_location=DEVICE) 212 | # encoder.load_state_dict(checkpoint['state_dict']) 213 | 214 | Mamba_Model = EncoderDecoder(encoder) 215 | 216 | model_1 = Mamba_Model # SwinUNET_R 217 | epoch_len = len(str(Max_Epochs)) 218 | 219 | # Variable to keep track of maximum Dice validation score 220 | 221 | def main(): 222 | max_dice_val = 0.0 223 | model1 = model_1.to(device=DEVICE,dtype=torch.float) 224 | scaler1 = torch.cuda.amp.GradScaler() 225 | 226 | optimizer1 = optim.AdamW(model1.parameters(),betas=(0.5, 0.55),lr=0.001) # 0.00005 227 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer1,milestones=[100,200,300,400,500], gamma=0.5) 228 | 229 | for epoch in range(Max_Epochs): 230 | 231 | train_loss_seg ,valid_loss_seg = train_fn(train_loader,val_loader, model1, optimizer1,scaler1,loss_fn_DC1) 232 | scheduler.step() 233 | 234 | print_msg1 = (f'[{epoch:>{epoch_len}}/{Max_Epochs:>{epoch_len}}] ' + 235 | f'train_loss_seg: {train_loss_seg:.5f} ' + 236 | f'valid_loss_seg: {valid_loss_seg:.5f}') 237 | 238 | 239 | print(print_msg1) 240 | 241 | Dice_val = check_Dice_Score(val_loader, model1, device=DEVICE) 242 | avg_valid_DS_ValSet_seg.append(Dice_val.detach().cpu().numpy()) 243 | 244 | 245 | if Dice_val > max_dice_val: 246 | max_dice_val = Dice_val 247 | # Save the checkpoint 248 | checkpoint = { 249 | "state_dict": model1.state_dict(), 250 | "optimizer": optimizer1.state_dict(), 251 | } 252 | save_checkpoint(checkpoint) 253 | 254 | if __name__ == "__main__": 255 | main() 256 | 257 | fig = plt.figure(figsize=(10,8)) 258 | 259 | plt.plot(range(1,len(avg_train_losses1_seg)+1),avg_train_losses1_seg, label='Training Segmentation Loss') 260 | plt.plot(range(1,len(avg_valid_losses1_seg)+1),avg_valid_losses1_seg,label='Validation Segmentation Loss') 261 | 262 | plt.plot(range(1,len(avg_valid_DS_ValSet_seg)+1),avg_valid_DS_ValSet_seg,label='Validation DS') 263 | 264 | # find position of lowest validation loss 265 | minposs = avg_valid_losses1_seg.index(min(avg_valid_losses1_seg))+1 266 | plt.axvline(minposs,linestyle='--', color='r',label='Early Stopping Checkpoint') 267 | font1 = {'size':20} 268 | plt.title("Learning Curve Graph",fontdict = font1) 269 | plt.xlabel('epochs') 270 | plt.ylabel('loss') 271 | plt.ylim(-1, 1) # consistent scale 272 | plt.xlim(0, len(avg_train_losses1_seg)+1) # consistent scale 273 | plt.grid(True) 274 | plt.legend() 275 | plt.tight_layout() 276 | plt.show() 277 | fig.savefig(path_to_save_Learning_Curve+'.png', bbox_inches='tight') 278 | -------------------------------------------------------------------------------- /models/classification_model_Mamba_Encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mamba_ssm import Mamba 4 | 5 | D_state = 2 6 | Expand = 1 7 | drop_rate = 0.10 8 | 9 | def flip_Dim(x,dim): 10 | if dim == '12': 11 | x1=x 12 | if dim == '14': 13 | x1 = x.permute(0, 1, 3, 2).contiguous() 14 | if dim == '98': 15 | x1 = torch.flip(x,[2,3]) 16 | if dim == '96': 17 | x1 = torch.flip(x,[2,3]).permute(0, 1, 3, 2).contiguous() 18 | if dim == '78': 19 | x1 = torch.flip(x,[2]) 20 | if dim == '74': 21 | x1 = torch.flip(x,[2]).permute(0, 1, 3, 2).contiguous() 22 | if dim == '32': 23 | x1 = torch.flip(x,[3]) 24 | if dim == '36': 25 | x1 = torch.flip(x,[3]).permute(0, 1, 3, 2).contiguous() 26 | return x1 27 | 28 | def flip_Dim_back(x,dim): 29 | if dim == '12': 30 | x1=x 31 | if dim == '14': 32 | x1 = x.permute(0, 1, 3, 2).contiguous() 33 | if dim == '98': 34 | x1 = torch.flip(x,[2,3]) 35 | if dim == '96': 36 | x1 = torch.flip(x,[2,3]).permute(0, 1, 3, 2).contiguous() 37 | if dim == '78': 38 | x1 = torch.flip(x,[2]) 39 | if dim == '74': 40 | x1 = torch.flip(x,[3]).permute(0, 1, 3, 2).contiguous() 41 | if dim == '32': 42 | x1 = torch.flip(x,[3]) 43 | if dim == '36': 44 | x1 = torch.flip(x,[2]).permute(0, 1, 3, 2).contiguous() 45 | return x1 46 | 47 | class Linear_Layer(nn.Module): 48 | def __init__(self, n_channels,out_channels): 49 | super(Linear_Layer, self).__init__() 50 | self.n_channels = n_channels 51 | self.out_channels = out_channels 52 | 53 | self.linear1 = nn.Linear(self.n_channels,self.out_channels) 54 | self.drop = nn.Dropout(p=drop_rate) 55 | 56 | self.m = nn.SiLU() 57 | #self.norm = nn.LayerNorm(normalized_shape=out_channels) 58 | 59 | def forward(self, x1): 60 | 61 | b,c,h,w = x1.shape 62 | x1 = x1.permute(0,2,3,1).flatten(start_dim=1,end_dim=2) ## [B,H*W,C] 63 | x1 = self.linear1(x1) 64 | #x1 = self.norm(x1) 65 | x1 = self.m(x1) 66 | x1 = self.drop(x1) 67 | x1 = x1.view(b,h,w,self.out_channels).permute(0,3,1,2) 68 | return x1 69 | 70 | class Linear_Layer_Last(nn.Module): 71 | def __init__(self, n_channels,out_channels): 72 | super(Linear_Layer_Last, self).__init__() 73 | self.n_channels = n_channels 74 | self.out_channels = out_channels 75 | 76 | self.linear1 = nn.Linear(self.n_channels,self.out_channels) 77 | 78 | def forward(self, x1): 79 | 80 | b,c,h,w = x1.shape 81 | x1 = x1.permute(0,2,3,1).flatten(start_dim=1,end_dim=2) ## [B,H*W,C] 82 | x1 = self.linear1(x1) 83 | x1 = x1.view(b,h,w,self.out_channels).permute(0,3,1,2) 84 | return x1 85 | 86 | 87 | class SSM_spa(nn.Module): 88 | def __init__(self, sp_dim1,sp_dim2): 89 | super(SSM_spa, self).__init__() 90 | 91 | self.sp_dim2 = sp_dim2 92 | self.ssm2 = Mamba( 93 | d_model = sp_dim1*sp_dim1, 94 | out_c = sp_dim2*sp_dim2, 95 | d_state = D_state, 96 | expand=Expand) 97 | 98 | self.norm = nn.LayerNorm(normalized_shape=sp_dim2*sp_dim2) 99 | 100 | def forward(self, x1): 101 | 102 | b,c,h,w = x1.shape 103 | x1 = x1.flatten(start_dim=2,end_dim=3) ## [B,H*W,C] 104 | 105 | x1 = self.ssm2(x1) 106 | x1 = self.norm(x1) 107 | x1 = x1.view(b,c,self.sp_dim2,self.sp_dim2) 108 | return x1 109 | 110 | class SSM_cha(nn.Module): 111 | def __init__(self, n_channels,out_channels): 112 | super(SSM_cha, self).__init__() 113 | self.n_channels = n_channels 114 | self.out_channels = out_channels 115 | 116 | self.ssm1 = Mamba( 117 | d_model = self.n_channels, 118 | out_c = self.out_channels, 119 | d_state=D_state, 120 | expand=Expand, 121 | ) 122 | 123 | self.norm = nn.LayerNorm(normalized_shape=out_channels) 124 | 125 | def forward(self, x1): 126 | 127 | b,c,h,w = x1.shape 128 | x1 = x1.permute(0,2,3,1).flatten(start_dim=1,end_dim=2) ## [B,H*W,C] 129 | 130 | x1 = self.ssm1(x1) 131 | x1 = self.norm(x1) 132 | 133 | x1 = x1.view(b,h,w,self.out_channels).permute(0,3,1,2) 134 | return x1 135 | 136 | class Linear_Layer_SP_Res(nn.Module): 137 | def __init__(self, sp_dim1,sp_dim2,in_channels,out_channels): 138 | super(Linear_Layer_SP_Res, self).__init__() 139 | self.sp_dim1 = sp_dim1 140 | self.sp_dim2 = sp_dim2 141 | 142 | self.in_channels = in_channels 143 | self.out_channels = out_channels 144 | 145 | self.linear1 = nn.Linear(sp_dim1*sp_dim1,sp_dim2*sp_dim2) 146 | self.drop = nn.Dropout(p=drop_rate) 147 | 148 | 149 | 150 | self.lin_chan = Linear_Layer(self.in_channels,self.out_channels) 151 | #self.norm = nn.LayerNorm(normalized_shape=sp_dim2*sp_dim2) 152 | 153 | self.m = nn.SiLU() 154 | 155 | def forward(self, x1): 156 | 157 | b,c,h,w = x1.shape 158 | x1 = x1.flatten(start_dim=2,end_dim=3) 159 | 160 | x1 = self.linear1(x1) 161 | #x1 = self.norm(x1) 162 | x1 = self.m(x1) 163 | x1 = self.drop(x1) 164 | 165 | x1 = x1.view(b,c,h,w) 166 | x1 = self.lin_chan(x1) 167 | return x1 168 | 169 | class SSM_cha_Last(nn.Module): 170 | def __init__(self, n_channels,out_channels): 171 | super(SSM_cha_Last, self).__init__() 172 | self.n_channels = n_channels 173 | self.out_channels = out_channels 174 | 175 | self.ssm1 = Mamba( 176 | d_model = self.n_channels, 177 | out_c = self.out_channels, 178 | d_state=D_state, 179 | expand=Expand, 180 | ) 181 | 182 | 183 | def forward(self, x1): 184 | 185 | b,c,h,w = x1.shape 186 | x1 = x1.permute(0,2,3,1).flatten(start_dim=1,end_dim=2) ## [B,H*W,C] 187 | 188 | x1 = self.ssm1(x1) 189 | 190 | x1 = x1.view(b,h,w,self.out_channels).permute(0,3,1,2) 191 | return x1 192 | 193 | 194 | class Branch_3(nn.Module): 195 | def __init__(self,in_channels, out_channels,sp_dim1,sp_dim2): 196 | super().__init__() 197 | self.branch3 = nn.Sequential( 198 | SSM_spa(sp_dim1,sp_dim2), 199 | Linear_Layer(in_channels, out_channels), 200 | ) 201 | 202 | self.br_r = Linear_Layer_SP_Res(sp_dim1, sp_dim2,in_channels, out_channels) 203 | def forward(self, x): 204 | return (self.branch3(x) + self.br_r(x)) 205 | 206 | 207 | class Branch_2(nn.Module): 208 | def __init__(self,in_channels, out_channels): 209 | super().__init__() 210 | self.branch2 = nn.Sequential( 211 | SSM_cha(in_channels, out_channels), 212 | Linear_Layer(out_channels, out_channels), 213 | SSM_cha(out_channels, out_channels), 214 | nn.Dropout(p=drop_rate), 215 | ) 216 | def forward(self, x): 217 | return self.branch2(x) 218 | 219 | class Branch_2_Last(nn.Module): 220 | def __init__(self,in_channels, out_channels): 221 | super().__init__() 222 | self.branch2 = nn.Sequential( 223 | SSM_cha_Last(in_channels, out_channels), 224 | Linear_Layer_Last(out_channels, out_channels), 225 | SSM_cha_Last(out_channels, out_channels), 226 | ) 227 | def forward(self, x): 228 | return self.branch2(x) 229 | 230 | class Branch12_Last(nn.Module): 231 | def __init__(self, in_channels, out_channels): 232 | super().__init__() 233 | self.linear = Linear_Layer_Last(in_channels, out_channels) 234 | self.br2 = Branch_2_Last(in_channels, out_channels) 235 | 236 | def forward(self, x): 237 | x1 = self.linear(x) 238 | x2 = self.br2(x) 239 | return (x1+x2) 240 | 241 | 242 | class Branch12(nn.Module): 243 | def __init__(self, in_channels, out_channels): 244 | super().__init__() 245 | self.linear = Linear_Layer(in_channels, out_channels) 246 | self.br2 = Branch_2(in_channels, out_channels) 247 | self.drop = nn.Dropout(p=drop_rate) 248 | 249 | def forward(self, x): 250 | x1 = self.linear(x) 251 | x2 = self.br2(x) 252 | return self.drop(x1+x2) 253 | 254 | class Branch123(nn.Module): 255 | def __init__(self, in_channels, out_channels, sp_dim1,sp_dim2): 256 | super().__init__() 257 | self.linear = Linear_Layer(in_channels, out_channels) 258 | self.br2 = Branch_2(in_channels, out_channels) 259 | self.br3 = Branch_3(in_channels, out_channels, sp_dim1,sp_dim2) 260 | 261 | self.drop = nn.Dropout(p=drop_rate) 262 | 263 | def forward(self, x): 264 | x1 = self.linear(x) 265 | x2 = self.br2(x) 266 | x3 = self.br3(x) 267 | return self.drop(x1+x2+x3) 268 | 269 | 270 | class Down(nn.Module): 271 | def __init__(self, in_channels, out_channels): 272 | super().__init__() 273 | self.maxpool_br12 = nn.Sequential( 274 | nn.AvgPool2d(2), 275 | Branch12(in_channels, out_channels) 276 | ) 277 | def forward(self, x): 278 | return self.maxpool_br12(x) 279 | 280 | class Down_1(nn.Module): 281 | def __init__(self, in_channels, out_channels, sp_dim1,sp_dim2): 282 | super().__init__() 283 | self.maxpool_br123 = nn.Sequential( 284 | nn.AvgPool2d(2), 285 | Branch123(in_channels, out_channels, sp_dim1,sp_dim2) 286 | ) 287 | def forward(self, x): 288 | return self.maxpool_br123(x) 289 | 290 | class Up_1(nn.Module): 291 | def __init__(self, in_channels, out_channels,sp_dim1,sp_dim2): 292 | super().__init__() 293 | 294 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 295 | self.br123 = Branch123(in_channels+in_channels//2, out_channels,sp_dim1,sp_dim2) 296 | 297 | def forward(self, x1, x2): 298 | x1 = self.up(x1) 299 | x = torch.cat([x2, x1], dim=1) 300 | return self.br123(x) 301 | 302 | import math 303 | def positionalencoding2d(d_model, height, width): 304 | if d_model % 4 != 0: 305 | raise ValueError("Cannot use sin/cos positional encoding with " 306 | "odd dimension (got dim={:d})".format(d_model)) 307 | pe = torch.zeros(d_model, height, width) 308 | # Each dimension use half of d_model 309 | d_model = int(d_model / 2) 310 | div_term = torch.exp(torch.arange(0., d_model, 2) * 311 | -(math.log(10000.0) / d_model)) 312 | pos_w = torch.arange(0., width).unsqueeze(1) 313 | pos_h = torch.arange(0., height).unsqueeze(1) 314 | pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) 315 | pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) 316 | pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) 317 | pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) 318 | 319 | return pe 320 | 321 | 322 | BN1 = 10 #36 # 10 323 | BN2 = 5 #18 # 5 324 | Base = 64 325 | image_size = 160 # 576 #160 326 | 327 | from timm.models.layers import to_2tuple 328 | class PatchEmbed(nn.Module): # [2,1,160,160] -->[2,1600,96] 329 | def __init__(self, img_size=image_size, patch_size=2, in_chans=3, embed_dim=Base, Apply_Norm=False): 330 | super().__init__() 331 | img_size = to_2tuple(img_size) 332 | patch_size = to_2tuple(patch_size) 333 | patches_resolution = [img_size[0] // 334 | patch_size[0], img_size[1] // patch_size[1]] 335 | self.img_size = img_size 336 | self.patch_size = patch_size 337 | self.patches_resolution = patches_resolution 338 | self.num_patches = patches_resolution[0] * patches_resolution[1] 339 | 340 | self.in_chans = in_chans 341 | self.embed_dim = embed_dim 342 | 343 | self.proj = nn.Conv2d(in_chans, embed_dim, 344 | kernel_size=patch_size, stride=patch_size) 345 | 346 | self.norm = nn.LayerNorm(embed_dim) 347 | self.Apply_Norm = Apply_Norm 348 | 349 | def forward(self, x): 350 | B, C, H, W = x.shape 351 | # FIXME look at relaxing size constraints 352 | assert H == self.img_size[0] and W == self.img_size[1], \ 353 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 354 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 355 | if self.Apply_Norm: 356 | x = self.norm(x) 357 | x = x.transpose(1, 2).view(B, self.embed_dim, H//2, W//2) 358 | return x 359 | 360 | class OutConv(nn.Module): 361 | def __init__(self, in_channels, out_channels): 362 | super(OutConv, self).__init__() 363 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 364 | 365 | def forward(self, x): 366 | return self.conv(x) 367 | 368 | 369 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 370 | class CustomEncoder(nn.Module): 371 | def __init__(self, n_channels=1): 372 | super(CustomEncoder, self).__init__() 373 | self.n_channels = n_channels 374 | 375 | self.pm = PatchEmbed() 376 | self.inc = Branch12(Base,Base) 377 | self.down1 = Down(Base,2*Base) 378 | self.down2 = Down(2*Base,4*Base) 379 | 380 | self.down3 = Down_1(4*Base,8*Base,BN1,BN1) 381 | self.down4 = Down_1(8*Base,16*Base,BN2,BN2) 382 | self.m = nn.SiLU() 383 | 384 | self.pos_embed = positionalencoding2d(Base,image_size//2,image_size//2).to(DEVICE) 385 | 386 | self.avgpool = nn.AdaptiveAvgPool2d(1) 387 | self.classifier = nn.Sequential( 388 | nn.Linear(1024, 1280), 389 | nn.Hardswish(inplace=True), 390 | nn.Dropout(p=0.2, inplace=True), 391 | nn.Linear(1280, 1000), 392 | ) 393 | 394 | def forward(self, inp): 395 | 396 | #inp = inp.repeat(1,3,1,1) 397 | inp = self.pm(inp) 398 | inp = inp + self.pos_embed 399 | 400 | x1_12 = self.inc(inp) 401 | x1_98 = flip_Dim(inp,'98') 402 | x1_98 = self.inc(x1_98) 403 | x1_98 = flip_Dim_back(x1_98,'98') 404 | x1 = x1_12 + x1_98 405 | 406 | x1 = self.m(x1) 407 | 408 | x2_12 = self.down1(x1) 409 | 410 | x2_98 = flip_Dim(x1,'98') 411 | x2_98 = self.down1(x2_98) 412 | x2_98 = flip_Dim_back(x2_98,'98') 413 | 414 | x2 = x2_12 + x2_98 415 | x2 = self.m(x2) 416 | 417 | x3_12 = self.down2(x2) 418 | 419 | x3_98 = flip_Dim(x2,'98') 420 | x3_98 = self.down2(x3_98) 421 | x3_98 = flip_Dim_back(x3_98,'98') 422 | 423 | x3 = x3_12 + x3_98 424 | x3 = self.m(x3) 425 | 426 | x4_12 = self.down3(x3) 427 | 428 | x4_98 = flip_Dim(x3,'98') 429 | x4_98 = self.down3(x4_98) 430 | x4_98 = flip_Dim_back(x4_98,'98') 431 | 432 | x4 = x4_12 + x4_98 433 | x4 = self.m(x4) 434 | 435 | 436 | x5_12 = self.down4(x4) 437 | 438 | x5_98 = flip_Dim(x4,'98') 439 | x5_98 = self.down4(x5_98) 440 | x5_98 = flip_Dim_back(x5_98,'98') 441 | 442 | x5 = x5_12 + x5_98 443 | x5 = self.m(x5) 444 | 445 | x5 = self.avgpool(x5) 446 | x5 = torch.flatten(x5, 1) 447 | x5 = self.classifier(x5) 448 | return x5 449 | 450 | -------------------------------------------------------------------------------- /models/seg_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mamba_ssm import Mamba 4 | 5 | D_state = 2 6 | Expand = 1 7 | drop_rate = 0.10 8 | 9 | def flip_Dim(x,dim): 10 | if dim == '12': 11 | x1=x 12 | if dim == '14': 13 | x1 = x.permute(0, 1, 3, 2).contiguous() 14 | if dim == '98': 15 | x1 = torch.flip(x,[2,3]) 16 | if dim == '96': 17 | x1 = torch.flip(x,[2,3]).permute(0, 1, 3, 2).contiguous() 18 | if dim == '78': 19 | x1 = torch.flip(x,[2]) 20 | if dim == '74': 21 | x1 = torch.flip(x,[2]).permute(0, 1, 3, 2).contiguous() 22 | if dim == '32': 23 | x1 = torch.flip(x,[3]) 24 | if dim == '36': 25 | x1 = torch.flip(x,[3]).permute(0, 1, 3, 2).contiguous() 26 | return x1 27 | 28 | def flip_Dim_back(x,dim): 29 | if dim == '12': 30 | x1=x 31 | if dim == '14': 32 | x1 = x.permute(0, 1, 3, 2).contiguous() 33 | if dim == '98': 34 | x1 = torch.flip(x,[2,3]) 35 | if dim == '96': 36 | x1 = torch.flip(x,[2,3]).permute(0, 1, 3, 2).contiguous() 37 | if dim == '78': 38 | x1 = torch.flip(x,[2]) 39 | if dim == '74': 40 | x1 = torch.flip(x,[3]).permute(0, 1, 3, 2).contiguous() 41 | if dim == '32': 42 | x1 = torch.flip(x,[3]) 43 | if dim == '36': 44 | x1 = torch.flip(x,[2]).permute(0, 1, 3, 2).contiguous() 45 | return x1 46 | 47 | class Linear_Layer(nn.Module): 48 | def __init__(self, n_channels,out_channels): 49 | super(Linear_Layer, self).__init__() 50 | self.n_channels = n_channels 51 | self.out_channels = out_channels 52 | 53 | self.linear1 = nn.Linear(self.n_channels,self.out_channels) 54 | self.drop = nn.Dropout(p=drop_rate) 55 | 56 | self.m = nn.SiLU() 57 | #self.norm = nn.LayerNorm(normalized_shape=out_channels) 58 | 59 | def forward(self, x1): 60 | 61 | b,c,h,w = x1.shape 62 | x1 = x1.permute(0,2,3,1).flatten(start_dim=1,end_dim=2) ## [B,H*W,C] 63 | x1 = self.linear1(x1) 64 | #x1 = self.norm(x1) 65 | x1 = self.m(x1) 66 | x1 = self.drop(x1) 67 | x1 = x1.view(b,h,w,self.out_channels).permute(0,3,1,2) 68 | return x1 69 | 70 | class Linear_Layer_Last(nn.Module): 71 | def __init__(self, n_channels,out_channels): 72 | super(Linear_Layer_Last, self).__init__() 73 | self.n_channels = n_channels 74 | self.out_channels = out_channels 75 | 76 | self.linear1 = nn.Linear(self.n_channels,self.out_channels) 77 | 78 | def forward(self, x1): 79 | 80 | b,c,h,w = x1.shape 81 | x1 = x1.permute(0,2,3,1).flatten(start_dim=1,end_dim=2) ## [B,H*W,C] 82 | x1 = self.linear1(x1) 83 | x1 = x1.view(b,h,w,self.out_channels).permute(0,3,1,2) 84 | return x1 85 | 86 | 87 | class SSM_spa(nn.Module): 88 | def __init__(self, sp_dim1,sp_dim2): 89 | super(SSM_spa, self).__init__() 90 | 91 | self.sp_dim2 = sp_dim2 92 | self.ssm2 = Mamba( 93 | d_model = sp_dim1*sp_dim1, 94 | out_c = sp_dim2*sp_dim2, 95 | d_state = D_state, 96 | expand=Expand) 97 | 98 | self.norm = nn.LayerNorm(normalized_shape=sp_dim2*sp_dim2) 99 | 100 | def forward(self, x1): 101 | 102 | b,c,h,w = x1.shape 103 | x1 = x1.flatten(start_dim=2,end_dim=3) ## [B,H*W,C] 104 | 105 | x1 = self.ssm2(x1) 106 | x1 = self.norm(x1) 107 | x1 = x1.view(b,c,self.sp_dim2,self.sp_dim2) 108 | return x1 109 | 110 | class SSM_cha(nn.Module): 111 | def __init__(self, n_channels,out_channels): 112 | super(SSM_cha, self).__init__() 113 | self.n_channels = n_channels 114 | self.out_channels = out_channels 115 | 116 | self.ssm1 = Mamba( 117 | d_model = self.n_channels, 118 | out_c = self.out_channels, 119 | d_state=D_state, 120 | expand=Expand, 121 | ) 122 | 123 | self.norm = nn.LayerNorm(normalized_shape=out_channels) 124 | 125 | def forward(self, x1): 126 | 127 | b,c,h,w = x1.shape 128 | x1 = x1.permute(0,2,3,1).flatten(start_dim=1,end_dim=2) ## [B,H*W,C] 129 | 130 | x1 = self.ssm1(x1) 131 | x1 = self.norm(x1) 132 | 133 | x1 = x1.view(b,h,w,self.out_channels).permute(0,3,1,2) 134 | return x1 135 | 136 | class Linear_Layer_SP_Res(nn.Module): 137 | def __init__(self, sp_dim1,sp_dim2,in_channels,out_channels): 138 | super(Linear_Layer_SP_Res, self).__init__() 139 | self.sp_dim1 = sp_dim1 140 | self.sp_dim2 = sp_dim2 141 | 142 | self.in_channels = in_channels 143 | self.out_channels = out_channels 144 | 145 | self.linear1 = nn.Linear(sp_dim1*sp_dim1,sp_dim2*sp_dim2) 146 | self.drop = nn.Dropout(p=drop_rate) 147 | 148 | 149 | 150 | self.lin_chan = Linear_Layer(self.in_channels,self.out_channels) 151 | #self.norm = nn.LayerNorm(normalized_shape=sp_dim2*sp_dim2) 152 | 153 | self.m = nn.SiLU() 154 | 155 | def forward(self, x1): 156 | 157 | b,c,h,w = x1.shape 158 | x1 = x1.flatten(start_dim=2,end_dim=3) 159 | 160 | x1 = self.linear1(x1) 161 | #x1 = self.norm(x1) 162 | x1 = self.m(x1) 163 | x1 = self.drop(x1) 164 | 165 | x1 = x1.view(b,c,h,w) 166 | x1 = self.lin_chan(x1) 167 | return x1 168 | 169 | class SSM_cha_Last(nn.Module): 170 | def __init__(self, n_channels,out_channels): 171 | super(SSM_cha_Last, self).__init__() 172 | self.n_channels = n_channels 173 | self.out_channels = out_channels 174 | 175 | self.ssm1 = Mamba( 176 | d_model = self.n_channels, 177 | out_c = self.out_channels, 178 | d_state=D_state, 179 | expand=Expand, 180 | ) 181 | 182 | 183 | def forward(self, x1): 184 | 185 | b,c,h,w = x1.shape 186 | x1 = x1.permute(0,2,3,1).flatten(start_dim=1,end_dim=2) ## [B,H*W,C] 187 | 188 | x1 = self.ssm1(x1) 189 | 190 | x1 = x1.view(b,h,w,self.out_channels).permute(0,3,1,2) 191 | return x1 192 | 193 | 194 | class Branch_3(nn.Module): 195 | def __init__(self,in_channels, out_channels,sp_dim1,sp_dim2): 196 | super().__init__() 197 | self.branch3 = nn.Sequential( 198 | SSM_spa(sp_dim1,sp_dim2), 199 | Linear_Layer(in_channels, out_channels), 200 | ) 201 | 202 | self.br_r = Linear_Layer_SP_Res(sp_dim1, sp_dim2,in_channels, out_channels) 203 | def forward(self, x): 204 | return (self.branch3(x) + self.br_r(x)) 205 | 206 | 207 | class Branch_2(nn.Module): 208 | def __init__(self,in_channels, out_channels): 209 | super().__init__() 210 | self.branch2 = nn.Sequential( 211 | SSM_cha(in_channels, out_channels), 212 | Linear_Layer(out_channels, out_channels), 213 | SSM_cha(out_channels, out_channels), 214 | nn.Dropout(p=drop_rate), 215 | ) 216 | def forward(self, x): 217 | return self.branch2(x) 218 | 219 | class Branch_2_Last(nn.Module): 220 | def __init__(self,in_channels, out_channels): 221 | super().__init__() 222 | self.branch2 = nn.Sequential( 223 | SSM_cha_Last(in_channels, out_channels), 224 | Linear_Layer_Last(out_channels, out_channels), 225 | SSM_cha_Last(out_channels, out_channels), 226 | ) 227 | def forward(self, x): 228 | return self.branch2(x) 229 | 230 | class Branch12_Last(nn.Module): 231 | def __init__(self, in_channels, out_channels): 232 | super().__init__() 233 | self.linear = Linear_Layer_Last(in_channels, out_channels) 234 | self.br2 = Branch_2_Last(in_channels, out_channels) 235 | 236 | def forward(self, x): 237 | x1 = self.linear(x) 238 | x2 = self.br2(x) 239 | return (x1+x2) 240 | 241 | 242 | class Branch12(nn.Module): 243 | def __init__(self, in_channels, out_channels): 244 | super().__init__() 245 | self.linear = Linear_Layer(in_channels, out_channels) 246 | self.br2 = Branch_2(in_channels, out_channels) 247 | self.drop = nn.Dropout(p=drop_rate) 248 | 249 | def forward(self, x): 250 | x1 = self.linear(x) 251 | x2 = self.br2(x) 252 | return self.drop(x1+x2) 253 | 254 | class Branch123(nn.Module): 255 | def __init__(self, in_channels, out_channels, sp_dim1,sp_dim2): 256 | super().__init__() 257 | self.linear = Linear_Layer(in_channels, out_channels) 258 | self.br2 = Branch_2(in_channels, out_channels) 259 | self.br3 = Branch_3(in_channels, out_channels, sp_dim1,sp_dim2) 260 | 261 | self.drop = nn.Dropout(p=drop_rate) 262 | 263 | def forward(self, x): 264 | x1 = self.linear(x) 265 | x2 = self.br2(x) 266 | x3 = self.br3(x) 267 | return self.drop(x1+x2+x3) 268 | 269 | 270 | class Down(nn.Module): 271 | def __init__(self, in_channels, out_channels): 272 | super().__init__() 273 | self.maxpool_br12 = nn.Sequential( 274 | nn.AvgPool2d(2), 275 | Branch12(in_channels, out_channels) 276 | ) 277 | def forward(self, x): 278 | return self.maxpool_br12(x) 279 | 280 | class Down_1(nn.Module): 281 | def __init__(self, in_channels, out_channels, sp_dim1,sp_dim2): 282 | super().__init__() 283 | self.maxpool_br123 = nn.Sequential( 284 | nn.AvgPool2d(2), 285 | Branch123(in_channels, out_channels, sp_dim1,sp_dim2) 286 | ) 287 | def forward(self, x): 288 | return self.maxpool_br123(x) 289 | 290 | class Up_1(nn.Module): 291 | def __init__(self, in_channels, out_channels,sp_dim1,sp_dim2): 292 | super().__init__() 293 | 294 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 295 | self.br123 = Branch123(in_channels+in_channels//2, out_channels,sp_dim1,sp_dim2) 296 | 297 | def forward(self, x1, x2): 298 | x1 = self.up(x1) 299 | x = torch.cat([x2, x1], dim=1) 300 | return self.br123(x) 301 | 302 | import math 303 | def positionalencoding2d(d_model, height, width): 304 | if d_model % 4 != 0: 305 | raise ValueError("Cannot use sin/cos positional encoding with " 306 | "odd dimension (got dim={:d})".format(d_model)) 307 | pe = torch.zeros(d_model, height, width) 308 | # Each dimension use half of d_model 309 | d_model = int(d_model / 2) 310 | div_term = torch.exp(torch.arange(0., d_model, 2) * 311 | -(math.log(10000.0) / d_model)) 312 | pos_w = torch.arange(0., width).unsqueeze(1) 313 | pos_h = torch.arange(0., height).unsqueeze(1) 314 | pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) 315 | pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) 316 | pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) 317 | pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) 318 | 319 | return pe 320 | 321 | 322 | BN1 = 10 323 | BN2 = 5 324 | Base = 64 325 | image_size = 160 326 | 327 | from timm.models.layers import to_2tuple 328 | class PatchEmbed(nn.Module): # [2,1,160,160] -->[2,1600,96] 329 | def __init__(self, img_size=image_size, patch_size=2, in_chans=3, embed_dim=Base, Apply_Norm=False): 330 | super().__init__() 331 | img_size = to_2tuple(img_size) 332 | patch_size = to_2tuple(patch_size) 333 | patches_resolution = [img_size[0] // 334 | patch_size[0], img_size[1] // patch_size[1]] 335 | self.img_size = img_size 336 | self.patch_size = patch_size 337 | self.patches_resolution = patches_resolution 338 | self.num_patches = patches_resolution[0] * patches_resolution[1] 339 | 340 | self.in_chans = in_chans 341 | self.embed_dim = embed_dim 342 | 343 | self.proj = nn.Conv2d(in_chans, embed_dim, 344 | kernel_size=patch_size, stride=patch_size) 345 | 346 | self.norm = nn.LayerNorm(embed_dim) 347 | self.Apply_Norm = Apply_Norm 348 | 349 | def forward(self, x): 350 | B, C, H, W = x.shape 351 | # FIXME look at relaxing size constraints 352 | assert H == self.img_size[0] and W == self.img_size[1], \ 353 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 354 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 355 | if self.Apply_Norm: 356 | x = self.norm(x) 357 | x = x.transpose(1, 2).view(B, self.embed_dim, H//2, W//2) 358 | return x 359 | 360 | class OutConv(nn.Module): 361 | def __init__(self, in_channels, out_channels): 362 | super(OutConv, self).__init__() 363 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 364 | 365 | def forward(self, x): 366 | return self.conv(x) 367 | 368 | 369 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 370 | class CustomEncoder(nn.Module): 371 | def __init__(self, n_channels=1): 372 | super(CustomEncoder, self).__init__() 373 | self.n_channels = n_channels 374 | 375 | self.pm = PatchEmbed() 376 | self.inc = Branch12(Base,Base) 377 | self.down1 = Down(Base,2*Base) 378 | self.down2 = Down(2*Base,4*Base) 379 | 380 | self.down3 = Down_1(4*Base,8*Base,BN1,BN1) 381 | self.down4 = Down_1(8*Base,16*Base,BN2,BN2) 382 | self.m = nn.SiLU() 383 | 384 | self.pos_embed = positionalencoding2d(Base,image_size//2,image_size//2).to(DEVICE) 385 | 386 | self.avgpool = nn.AdaptiveAvgPool2d(1) 387 | self.classifier = nn.Sequential( 388 | nn.Linear(1024, 1280), 389 | nn.Hardswish(inplace=True), 390 | nn.Dropout(p=0.2, inplace=True), 391 | nn.Linear(1280, 1000), 392 | ) 393 | 394 | def forward(self, inp): 395 | 396 | inp = inp.repeat(1,3,1,1) 397 | inp = self.pm(inp) 398 | inp = inp + self.pos_embed 399 | 400 | x1_12 = self.inc(inp) 401 | x1_98 = flip_Dim(inp,'98') 402 | x1_98 = self.inc(x1_98) 403 | x1_98 = flip_Dim_back(x1_98,'98') 404 | x1 = x1_12 + x1_98 405 | 406 | x1 = self.m(x1) 407 | 408 | x2_12 = self.down1(x1) 409 | 410 | x2_98 = flip_Dim(x1,'98') 411 | x2_98 = self.down1(x2_98) 412 | x2_98 = flip_Dim_back(x2_98,'98') 413 | 414 | x2 = x2_12 + x2_98 415 | x2 = self.m(x2) 416 | 417 | x3_12 = self.down2(x2) 418 | 419 | x3_98 = flip_Dim(x2,'98') 420 | x3_98 = self.down2(x3_98) 421 | x3_98 = flip_Dim_back(x3_98,'98') 422 | 423 | x3 = x3_12 + x3_98 424 | x3 = self.m(x3) 425 | 426 | x4_12 = self.down3(x3) 427 | 428 | x4_98 = flip_Dim(x3,'98') 429 | x4_98 = self.down3(x4_98) 430 | x4_98 = flip_Dim_back(x4_98,'98') 431 | 432 | x4 = x4_12 + x4_98 433 | x4 = self.m(x4) 434 | 435 | 436 | x5_12 = self.down4(x4) 437 | 438 | x5_98 = flip_Dim(x4,'98') 439 | x5_98 = self.down4(x5_98) 440 | x5_98 = flip_Dim_back(x5_98,'98') 441 | 442 | x5 = x5_12 + x5_98 443 | x5_r = self.m(x5) 444 | 445 | x5 = self.avgpool(x5_r) 446 | x5 = torch.flatten(x5, 1) 447 | x5 = self.classifier(x5) 448 | return x5_r,x4,x3,x2,x1,inp 449 | 450 | 451 | class Up(nn.Module): 452 | def __init__(self, in_channels, out_channels,last=None): 453 | super().__init__() 454 | 455 | self.last = last 456 | if self.last is None: 457 | in_channels = in_channels+in_channels//2 458 | if self.last is not None: 459 | in_channels = in_channels + in_channels 460 | 461 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 462 | self.br12 = Branch12(in_channels, out_channels) 463 | 464 | def forward(self, x1, x2): 465 | x1 = self.up(x1) 466 | if self.last is not None: 467 | x2 = self.up(x2) 468 | x = torch.cat([x2, x1], dim=1) 469 | return self.br12(x) 470 | 471 | Num_Classes = 5 472 | class Decoder(nn.Module): 473 | def __init__(self): 474 | super(Decoder, self).__init__() 475 | 476 | self.up1 = Up_1(16*Base,8*Base,BN1,BN1) 477 | #self.up2 = Up_1(8*Base,4*Base,BN3,BN3) 478 | 479 | self.up2 = Up(8*Base,4*Base) 480 | self.up3 = Up(4*Base,2*Base) 481 | self.up4 = Up(2*Base,Base) 482 | self.up5 = Up(Base,Base,last='yes') 483 | self.outc = Branch12_Last(Base, Num_Classes) 484 | 485 | def forward(self, x5,x4,x3,x2, x1,inp): 486 | 487 | x = self.up1(x5, x4) 488 | x = self.up2(x, x3) 489 | x = self.up3(x, x2) 490 | x = self.up4(x, x1) 491 | x = self.up5(x, inp) 492 | x = self.outc(x) 493 | return x 494 | 495 | class EncoderDecoder(nn.Module): 496 | def __init__(self, encoder): 497 | super(EncoderDecoder, self).__init__() 498 | self.encoder = encoder 499 | self.decoder = Decoder() 500 | 501 | def forward(self, x): 502 | x5,x4,x3,x2,x1,inp = self.encoder(x) 503 | out = self.decoder(x5,x4,x3,x2,x1,inp) 504 | return out 505 | --------------------------------------------------------------------------------