├── LICENSE ├── README.md ├── CSAM_modules.py ├── CSAM_networks.py └── experiment.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Alex Hung 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CSAM: A 2.5D Cross-Slice Attention Module for Anisotropic Volumetric Medical Image Segmentation 2 | 3 | This is the official code implementation of our WACV paper [CSAM: A 2.5D Cross-Slice Attention Module for Anisotropic Volumetric Medical Image Segmentation](https://arxiv.org/abs/2311.04942) 4 | 5 | [Presentation](https://youtu.be/602DSt8DXBw) [Poster](https://drive.google.com/file/d/1gS9O3QnCQUeYqCPwuGRRSmPb-MdnyPV9/view?usp=sharing) [Slides](https://drive.google.com/file/d/1JjapsZT7IiE1Vdm9CoUnXxas7_JuE-qd/view?usp=sharing) 6 | 7 | This is a followup work of our previous work in TMI, [CAT-Net: A Cross-Slice Attention Transformer Model for Prostate Zonal Segmentation in MRI](https://github.com/aL3x-O-o-Hung/CAT-Net) 8 | 9 | Our followup MICCAI 2024 paper, [Cross-Slice Attention and Evidential Critical Loss for Uncertainty-Aware Prostate Cancer Detection](https://github.com/aL3x-O-o-Hung/GLCSA_ECLoss) 10 | 11 | 12 | 13 | # Credits 14 | 15 | If you use this code or the paper in your work, please cite our paper: 16 | ```bash 17 | @inproceedings{hung2024csam, 18 | title={CSAM: A 2.5 D Cross-Slice Attention Module for Anisotropic Volumetric Medical Image Segmentation}, 19 | author={Hung, Alex Ling Yu and Zheng, Haoxin and Zhao, Kai and Du, Xiaoxi and Pang, Kaifeng and Miao, Qi and Raman, Steven S and Terzopoulos, Demetri and Sung, Kyunghyun}, 20 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision}, 21 | pages={5923--5932}, 22 | year={2024} 23 | } 24 | ``` 25 | -------------------------------------------------------------------------------- /CSAM_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributions as td 4 | import torch.nn.functional as F 5 | import math 6 | import numpy as np 7 | 8 | def custom_max(x,dim,keepdim=True): 9 | temp_x=x 10 | for i in dim: 11 | temp_x=torch.max(temp_x,dim=i,keepdim=True)[0] 12 | if not keepdim: 13 | temp_x=temp_x.squeeze() 14 | return temp_x 15 | 16 | class PositionalAttentionModule(nn.Module): 17 | def __init__(self): 18 | super(PositionalAttentionModule,self).__init__() 19 | self.conv=nn.Conv2d(in_channels=2,out_channels=1,kernel_size=(7,7),padding=3) 20 | def forward(self,x): 21 | max_x=custom_max(x,dim=(0,1),keepdim=True) 22 | avg_x=torch.mean(x,dim=(0,1),keepdim=True) 23 | att=torch.cat((max_x,avg_x),dim=1) 24 | att=self.conv(att) 25 | att=torch.sigmoid(att) 26 | return x*att 27 | 28 | class SemanticAttentionModule(nn.Module): 29 | def __init__(self,in_features,reduction_rate=16): 30 | super(SemanticAttentionModule,self).__init__() 31 | self.linear=[] 32 | self.linear.append(nn.Linear(in_features=in_features,out_features=in_features//reduction_rate)) 33 | self.linear.append(nn.ReLU()) 34 | self.linear.append(nn.Linear(in_features=in_features//reduction_rate,out_features=in_features)) 35 | self.linear=nn.Sequential(*self.linear) 36 | def forward(self,x): 37 | max_x=custom_max(x,dim=(0,2,3),keepdim=False).unsqueeze(0) 38 | avg_x=torch.mean(x,dim=(0,2,3),keepdim=False).unsqueeze(0) 39 | max_x=self.linear(max_x) 40 | avg_x=self.linear(avg_x) 41 | att=max_x+avg_x 42 | att=torch.sigmoid(att).unsqueeze(-1).unsqueeze(-1) 43 | return x*att 44 | 45 | class SliceAttentionModule(nn.Module): 46 | def __init__(self,in_features,rate=4,uncertainty=True,rank=5): 47 | super(SliceAttentionModule,self).__init__() 48 | self.uncertainty=uncertainty 49 | self.rank=rank 50 | self.linear=[] 51 | self.linear.append(nn.Linear(in_features=in_features,out_features=int(in_features*rate))) 52 | self.linear.append(nn.ReLU()) 53 | self.linear.append(nn.Linear(in_features=int(in_features*rate),out_features=in_features)) 54 | self.linear=nn.Sequential(*self.linear) 55 | if uncertainty: 56 | self.non_linear=nn.ReLU() 57 | self.mean=nn.Linear(in_features=in_features,out_features=in_features) 58 | self.log_diag=nn.Linear(in_features=in_features,out_features=in_features) 59 | self.factor=nn.Linear(in_features=in_features,out_features=in_features*rank) 60 | def forward(self,x): 61 | max_x=custom_max(x,dim=(1,2,3),keepdim=False).unsqueeze(0) 62 | avg_x=torch.mean(x,dim=(1,2,3),keepdim=False).unsqueeze(0) 63 | max_x=self.linear(max_x) 64 | avg_x=self.linear(avg_x) 65 | att=max_x+avg_x 66 | if self.uncertainty: 67 | temp=self.non_linear(att) 68 | mean=self.mean(temp) 69 | diag=self.log_diag(temp).exp() 70 | factor=self.factor(temp) 71 | factor=factor.view(1,-1,self.rank) 72 | dist=td.LowRankMultivariateNormal(loc=mean,cov_factor=factor,cov_diag=diag) 73 | att=dist.sample() 74 | att=torch.sigmoid(att).squeeze().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 75 | return x*att 76 | 77 | 78 | class CSAM(nn.Module): 79 | def __init__(self,num_slices,num_channels,semantic=True,positional=True,slice=True,uncertainty=True,rank=5): 80 | super(CSAM,self).__init__() 81 | self.semantic=semantic 82 | self.positional=positional 83 | self.slice=slice 84 | if semantic: 85 | self.semantic_att=SemanticAttentionModule(num_channels) 86 | if positional: 87 | self.positional_att=PositionalAttentionModule() 88 | if slice: 89 | self.slice_att=SliceAttentionModule(num_slices,uncertainty=uncertainty,rank=rank) 90 | def forward(self,x): 91 | if self.semantic: 92 | x=self.semantic_att(x) 93 | if self.positional: 94 | x=self.positional_att(x) 95 | if self.slice: 96 | x=self.slice_att(x) 97 | return x 98 | 99 | -------------------------------------------------------------------------------- /CSAM_networks.py: -------------------------------------------------------------------------------- 1 | from CSAM_modules import * 2 | 3 | class ConvBlock(nn.Module): 4 | def __init__(self,input_channels,output_channels,max_pool,return_single=False): 5 | super(ConvBlock,self).__init__() 6 | self.max_pool=max_pool 7 | self.conv=[] 8 | self.conv.append(nn.Conv2d(in_channels=input_channels,out_channels=output_channels,kernel_size=3,stride=1,padding=1)) 9 | self.conv.append(nn.InstanceNorm2d(output_channels)) 10 | self.conv.append(nn.LeakyReLU()) 11 | self.conv.append(nn.Conv2d(in_channels=output_channels,out_channels=output_channels,kernel_size=3,stride=1,padding=1)) 12 | self.conv.append(nn.InstanceNorm2d(output_channels)) 13 | self.conv.append(nn.LeakyReLU()) 14 | self.return_single=return_single 15 | if max_pool: 16 | self.pool=nn.MaxPool2d(2,stride=2,dilation=(1,1)) 17 | self.conv=nn.Sequential(*self.conv) 18 | 19 | def forward(self,x): 20 | x=self.conv(x) 21 | b=x 22 | if self.max_pool: 23 | x=self.pool(x) 24 | if self.return_single: 25 | return x 26 | else: 27 | return x,b 28 | 29 | 30 | class DeconvBlock(nn.Module): 31 | def __init__(self,input_channels,output_channels,intermediate_channels=-1): 32 | super(DeconvBlock,self).__init__() 33 | input_channels=int(input_channels) 34 | output_channels=int(output_channels) 35 | if intermediate_channels<0: 36 | intermediate_channels=output_channels*2 37 | else: 38 | intermediate_channels=input_channels 39 | self.upconv=[] 40 | self.upconv.append(nn.UpsamplingBilinear2d(scale_factor=2)) 41 | self.upconv.append(nn.Conv2d(in_channels=input_channels,out_channels=intermediate_channels//2,kernel_size=3,stride=1,padding=1)) 42 | self.conv=ConvBlock(intermediate_channels,output_channels,False) 43 | self.upconv=nn.Sequential(*self.upconv) 44 | 45 | def forward(self,x,b): 46 | x=self.upconv(x) 47 | x=torch.cat((x,b),dim=1) 48 | x,_=self.conv(x) 49 | return x 50 | 51 | class UNetDecoder(nn.Module): 52 | def __init__(self,num_layers,base_num): 53 | super(UNetDecoder,self).__init__() 54 | self.conv=[] 55 | self.num_layers=num_layers 56 | for i in range(num_layers-1,0,-1): 57 | self.conv.append(DeconvBlock(base_num*(2**i),base_num*(2**(i-1)))) 58 | self.conv=nn.Sequential(*self.conv) 59 | 60 | 61 | class EncoderCSAM(nn.Module): 62 | def __init__(self,input_channels,num_layers,base_num,batch_size=20,semantic=True,positional=True,slice=True,uncertainty=True,rank=5): 63 | super(EncoderCSAM,self).__init__() 64 | self.conv=[] 65 | self.num_layers=num_layers 66 | for i in range(num_layers): 67 | if i==0: 68 | self.conv.append(ConvBlock(input_channels,base_num,True)) 69 | else: 70 | self.conv.append(ConvBlock(base_num*(2**(i-1)),base_num*(2**i),(i!=num_layers-1))) 71 | self.conv=nn.Sequential(*self.conv) 72 | self.attentions=[] 73 | for i in range(num_layers): 74 | self.attentions.append(CSAM(batch_size,base_num*(2**i),semantic,positional,slice,uncertainty,rank)) 75 | self.attentions=nn.Sequential(*self.attentions) 76 | 77 | def forward(self,x): 78 | b=[] 79 | for i in range(self.num_layers): 80 | x,block=self.conv[i](x) 81 | if i!=self.num_layers-1: 82 | block=self.attentions[i](block) 83 | else: 84 | x=self.attentions[i](x) 85 | b.append(block) 86 | b=b[:-1] 87 | b=b[::-1] 88 | return x,b 89 | 90 | class C2BAMUNet(nn.Module): 91 | def __init__(self,input_channels,num_classes,num_layers,base_num=64,batch_size=20,semantic=True,positional=True,slice=True,uncertainty=True,rank=5): 92 | super(C2BAMUNet,self).__init__() 93 | self.encoder=EncoderCSAM(input_channels,num_layers,base_num,batch_size=batch_size,semantic=semantic,positional=positional,slice=slice,uncertainty=uncertainty,rank=rank) 94 | self.decoder=UNetDecoder(num_layers,base_num) 95 | self.base_num=base_num 96 | self.input_channels=input_channels 97 | self.num_classes=num_classes 98 | self.conv_final=nn.Conv2d(in_channels=base_num,out_channels=num_classes,kernel_size=(1,1)) 99 | 100 | def forward(self,x): 101 | x,b=self.encoder(x) 102 | x=self.decoder(x,b) 103 | x=self.conv_final(x) 104 | return x 105 | 106 | 107 | class CSAMUNetPlusPlus(nn.Module): 108 | def __init__(self,input_channels,num_classes,num_layers,base_num=64,batch_size=20,semantic=True,positional=True,slice=True,uncertainty=True,rank=5): 109 | super(CSAMUNetPlusPlus).__init__() 110 | self.num_layers=num_layers 111 | nb_filter=[] 112 | for i in range(num_layers): 113 | nb_filter.append(base_num*(2**i)) 114 | self.pool=nn.MaxPool2d(2,2) 115 | self.up=nn.Upsample(scale_factor=2,mode='bilinear',align_corners=True) 116 | self.conv=[] 117 | for i in range(num_layers): 118 | temp_conv=[] 119 | for j in range(num_layers-i): 120 | if j==0: 121 | if i==0: 122 | inp=input_channels 123 | else: 124 | inp=nb_filter[i-1] 125 | else: 126 | inp=nb_filter[i]*j+nb_filter[i+1] 127 | temp_conv.append(ConvBlock(inp,nb_filter[i],False,True)) 128 | self.conv.append(nn.Sequential(*temp_conv)) 129 | self.conv=nn.Sequential(*self.conv) 130 | self.attentions=[] 131 | for i in range(num_layers): 132 | self.attentions.append(CSAM(batch_size,base_num*(2**i),semantic=semantic,positional=positional,slice=slice,uncertainty=uncertainty,rank=rank)) 133 | self.attentions=nn.Sequential(*self.attentions) 134 | self.final=[] 135 | for i in range(num_layers-1): 136 | self.final.append(nn.Conv2d(nb_filter[0],num_classes,kernel_size=(1,1))) 137 | self.final=nn.Sequential(*self.final) 138 | 139 | def forward(self,inputs): 140 | x=[] 141 | for i in range(self.num_layers): 142 | temp=[] 143 | for j in range(self.num_layers-i): 144 | temp.append([]) 145 | x.append(temp) 146 | x[0][0].append(self.conv[0][0](inputs)) 147 | for s in range(1,self.num_layers): 148 | for i in range(s+1): 149 | if i==0: 150 | x[s-i][i].append(self.conv[s-i][i](self.pool(x[s-i-1][i][0]))) 151 | else: 152 | for j in range(i): 153 | if j==0: 154 | block=x[s-i][j][0] 155 | block=self.attentions[s-i](block) 156 | temp_x=block 157 | #print(s-i,j) 158 | else: 159 | temp_x=torch.cat((temp_x,x[s-i][j][0]),dim=1) 160 | #print(s-i,j) 161 | temp_x=torch.cat((temp_x,self.up(x[s-i+1][i-1][0])),dim=1) 162 | #print('up',s-i+1,i-1,temp_x.size(),self.up(x[s-i+1][i-1][0]).size()) 163 | x[s-i][i].append(self.conv[s-i][i](temp_x)) 164 | if self.training: 165 | res=[] 166 | for i in range(self.num_layers-1): 167 | res.append(self.final[i](x[0][i+1][0])) 168 | return res 169 | else: 170 | return self.final[-1](x[0][-1][0]) 171 | -------------------------------------------------------------------------------- /experiment.py: -------------------------------------------------------------------------------- 1 | from dataloader import * 2 | from CAT_Net import CrossSliceAttentionUNet,CrossSliceUNetPlusPlus 3 | import argparse 4 | from datetime import datetime 5 | import numpy as np 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import pickle 11 | 12 | 13 | 14 | def print_log(print_string,log): 15 | print("{:}".format(print_string)) 16 | log.write('{:}\n'.format(print_string)) 17 | log.flush() 18 | 19 | 20 | 21 | 22 | def dice_coeff(seg,target,smooth=0.001): 23 | intersection=np.sum(seg*target) 24 | dice=(2*intersection+smooth)/(np.sum(seg)+np.sum(target)+smooth) 25 | return dice 26 | 27 | 28 | def validate(model,dataloader,args): 29 | device=args.device 30 | model.eval() 31 | total_dice=np.zeros(args.num_classes-1) 32 | c=0 33 | criterion=torch.nn.CrossEntropyLoss() 34 | loss=0 35 | for batch_num,data in enumerate(dataloader): 36 | img,mask,mask_onehot,length=data['im'],data['mask'],data['m'],data['length'] 37 | img=img.to(device).squeeze(0)[:length[0],:,:,:] 38 | mask=mask.to(device).squeeze(0)[:length[0],:,:] 39 | mask_onehot=mask_onehot.to(device).squeeze(0)[:length[0],:,:,:] 40 | pred_raw=model(img) 41 | pred=F.softmax(pred_raw,dim=1) 42 | 43 | tmp_loss=criterion(pred_raw,mask) 44 | loss+=tmp_loss.item() 45 | 46 | pred_np=pred.detach().cpu().numpy() 47 | mask_onehot_np=mask_onehot.detach().cpu().numpy() 48 | 49 | pred_np=np.moveaxis(pred_np,1,-1) 50 | mask_onehot_np=np.moveaxis(mask_onehot_np,1,-1) 51 | pred_onehot_np=np.zeros_like(pred_np) 52 | 53 | pred_np=np.argmax(pred_np,axis=-1) 54 | for i in range(args.input_size): 55 | for j in range(args.input_size): 56 | for k in range(pred_np.shape[0]): 57 | pred_onehot_np[k,i,j,pred_np[k,i,j]]=1 58 | for i in range(args.num_classes-1): 59 | total_dice[i]+=dice_coeff(pred_onehot_np[:,:,:,i:i+1],mask_onehot_np[:,:,:,i:i+1]) 60 | c+=1 61 | 62 | return total_dice/c,loss/c 63 | 64 | def unet_init(input_channels=1,num_classes=3,num_layers=6,heads=1,num_attention_blocks=1,base_num=64,pool_kernel_size=4,input_size=128,batch_size=20,pool_method="avgpool",is_pe_learnable=True): 65 | network=CrossSliceAttentionUNet(input_channels,num_classes,num_layers,heads,num_attention_blocks,base_num,(pool_kernel_size,pool_kernel_size),(input_size,input_size),batch_size,pool_method,is_pe_learnable) 66 | return network 67 | 68 | def unetplusplus_init(input_channels=1,num_classes=3,num_layers=5,heads=1,num_attention_blocks=1,base_num=64,pool_kernel_size=4,input_size=128,batch_size=20,pool_method="avgpool",is_pe_learnable=True): 69 | network=CrossSliceUNetPlusPlus(input_channels,num_classes,num_layers,heads,num_attention_blocks,base_num,(pool_kernel_size,pool_kernel_size),(input_size,input_size),batch_size,pool_method,is_pe_learnable) 70 | return network 71 | 72 | def train(args): 73 | device=args.device 74 | # epochs,current_epoch,mode,dataset,train_batch_size,heads,num_attention_blocks, data_path, save_path, learning_rate, try_id 75 | if args.mode=='unet': 76 | network=unet_init(heads=args.num_heads,num_attention_blocks=args.num_attention_blocks,pool_method=args.pool_method,is_pe_learnable=args.is_pe_learnable,batch_size=args.sequence_length,pool_kernel_size=args.pool_kernel_size,input_size=args.input_size) 77 | elif args.mode=='unetplusplus': 78 | network=unetplusplus_init(heads=args.num_heads,num_attention_blocks=args.num_attention_blocks,pool_method=args.pool_method,is_pe_learnable=args.is_pe_learnable,batch_size=args.sequence_length,pool_kernel_size=args.pool_kernel_size,input_size=args.input_size) 79 | #elif args.mode=='yk3d_softmax': 80 | # network=yk3d_init(mode='softmax',heads=args.num_heads, num_attention_blocks=args.num_attention_blocks) 81 | 82 | else: 83 | print('not implemented yet!') 84 | return 85 | 86 | now=datetime.now() 87 | current_time=now.strftime("%m-%d-%Y_%H:%M:%S") 88 | 89 | if not os.path.exists(args.save_path): 90 | os.mkdir(args.save_path) 91 | 92 | log=open(args.save_path+"model_training_log_id-{}_t-{}.txt".format(args.try_id,current_time),'w') 93 | state={k:v for k,v in args._get_kwargs()} 94 | print_log(state,log) # generate logs e.g. {'alpha': 1.0, 'batch_size': 4, 'belta': 1.0, ... 95 | 96 | print("Load Network Successfully!") 97 | model_parameters=filter(lambda p:p.requires_grad,network.parameters()) 98 | params=sum([np.prod(p.size()) for p in model_parameters]) 99 | print(params) 100 | network=network.to(device) 101 | 102 | if args.current_epoch>=0: 103 | network.load_state_dict(torch.load(args.save_path+str(args.current_epoch)+'.pt')) 104 | network=network.to(device) 105 | network.train() 106 | 107 | criterion=torch.nn.CrossEntropyLoss() 108 | 109 | if args.is_regularization: 110 | optimizer=torch.optim.Adam(network.parameters(),lr=args.lr,weight_decay=args.reg_val) 111 | else: 112 | optimizer=torch.optim.Adam(network.parameters(),lr=args.lr) 113 | 114 | train_dataloader=Internal3D(root=args.data_path,split='train') 115 | train_loader=torch.utils.data.DataLoader(train_dataloader,batch_size=args.batch_size,shuffle=True,num_workers=1,pin_memory=True,sampler=None,drop_last=True) 116 | val_dataloader=Internal3D(root=args.data_path,split='val') 117 | val_loader=torch.utils.data.DataLoader(val_dataloader,batch_size=1,shuffle=False,num_workers=1,pin_memory=True,sampler=None,drop_last=True) 118 | test_dataloader=Internal3D(root=args.data_path,split='test') 119 | test_loader=torch.utils.data.DataLoader(test_dataloader,batch_size=1,shuffle=False,num_workers=1,pin_memory=True,sampler=None,drop_last=True) 120 | loss_history_train=[] 121 | loss_history_val=[] 122 | dice_score_history_val=[] 123 | loss_history_test=[] 124 | dice_score_history_test=[] 125 | 126 | pz_val=0 127 | tz_val=0 128 | sum_val=0 129 | pz_test=0 130 | tz_test=0 131 | sum_test=0 132 | 133 | for epoch in range(args.current_epoch+1,args.epochs): 134 | 135 | for batch,data in enumerate(train_loader): 136 | img,mask,mask_onehot,length=data['im'],data['mask'],data['m'],data['length'] 137 | #img should be batch_size x sequence_length x channels (1) x height x width 138 | 139 | img=img.to(device) 140 | mask=mask.to(device) 141 | mask_onehot=mask_onehot.to(device) 142 | length=length.to(device) 143 | 144 | network.train() 145 | optimizer.zero_grad() 146 | loss=0 147 | for i in range(img.size(0)): 148 | pred=network(img[i,:length[i],:,:,:]) 149 | if 'unetplusplus' in args.mode: 150 | for p in pred: 151 | loss+=criterion(p,mask[i,:length[i],:,:]) 152 | else: 153 | loss+=criterion(pred,mask[i,:length[i],:,:]) 154 | 155 | # average of all cases 156 | loss/=img.size(0) 157 | 158 | if batch==0: 159 | loss_history_train.append(loss.item()) 160 | 161 | loss.backward() 162 | optimizer.step() 163 | 164 | # validation and test 165 | with torch.no_grad(): 166 | network.eval() 167 | # Validation loss and dice calculation 168 | dice_scores_val,loss_val=validate(network,val_loader,args) 169 | dice_score_history_val.append(dice_scores_val) 170 | loss_history_val.append(loss_val) 171 | 172 | if dice_scores_val[0]>tz_val: 173 | tz_val=dice_scores_val[0] 174 | torch.save(network.state_dict(),args.save_path+'tz.pt') 175 | print_log("----------Save model for TZ at: {:.4f}, {:.4f} ---------".format(dice_scores_val[0],dice_scores_val[1]),log) 176 | 177 | if dice_scores_val[1]>pz_val: 178 | pz_val=dice_scores_val[1] 179 | torch.save(network.state_dict(),args.save_path+'pz.pt') 180 | print_log("----------Save model for PZ at: {:.4f}, {:.4f} ---------".format(dice_scores_val[0],dice_scores_val[1]),log) 181 | 182 | if dice_scores_val[1]+dice_scores_val[0]>sum_val: 183 | sum_val=dice_scores_val[1]+dice_scores_val[0] 184 | torch.save(network.state_dict(),args.save_path+'sum.pt') 185 | print_log("----------Save model for Both at: {:.4f}, {:.4f} ---------".format(dice_scores_val[0],dice_scores_val[1]),log) 186 | 187 | #testing loss and dice calculation 188 | dice_scores_test,loss_test=validate(network,test_loader,args) 189 | dice_score_history_test.append(dice_scores_test) 190 | loss_history_test.append(loss_test) 191 | 192 | if dice_scores_test[0]>tz_test: 193 | tz_test=dice_scores_test[0] 194 | print_log("----------Test TZ max at: {:.4f}, {:.4f} ---------".format(dice_scores_test[0],dice_scores_test[1]),log) 195 | 196 | if dice_scores_test[1]>pz_test: 197 | pz_test=dice_scores_test[1] 198 | print_log("----------Test PZ max at: {:.4f}, {:.4f} ---------".format(dice_scores_test[0],dice_scores_test[1]),log) 199 | 200 | if dice_scores_test[1]+dice_scores_test[0]>sum_test: 201 | sum_test=dice_scores_test[1]+dice_scores_test[0] 202 | print_log("----------Test Both max at: {:.4f}, {:.4f}---------".format(dice_scores_test[0],dice_scores_test[1]),log) 203 | 204 | msg="Epoch:{}, LR:{:.6f}, Train-Loss:{:.4f}, Val-Dice:[{:.4f}, {:.4f}], Val-Loss:{:.4f}, Test-Dice:[{:.4f}, {:.4f}], Test-Loss:{:.4f}".format\ 205 | (epoch,optimizer.param_groups[0]['lr'],loss_history_train[-1],dice_scores_val[0],dice_scores_val[1],loss_history_val[-1],dice_scores_test[0],dice_scores_test[1],loss_history_test[-1]) 206 | print_log(msg,log) 207 | 208 | pickle.dump(loss_history_train,open(args.save_path+'loss_history_train.p','wb')) 209 | pickle.dump(loss_history_val,open(args.save_path+'loss_history_val.p','wb')) 210 | pickle.dump(dice_score_history_val,open(args.save_path+'val_dice_score_history.p','wb')) 211 | pickle.dump(loss_history_test,open(args.save_path+'loss_history_test.p','wb')) 212 | pickle.dump(dice_score_history_test,open(args.save_path+'test_dice_score_history.p','wb')) 213 | 214 | 215 | def main(): 216 | parser=argparse.ArgumentParser(description='After fixing the bug of the first block') 217 | parser.add_argument('--comments',default="Modified the code, add log and history of loss/dice & change BN to default momentum 0.1",type=str,help='Comment to which hyperparameter this group experiments aim to test') 218 | parser.add_argument('--epochs',default=150,type=int,help='number of total epochs to run') 219 | parser.add_argument('--current_epoch',default=-1,type=int,help='current starting epoch') 220 | parser.add_argument('--mode',default='unet',type=str,help='mode name to be used') 221 | parser.add_argument('--num_classes',default=3,type=int,help='TZ, PZ, background') 222 | parser.add_argument('--batch_size',default=2,type=int,help='current starting epoch') 223 | parser.add_argument('--num_heads',default=3,type=int,help='num of heads') 224 | parser.add_argument('--num_attention_blocks',default=2,type=int,help='num of attention blocks') 225 | parser.add_argument('--pool_kernel_size',default=4,type=int,help='pool kernel size') 226 | parser.add_argument('--input_size',default=128,type=int,help='input size') 227 | parser.add_argument('--lr',default=0.0001,type=float,help='learning rate') 228 | parser.add_argument('--data_path',default='../data/',type=str,help='dataset using') 229 | parser.add_argument('--save_path',default='',type=str,help='dataset using') 230 | parser.add_argument('--try_id',default='0',type=str,help='id of try') 231 | parser.add_argument('--network_dim',default="3D",type=str,help='2D or 3D in network using') 232 | parser.add_argument('--is_gamma',default=True,help='Whether add gamma transformation or not',action='store_false') 233 | parser.add_argument('--is_regularization',default=False,help='Whether we add regularization in optimizer or not',action='store_false') 234 | parser.add_argument('--reg_val',default=1e-5,type=float,help='How much regularization we want to add') 235 | parser.add_argument('--pool_method',default="avgpool",type=str,help='maxpool or avgpool for extracting features for self attention') 236 | parser.add_argument('--is_pe_learnable',default=True,help='Is the positional embedding learnable?',action='store_false') 237 | parser.add_argument('--sequence_length',default=20,type=int,help='length of the sequence') 238 | parser.add_argument('--device',default='cuda:0',type=str,help='device to use') 239 | 240 | args=parser.parse_args() 241 | 242 | ######################################################### 243 | now=datetime.now() 244 | current_time=now.strftime("%m-%d-%Y_%H:%M:%S") 245 | 246 | #args.save_path='{}_id:{}_{}_model/'.format(args.network_dim,args.try_id,current_time) 247 | args.save_path=args.mode+'_'+args.dataset+'_{}_model/'.format(current_time) 248 | train(args) 249 | 250 | 251 | if __name__=='__main__': 252 | main() --------------------------------------------------------------------------------