├── README.md ├── config └── train.cfg ├── loss.py ├── metrics.py ├── model.py ├── network.py ├── run.sh ├── test.py ├── tf_to_np.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # SIFA-pytorch 2 | This is a PyTorch implementation of SIFA for 'Unsupervised Bidirectional Cross-Modality Adaptation via Deeply Synergistic Image and Feature Alignment for Medical Image Segmentation.' 3 | 4 | 5 | ### 1. Dataset 6 | 7 | If you wish to utilize the provided UnpairedDataset, please prepare your dataset in the following format. Please note that each individual data unit should be stored in an NPZ file, where '[arr_0]' contains the image data, and '[arr_1]' contains the corresponding labels: 8 | ``` 9 | your/data_root/ 10 | source_domain/ 11 | s001.npz 12 | ['arr_0']:imgae_arr 13 | ['arr_1']:label_arr 14 | s002.npz 15 | ... 16 | 17 | target_domain/ 18 | t001.npz 19 | ['arr_0']:imgae_arr 20 | ['arr_1']:label_arr 21 | t002.npz 22 | ... 23 | test/ 24 | t101.npz 25 | ['arr_0']:imgae_arr 26 | ['arr_1']:label_arr 27 | t102.npz 28 | ... 29 | ``` 30 | 31 | ### 2. Perform experimental settings in ```config/train.cfg``` 32 | 33 | ### 3. Train SIFA 34 | ``` 35 | CUDA_LAUNCH_BLOCKING=0 python train.py 36 | ``` 37 | 38 | ### 4. Test SIFA 39 | ``` 40 | CUDA_LAUNCH_BLOCKING=0 python test.py 41 | ``` 42 | 43 | 44 | #### If you find the code useful, please consider citing the following article (with [code](https://github.com/HiLab-git/FPL-plus)): 45 | 46 | ```bibtex 47 | @article{wu2024fpl+, 48 | author={Wu, Jianghao and Guo, Dong and Wang, Guotai and Yue, Qiang and Yu, Huijun and Li, Kang and Zhang, Shaoting}, 49 | journal={IEEE Transactions on Medical Imaging}, 50 | title={FPL+: Filtered Pseudo Label-Based Unsupervised Cross-Modality Adaptation for 3D Medical Image Segmentation}, 51 | year={2024}, 52 | volume={43}, 53 | number={9}, 54 | pages={3098-3109} 55 | } 56 | 57 | 58 | ``` 59 | #### Furthermore, Source-Free Domain Adaptation is a more advanced domain adaptation task that does not require source domain data for adaptation. Please refer to the following paper (with [code](https://github.com/HiLab-git/UPL-SFDA)): 60 | ```bibtex 61 | @ARTICLE{10261231, 62 | author={Wu, Jianghao and Wang, Guotai and Gu, Ran and Lu, Tao and Chen, Yinan and Zhu, Wentao and Vercauteren, Tom and Ourselin, Sébastien and Zhang, Shaoting}, 63 | journal={IEEE Transactions on Medical Imaging}, 64 | title={UPL-SFDA: Uncertainty-Aware Pseudo Label Guided Source-Free Domain Adaptation for Medical Image Segmentation}, 65 | year={2023}, 66 | volume={42}, 67 | number={12}, 68 | pages={3932-3943} 69 | 70 | ``` 71 | -------------------------------------------------------------------------------- /config/train.cfg: -------------------------------------------------------------------------------- 1 | [train] 2 | exp_name = VS_1 3 | GPU = 0 4 | batch_size = 10 5 | num_epochs = 400 6 | num_classes = 2 7 | lr_seg = 0.00005 8 | lr = 0.0001 9 | skip = False 10 | 11 | A_path = ./data/source_path 12 | B_path = ./data/target_path 13 | 14 | [test] 15 | test_path = ./data/test_path 16 | GPU = 0 17 | num_classes = 2 18 | test_model = ./VS_1/model-140.pth 19 | batch_size = 1 20 | image_shape = 256,256 -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | #loss function for SIFA 2 | import torch 3 | from torch import nn, Tensor 4 | import torch.nn.functional as F 5 | import math 6 | import numpy as np 7 | 8 | 9 | 10 | def dice_loss(predict,target): 11 | target = target.float() 12 | smooth = 1e-4 13 | intersect = torch.sum(predict*target) 14 | dice = (2 * intersect + smooth)/(torch.sum(target)+torch.sum(predict*predict)+smooth) 15 | loss = 1.0 - dice 16 | return loss 17 | 18 | 19 | class DiceLoss(nn.Module): 20 | def __init__(self,n_classes): 21 | super().__init__() 22 | self.n_classes = n_classes 23 | 24 | def one_hot_encode(self,input_tensor): 25 | tensor_list = [] 26 | for i in range(self.n_classes): 27 | tmp = (input_tensor==i) * torch.ones_like(input_tensor) 28 | tensor_list.append(tmp) 29 | output_tensor = torch.cat(tensor_list,dim=1) 30 | return output_tensor.float() 31 | 32 | 33 | def forward(self,input,target,weight=None,softmax=True): 34 | if softmax: 35 | inputs = F.softmax(input,dim=1) 36 | target = self.one_hot_encode(target) 37 | if weight is None: 38 | weight = [1] * self.n_classes 39 | assert inputs.shape == target.shape,'size must match' 40 | class_wise_dice = [] 41 | loss = 0.0 42 | for i in range(self.n_classes): 43 | diceloss = dice_loss(inputs[:,i], target[:,i]) 44 | class_wise_dice.append(diceloss) 45 | loss += diceloss * weight[i] 46 | return loss/self.n_classes 47 | 48 | class WeightedCrossEntropyLoss(nn.Module): 49 | def __init__(self, num_classes): 50 | super().__init__() 51 | self.eps = 1e-4 52 | self.num_classes = num_classes 53 | 54 | def forward(self, predict, target): 55 | weight = [] 56 | for c in range(self.num_classes): 57 | weight_c = torch.sum(target == c).float() 58 | weight.append(weight_c) 59 | weight = torch.tensor(weight).to(target.device) 60 | weight = 1 - weight / (torch.sum(weight)) 61 | if len(target.shape) == len(predict.shape): 62 | assert target.shape[1] == 1 63 | target = target[:, 0] 64 | wce_loss = F.cross_entropy(predict, target.long(), weight) 65 | return wce_loss 66 | 67 | 68 | class DiceCeLoss(nn.Module): 69 | #predict : output of model (i.e. no softmax)[N,C,*] 70 | #target : gt of img [N,1,*] 71 | def __init__(self,num_classes,alpha=1.0): 72 | ''' 73 | calculate loss: 74 | celoss + alpha*celoss 75 | alpha : default is 1 76 | ''' 77 | super().__init__() 78 | self.alpha = alpha 79 | self.num_classes = num_classes 80 | self.diceloss = DiceLoss(self.num_classes) 81 | self.celoss = WeightedCrossEntropyLoss(self.num_classes) 82 | 83 | def forward(self,predict,label): 84 | #predict is output of the model, i.e. without softmax [N,C,*] 85 | #label is not one hot encoding [N,1,*] 86 | 87 | diceloss = self.diceloss(predict,label) 88 | celoss = self.celoss(predict,label) 89 | loss = celoss + self.alpha * diceloss 90 | return loss 91 | 92 | class NCC: 93 | """ 94 | Local (over window) normalized cross correlation loss. 95 | """ 96 | 97 | def __init__(self, win=None): 98 | self.win = win 99 | 100 | def loss(self, y_true, y_pred): 101 | 102 | Ii = y_true 103 | Ji = y_pred 104 | 105 | # get dimension of volume 106 | # assumes Ii, Ji are sized [batch_size, *vol_shape, nb_feats] 107 | ndims = len(list(Ii.size())) - 2 108 | assert ndims in [1, 2, 3], "volumes should be 1 to 3 dimensions. found: %d" % ndims 109 | 110 | # set window size 111 | win = [9] * ndims if self.win is None else self.win 112 | 113 | # compute filters 114 | sum_filt = torch.ones([1, 1, *win]).to("cuda") 115 | 116 | pad_no = math.floor(win[0] / 2) 117 | 118 | if ndims == 1: 119 | stride = (1) 120 | padding = (pad_no) 121 | elif ndims == 2: 122 | stride = (1, 1) 123 | padding = (pad_no, pad_no) 124 | else: 125 | stride = (1, 1, 1) 126 | padding = (pad_no, pad_no, pad_no) 127 | 128 | # get convolution function 129 | conv_fn = getattr(F, 'conv%dd' % ndims) 130 | 131 | # compute CC squares 132 | I2 = Ii * Ii 133 | J2 = Ji * Ji 134 | IJ = Ii * Ji 135 | 136 | I_sum = conv_fn(Ii, sum_filt, stride=stride, padding=padding) 137 | J_sum = conv_fn(Ji, sum_filt, stride=stride, padding=padding) 138 | I2_sum = conv_fn(I2, sum_filt, stride=stride, padding=padding) 139 | J2_sum = conv_fn(J2, sum_filt, stride=stride, padding=padding) 140 | IJ_sum = conv_fn(IJ, sum_filt, stride=stride, padding=padding) 141 | 142 | win_size = np.prod(win) 143 | u_I = I_sum / win_size 144 | u_J = J_sum / win_size 145 | 146 | cross = IJ_sum - u_J * I_sum - u_I * J_sum + u_I * u_J * win_size 147 | I_var = I2_sum - 2 * u_I * I_sum + u_I * u_I * win_size 148 | J_var = J2_sum - 2 * u_J * J_sum + u_J * u_J * win_size 149 | 150 | cc = cross * cross / (I_var * J_var + 1e-5) 151 | 152 | return -torch.mean(cc) -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from medpy import metric 3 | 4 | def dice_eval(predict,label,num_classes): 5 | #Computer Dice coefficient 6 | dice = np.zeros(num_classes) 7 | eps = 1e-7 8 | for c in range(num_classes): 9 | inter = 2.0 * (np.sum((predict==c)*(label==c),dtype=np.float32)) 10 | p_sum = np.sum(predict==c,dtype=np.float32) 11 | gt_sum = np.sum(label==c,dtype=np.float32) 12 | dice[c] = (inter+eps)/(p_sum+gt_sum+eps) 13 | return dice[1:] 14 | 15 | def assd_eval(predict,label,num_classes): 16 | #Average Symmetric Surface Distance (ASSD) 17 | assd_all = np.zeros(num_classes) 18 | for c in range(num_classes): 19 | reference = (label==c) * 1 20 | result = (predict==c) * 1 21 | assd_all[c] = metric.binary.assd(result,reference) 22 | return assd_all[1:] 23 | 24 | def create_visual_anno(anno): 25 | assert np.max(anno) < 7 # only 7 classes are supported, add new color in label2color_dict 26 | label2color_dict = { 27 | 0: [0, 0, 0], 28 | 1: [0,0,255], 29 | 2: [0, 255, 0], 30 | 3: [0, 0, 255], 31 | 4: [255, 215, 0], 32 | 5: [160, 32, 100], 33 | 6: [255, 64, 64], 34 | 7: [139, 69, 19], 35 | } 36 | # visualize 37 | visual_anno = np.zeros((anno.shape[0], anno.shape[1], 3), dtype=np.uint8) 38 | for i in range(visual_anno.shape[0]): # i for h 39 | for j in range(visual_anno.shape[1]): 40 | color = label2color_dict[anno[i, j]] 41 | visual_anno[i, j, 0] = color[0] 42 | visual_anno[i, j, 1] = color[1] 43 | visual_anno[i, j, 2] = color[2] 44 | 45 | return visual_anno -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | "Trainer for SIFA" 2 | import torch.nn as nn 3 | import torch 4 | from network import Dis, Dis_aux, G, Encoder, Decoder, Seg,init_weights,ImagePool,get_scheduler 5 | from loss import DiceCeLoss 6 | from torchvision.utils import save_image 7 | import os 8 | 9 | torch.autograd.set_detect_anomaly(True) 10 | 11 | def denorm(img): 12 | #convert [-1,1] to [0,1] 13 | #to use torchvision.utils.save_image 14 | return (img+1.0)/2 15 | 16 | 17 | 18 | class SIFA(nn.Module): 19 | def __init__(self,params): 20 | super().__init__() 21 | num_classes = params['train']['num_classes'] 22 | lr_seg = params['train']['lr_seg'] 23 | lr = params['train']['lr'] 24 | # self.skip = True 25 | self.skip = params['train']['skip'] 26 | #network 27 | self.gen = G(skip=self.skip) 28 | self.enc = Encoder() 29 | self.dec = Decoder(skip=self.skip) 30 | self.seg = Seg(num_classes) 31 | self.disA = Dis_aux() 32 | self.disB = Dis() 33 | self.disSeg = Dis(num_classes) 34 | #optimizer 35 | self.gen_opt = torch.optim.Adam(self.gen.parameters(),lr=lr,betas=(0.5,0.999)) 36 | self.enc_opt = torch.optim.Adam(self.enc.parameters(),lr=lr,betas=(0.5,0.999),weight_decay=0.0001) 37 | self.dec_opt = torch.optim.Adam(self.dec.parameters(),lr=lr,betas=(0.5,0.999)) 38 | self.seg_opt = torch.optim.Adam(self.seg.parameters(),lr=lr_seg,weight_decay=0.0001) 39 | self.disA_opt = torch.optim.Adam(self.disA.parameters(),lr=lr,betas=(0.5,0.999)) 40 | self.disB_opt = torch.optim.Adam(self.disB.parameters(),lr=lr,betas=(0.5,0.999)) 41 | self.disSeg_opt = torch.optim.Adam(self.disSeg.parameters(),lr=lr,betas=(0.5,0.999)) 42 | # fake image pool 43 | self.fakeA_pool = ImagePool() 44 | self.fakeB_pool = ImagePool() 45 | # lr update 46 | self.seg_opt_sch = get_scheduler(self.seg_opt) 47 | #loss 48 | self.segloss = DiceCeLoss(num_classes) 49 | self.criterionL2 = nn.MSELoss() 50 | self.criterionL1 = nn.L1Loss() 51 | 52 | def initialize(self): 53 | init_weights(self.gen) 54 | init_weights(self.dec) 55 | init_weights(self.enc) 56 | init_weights(self.seg) 57 | init_weights(self.disA) 58 | init_weights(self.disB) 59 | init_weights(self.disSeg) 60 | 61 | def forward(self): 62 | self.fakeB = self.gen(self.realA, self.realA) 63 | self.latent_realB = self.enc(self.realB) 64 | self.latent_fakeB = self.enc(self.fakeB) 65 | self.fakeA = self.dec(self.latent_realB, self.realB) 66 | self.pred_mask_b = self.seg(self.latent_realB) 67 | self.pred_mask_fake_b = self.seg(self.latent_fakeB) 68 | self.cycleA = self.dec(self.latent_fakeB, self.fakeB) 69 | self.cycleB = self.gen(self.fakeA, self.fakeA) 70 | 71 | def backward_D(self, netD, real, fake,aux=False): 72 | if aux: 73 | pred_real = netD.forward_aux(real) 74 | pred_fake = netD.forward_aux(fake.detach()) 75 | else: 76 | pred_real = netD(real) 77 | pred_fake = netD(fake.detach()) 78 | all1 = torch.ones_like(pred_real) 79 | all0 = torch.zeros_like(pred_fake) 80 | loss_real = self.criterionL2(pred_real, all1) 81 | loss_fake = self.criterionL2(pred_fake, all0) 82 | loss_D = (loss_real + loss_fake) * 0.5 83 | return loss_D 84 | 85 | 86 | def backward_G(self, netD, fake,aux=False): 87 | if aux: 88 | out = netD.forward_aux(fake) 89 | else: 90 | out = netD(fake) 91 | all1 = torch.ones_like(out) 92 | loss_G = self.criterionL2(out, all1) 93 | return loss_G 94 | 95 | def update_lr(self): 96 | self.seg_opt_sch.step() 97 | 98 | def update_GAN(self,imagesa,imagesb): 99 | self.realA = imagesa 100 | self.realB = imagesb 101 | self.forward() 102 | #update DisA 103 | self.disA_opt.zero_grad() 104 | self.fakeA_from_pool = self.fakeA_pool.query(self.fakeA) 105 | loss_disA = self.backward_D(self.disA,self.realA,self.fakeA_from_pool) 106 | loss_disA_aux = self.backward_D(self.disA,self.cycleA.detach(),self.fakeA_from_pool,aux=True) 107 | loss_disA = loss_disA + loss_disA_aux 108 | loss_disA.backward() 109 | self.disA_opt.step() 110 | #update disB 111 | self.disB_opt.zero_grad() 112 | self.fakeB_from_pool = self.fakeB_pool.query(self.fakeB) 113 | loss_disB = self.backward_D(self.disB,self.realB,self.fakeB_from_pool) 114 | loss_disB.backward() 115 | self.disB_opt.step() 116 | #update DisSeg 117 | self.disSeg_opt.zero_grad() 118 | loss_disSeg = self.backward_D(self.disSeg,self.pred_mask_fake_b.detach(),self.pred_mask_b) 119 | loss_disSeg.backward() 120 | self.disSeg_opt.step() 121 | #update G 122 | self.gen_opt.zero_grad() 123 | self.dec_opt.zero_grad() 124 | loss_cycleA = self.criterionL1(self.realA,self.cycleA) * 10.0 125 | loss_cycleB = self.criterionL1(self.realB,self.cycleB) * 10.0 126 | g_loss = loss_cycleA + loss_cycleB + self.backward_G(self.disB,self.fakeB) + self.backward_G(self.disA,self.fakeA) 127 | g_loss.backward() 128 | self.gen_opt.step() 129 | self.dec_opt.step() 130 | 131 | self.loss_cyclea = loss_cycleA.item() 132 | self.loss_cycleb = loss_cycleB.item() 133 | 134 | def update_seg(self,imagesa,imagesb,labelsa): 135 | self.realA = imagesa 136 | self.realB = imagesb 137 | self.labelA = labelsa 138 | self.forward() 139 | #update encoder and seg 140 | self.enc_opt.zero_grad() 141 | self.seg_opt.zero_grad() 142 | seg_loss_B = self.segloss(self.pred_mask_fake_b,self.labelA) + \ 143 | self.criterionL1(self.realA,self.cycleA) + self.criterionL1(self.realB, self.cycleB) + \ 144 | 0.1 * self.backward_G(self.disA,self.fakeA) + 0.1 * self.backward_G(self.disSeg,self.pred_mask_b) + 0.1 * self.backward_G(self.disA,self.fakeA,aux=True) 145 | 146 | seg_loss_B.backward() 147 | self.enc_opt.step() 148 | self.seg_opt.step() 149 | self.loss_seg = self.segloss(self.pred_mask_fake_b,self.labelA).item() 150 | 151 | def test_seg(self,imagesb): 152 | #for test 153 | content_b = self.enc(imagesb) 154 | pred_mask_b = self.seg(content_b) 155 | return pred_mask_b 156 | 157 | def sample_image(self, epoch, exp_name): 158 | sample_image_dir = "sample_image/" + str(exp_name) 159 | if(not os.path.exists(sample_image_dir)): 160 | os.mkdir(sample_image_dir) 161 | print(sample_image_dir,epoch+1) 162 | save_image(denorm(self.realA), '{}/realA-epoch-{}.jpg'.format(sample_image_dir, epoch + 1)) 163 | save_image(denorm(self.realB), '{}/realB-epoch-{}.jpg'.format(sample_image_dir, epoch + 1)) 164 | save_image(denorm(self.fakeA), '{}/fakeA-epoch-{}.jpg'.format(sample_image_dir, epoch + 1)) 165 | save_image(denorm(self.fakeB), '{}/fakeB-epoch-{}.jpg'.format(sample_image_dir, epoch + 1)) 166 | save_image(denorm(self.cycleA), '{}/cycleA-epoch-{}.jpg'.format(sample_image_dir, epoch + 1)) 167 | save_image(denorm(self.cycleB), '{}/cycleB-epoch-{}.jpg'.format(sample_image_dir, epoch + 1)) 168 | 169 | def print_loss(self): 170 | print('---------------------') 171 | print('loss cycle A:', self.loss_cyclea) 172 | print('loss cycle B:', self.loss_cycleb) 173 | print('loss seg:', self.loss_seg) 174 | return self.loss_cyclea,self.loss_cycleb,self.loss_seg 175 | 176 | 177 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | """network for SIFA""" 2 | import torch.nn as nn 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.nn import init 6 | from torch.optim import lr_scheduler 7 | import random 8 | 9 | 10 | class InsResBlock(nn.Module): 11 | def __init__(self, in_features): 12 | super().__init__() 13 | self.layer = [] 14 | self.layer += [nn.Conv2d(in_features, in_features, 3, 1, 1)] 15 | self.layer += [nn.InstanceNorm2d(in_features)] 16 | self.layer += [nn.ReLU(inplace=True)] 17 | self.layer += [nn.Conv2d(in_features, in_features, 3, 1, 1)] 18 | self.layer += [nn.InstanceNorm2d(in_features)] 19 | self.layer = nn.Sequential(*self.layer) 20 | 21 | def forward(self, x): 22 | res = x 23 | x = self.layer(x) 24 | out = F.relu(x + res, inplace=True) 25 | return out 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | def __init__(self, c_in, c_out, k_size, stride, pad, norm_type=None, p_dropout=None, do_relu=True): 30 | super().__init__() 31 | self.layer = [nn.Conv2d(c_in, c_out, k_size, stride=stride, padding=pad)] 32 | if p_dropout is not None: 33 | self.layer += [nn.Dropout(p=p_dropout)] 34 | if norm_type is not None: 35 | if norm_type == 'BN': 36 | self.layer += [nn.BatchNorm2d(c_out)] 37 | if norm_type == 'IN': 38 | self.layer += [nn.InstanceNorm2d(c_out)] 39 | if do_relu: 40 | self.layer += [nn.ReLU(inplace=True)] 41 | self.layer = nn.Sequential(*self.layer) 42 | 43 | def forward(self, x): 44 | return self.layer(x) 45 | 46 | 47 | class InsDeconv(nn.Module): 48 | def __init__(self, c_in, c_out): 49 | super().__init__() 50 | self.layer = [nn.ConvTranspose2d(c_in, c_out, kernel_size=3, stride=2, padding=1, output_padding=1)] 51 | self.layer += [nn.InstanceNorm2d(c_out)] 52 | self.layer += [nn.ReLU(inplace=True)] 53 | self.layer = nn.Sequential(*self.layer) 54 | 55 | def forward(self, x): 56 | return self.layer(x) 57 | 58 | 59 | class DilateBlock(nn.Module): 60 | def __init__(self, c_in, c_out, p_dropout=0.25): 61 | super().__init__() 62 | self.layer = [nn.Conv2d(c_in, c_out, kernel_size=3, stride=1, padding=2, dilation=2)] 63 | self.layer += [nn.Dropout(p_dropout)] 64 | self.layer += [nn.BatchNorm2d(c_out)] 65 | self.layer += [nn.ReLU(inplace=True)] 66 | self.layer += [nn.Conv2d(c_out, c_out, kernel_size=3, stride=1, padding=2, dilation=2)] 67 | self.layer += [nn.Dropout(p_dropout)] 68 | self.layer += [nn.BatchNorm2d(c_out)] 69 | self.layer = nn.Sequential(*self.layer) 70 | self.relu = nn.ReLU(inplace=True) 71 | 72 | def forward(self, x): 73 | return self.relu(x + self.layer(x)) 74 | 75 | 76 | class BNResBlock(nn.Module): 77 | def __init__(self, c_in, c_out, p_dropout=0.25): 78 | super().__init__() 79 | # c_in==c_out or c_out = 2*c_in 80 | self.c_in = c_in 81 | self.c_out = c_out 82 | p = p_dropout 83 | self.layer = [nn.Conv2d(c_in, c_out, 3, 1, 1)] 84 | self.layer += [nn.Dropout(p)] 85 | self.layer += [nn.BatchNorm2d(c_out)] 86 | self.layer += [nn.ReLU(inplace=True)] 87 | self.layer += [nn.Conv2d(c_out, c_out, 3, 1, 1)] 88 | self.layer += [nn.Dropout(p)] 89 | self.layer += [nn.BatchNorm2d(c_out)] 90 | self.layer = nn.Sequential(*self.layer) 91 | 92 | def expand_channels(self, x): 93 | # expand channels [B,C,H,W] to [B,2C,H,W] 94 | all0 = torch.zeros_like(x) 95 | z1, z2 = torch.split(all0, x.size(1) // 2, dim=1) 96 | return torch.cat((z1, x, z2), dim=1) 97 | 98 | def forward(self, x): 99 | if self.c_in == self.c_out: 100 | return F.relu(x + self.layer(x), inplace=True) 101 | else: 102 | out0 = self.layer(x) 103 | out1 = self.expand_channels(x) 104 | return F.relu(out0 + out1, inplace=True) 105 | 106 | 107 | 108 | def get_scheduler(optimizer): 109 | scheduler = lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9) 110 | return scheduler 111 | 112 | 113 | class G(nn.Module): 114 | """G from cyclegan""" 115 | def __init__(self, inputdim=1, skip=True): 116 | super().__init__() 117 | self._skip = skip 118 | self.model = [BasicBlock(inputdim, 32, 7, 1, 3, 'IN')] 119 | self.model += [BasicBlock(32, 64, 3, 2, 1, 'IN')] 120 | self.model += [BasicBlock(64, 128, 3, 2, 1, 'IN')] 121 | for _ in range(9): 122 | self.model += [InsResBlock(128)] 123 | self.model += [InsDeconv(128, 64)] 124 | self.model += [InsDeconv(64, 32)] 125 | self.model += [BasicBlock(32, 1, 7, 1, 3, do_relu=False)] 126 | self.model = nn.Sequential(*self.model) 127 | self.tanh = nn.Tanh() 128 | 129 | def forward(self, inputgen, inputimg): 130 | out = self.model(inputgen) 131 | if self._skip is True: 132 | out = self.tanh(out + inputimg) 133 | else: 134 | out = self.tanh(out) 135 | return out 136 | 137 | 138 | class Encoder(nn.Module): 139 | def __init__(self, inputdim=1, p_dropout=0.25): 140 | super().__init__() 141 | self.model = nn.Sequential( 142 | BasicBlock(inputdim, 16, 7, 1, 3, norm_type='BN', p_dropout=0.25), 143 | BNResBlock(16, 16), 144 | nn.MaxPool2d(kernel_size=2), 145 | BNResBlock(16, 32), 146 | nn.MaxPool2d(kernel_size=2), 147 | BNResBlock(32, 64), 148 | BNResBlock(64, 64), 149 | nn.MaxPool2d(kernel_size=2), 150 | BNResBlock(64, 128), 151 | BNResBlock(128, 128), 152 | BNResBlock(128, 256), 153 | BNResBlock(256, 256), 154 | BNResBlock(256, 256), 155 | BNResBlock(256, 256), 156 | BNResBlock(256, 512), 157 | BNResBlock(512, 512), 158 | DilateBlock(512, 512), 159 | DilateBlock(512, 512), 160 | BasicBlock(512, 512, 3, 1, 1, norm_type='BN', p_dropout=0.25), 161 | BasicBlock(512, 512, 3, 1, 1, norm_type='BN', p_dropout=0.25) 162 | ) 163 | 164 | def forward(self, x): 165 | out = self.model(x) 166 | return out 167 | 168 | 169 | class Decoder(nn.Module): 170 | def __init__(self, skip=True): 171 | super().__init__() 172 | self._skip = skip 173 | self.model = [BasicBlock(512, 128, 3, 1, 1, 'IN')] 174 | for _ in range(4): 175 | self.model += [InsResBlock(128)] 176 | self.model += [InsDeconv(128, 64)] 177 | self.model += [InsDeconv(64, 64)] 178 | self.model += [InsDeconv(64, 32)] 179 | self.model += [BasicBlock(32, 1, 7, 1, 3, do_relu=False)] 180 | self.model = nn.Sequential(*self.model) 181 | self.tanh = nn.Tanh() 182 | 183 | def forward(self, inputde, inputimg): 184 | out = self.model(inputde) 185 | if self._skip is True: 186 | out = self.tanh(out + inputimg) 187 | else: 188 | out = self.tanh(out) 189 | return out 190 | 191 | 192 | class Seg(nn.Module): 193 | def __init__(self, num_classes): 194 | super().__init__() 195 | self.model = nn.Sequential(BasicBlock(512, num_classes, 1, 1, 0, do_relu=False), 196 | nn.Upsample(scale_factor=8,mode='bilinear')) 197 | 198 | def forward(self, x): 199 | x = self.model(x) 200 | return x 201 | 202 | 203 | class Dis(nn.Module): 204 | def __init__(self, input_nc=1): 205 | super().__init__() 206 | 207 | model = [nn.Conv2d(input_nc, 64, 4, stride=2, padding=1), 208 | nn.LeakyReLU(0.2, inplace=True)] 209 | 210 | model += [nn.Conv2d(64, 128, 4, stride=2, padding=1), 211 | nn.InstanceNorm2d(128), 212 | nn.LeakyReLU(0.2, inplace=True)] 213 | 214 | model += [nn.Conv2d(128, 256, 4, stride=2, padding=1), 215 | nn.InstanceNorm2d(256), 216 | nn.LeakyReLU(0.2, inplace=True)] 217 | 218 | model += [nn.Conv2d(256, 512, 4, padding=1), 219 | nn.InstanceNorm2d(512), 220 | nn.LeakyReLU(0.2, inplace=True)] 221 | 222 | model += [nn.Conv2d(512, 1, 4, padding=1)] 223 | 224 | self.model = nn.Sequential(*model) 225 | 226 | def forward(self, x): 227 | x = self.model(x) 228 | return x 229 | 230 | 231 | class Dis_aux(nn.Module): 232 | def __init__(self, input_nc=1): 233 | super().__init__() 234 | 235 | model = [nn.Conv2d(input_nc, 64, 4, stride=2, padding=1), 236 | nn.LeakyReLU(0.2, inplace=True)] 237 | 238 | model += [nn.Conv2d(64, 128, 4, stride=2, padding=1), 239 | nn.InstanceNorm2d(128), 240 | nn.LeakyReLU(0.2, inplace=True)] 241 | 242 | model += [nn.Conv2d(128, 256, 4, stride=2, padding=1), 243 | nn.InstanceNorm2d(256), 244 | nn.LeakyReLU(0.2, inplace=True)] 245 | 246 | model += [nn.Conv2d(256, 512, 4, padding=1), 247 | nn.InstanceNorm2d(512), 248 | nn.LeakyReLU(0.2, inplace=True)] 249 | 250 | self.share = nn.Sequential(*model) 251 | self.model = nn.Sequential(nn.Conv2d(512, 1, 4, padding=1)) 252 | self.model_aux = nn.Sequential(nn.Conv2d(512, 1, 4, padding=1)) 253 | 254 | 255 | def forward(self, x): 256 | x = self.share(x) 257 | x = self.model(x) 258 | return x 259 | 260 | def forward_aux(self,x): 261 | x = self.share(x) 262 | x = self.model_aux(x) 263 | return x 264 | 265 | 266 | def init_weights(net, init_type='normal', init_gain=0.02): 267 | """Initialize network weights. 268 | 269 | Parameters: 270 | net (network) -- network to be initialized 271 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 272 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 273 | 274 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 275 | work better for some applications. Feel free to try yourself. 276 | """ 277 | def init_func(m): # define the initialization function 278 | classname = m.__class__.__name__ 279 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 280 | if init_type == 'normal': 281 | init.normal_(m.weight.data, 0.0, init_gain) 282 | elif init_type == 'xavier': 283 | init.xavier_normal_(m.weight.data, gain=init_gain) 284 | elif init_type == 'kaiming': 285 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 286 | elif init_type == 'orthogonal': 287 | init.orthogonal_(m.weight.data, gain=init_gain) 288 | else: 289 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 290 | if hasattr(m, 'bias') and m.bias is not None: 291 | init.constant_(m.bias.data, 0.0) 292 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 293 | init.normal_(m.weight.data, 1.0, init_gain) 294 | init.constant_(m.bias.data, 0.0) 295 | 296 | print('initialize network with %s' % init_type) 297 | net.apply(init_func) # apply the initialization function 298 | 299 | 300 | class ImagePool(): 301 | """This class implements an image buffer that stores previously generated images. 302 | 303 | This buffer enables us to update discriminators using a history of generated images 304 | rather than the ones produced by the latest generators. 305 | """ 306 | 307 | def __init__(self, pool_size=50): 308 | """Initialize the ImagePool class 309 | 310 | Parameters: 311 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 312 | """ 313 | self.pool_size = pool_size 314 | if self.pool_size > 0: # create an empty pool 315 | self.num_imgs = 0 316 | self.images = [] 317 | 318 | def query(self, images): 319 | """Return an image from the pool. 320 | 321 | Parameters: 322 | images: the latest generated images from the generator 323 | 324 | Returns images from the buffer. 325 | 326 | By 50/100, the buffer will return input images. 327 | By 50/100, the buffer will return images previously stored in the buffer, 328 | and insert the current images to the buffer. 329 | """ 330 | if self.pool_size == 0: # if the buffer size is 0, do nothing 331 | return images 332 | return_images = [] 333 | for image in images: 334 | image = torch.unsqueeze(image.data, 0) 335 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer 336 | self.num_imgs = self.num_imgs + 1 337 | self.images.append(image) 338 | return_images.append(image) 339 | else: 340 | p = random.uniform(0, 1) 341 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer 342 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 343 | tmp = self.images[random_id].clone() 344 | self.images[random_id] = image 345 | return_images.append(tmp) 346 | else: # by another 50% chance, the buffer will return the current image 347 | return_images.append(image) 348 | return_images = torch.cat(return_images, 0) # collect all the images and return 349 | return return_images 350 | 351 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | CUDA_LAUNCH_BLOCKING=1 python train.py 2 | CUDA_LAUNCH_BLOCKING=1 python test.py -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #Evaluate of SIFA 2 | import numpy as np 3 | import matplotlib 4 | matplotlib.use('Agg') 5 | import matplotlib.pyplot as plt 6 | from torch.utils.data import DataLoader 7 | import torch 8 | from model import SIFA 9 | import yaml 10 | from utils import SingleDataset 11 | from metrics import dice_eval,assd_eval,create_visual_anno 12 | import cv2 13 | from utils import parse_config 14 | import os 15 | 16 | config = "/data2/jianghao/Two/SIFA/config/train.cfg" 17 | config = parse_config(config) 18 | exp_name = config['train']['exp_name'] 19 | 20 | def norm_01(image): 21 | mn = np.min(image) 22 | mx = np.max(image) 23 | image = (image-mn)/(mx-mn).astype(np.float32) 24 | return image 25 | 26 | def save_img(image): 27 | image = norm_01(image) 28 | image = (image*255).astype(np.uint8) 29 | return image 30 | 31 | 32 | device = torch.device('cuda:{}'.format(config['test']['gpu'])) 33 | test_path = config['test']['test_path'] 34 | num_classes = config['test']['num_classes'] 35 | sifa_model = SIFA(config).to(device) 36 | sifa_model.load_state_dict(torch.load('{}'.format(config['test']['test_model']))) 37 | sifa_model.eval() 38 | #test dataset 39 | test_dataset = SingleDataset(test_path) 40 | batch_size = config['test']['batch_size'] 41 | test_loader = DataLoader(test_dataset,batch_size,shuffle=False) 42 | 43 | #test 44 | all_batch_dice = [] 45 | all_batch_assd = [] 46 | with torch.no_grad(): 47 | for it,(xt,xt_label) in enumerate(test_loader): 48 | xt = xt.to(device) 49 | xt_label = xt_label.numpy().squeeze().astype(np.uint8) 50 | output = sifa_model.test_seg(xt).detach() 51 | output = output.squeeze(0) 52 | output = torch.argmax(output,dim=0) 53 | output = output.cpu().numpy() 54 | 55 | 56 | xt = xt.detach().cpu().numpy().squeeze() 57 | gt = xt_label.reshape(256,256).astype(np.uint8) 58 | output = output.squeeze() 59 | xt = save_img(xt) 60 | 61 | output_vis = create_visual_anno(output) 62 | gt_vis = create_visual_anno(gt) 63 | results = "results/" + str(exp_name) 64 | if(not os.path.exists(results)): 65 | os.mkdir(results) 66 | cv2.imwrite('{}/xt-{}.jpg'.format(results, it+1),xt) 67 | cv2.imwrite('{}/gt-{}.jpg'.format(results, it+1),gt_vis) 68 | cv2.imwrite('{}/output-{}.jpg'.format(results, it+1),output_vis) 69 | 70 | 71 | 72 | one_case_dice = dice_eval(output,xt_label,num_classes) * 100 73 | #print('{:.4f} th case dice MYO:{:.4f} LV:{:.4f} RV:{:.4f}'.format(it+1,one_case_dice[0],one_case_dice[1],one_case_dice[2])) 74 | #dicefile.write('file:{},{} th case dice:{}\n'.format(filename,it+1,one_case_dice)) 75 | all_batch_dice += [one_case_dice] 76 | try: 77 | one_case_assd = assd_eval(output,xt_label,num_classes) 78 | except: 79 | continue 80 | all_batch_assd.append(one_case_assd) 81 | 82 | 83 | 84 | all_batch_dice = np.array(all_batch_dice) 85 | all_batch_assd = np.array(all_batch_assd) 86 | mean_dice = np.mean(all_batch_dice,axis=0) 87 | std_dice = np.std(all_batch_dice,axis=0) 88 | mean_assd = np.mean(all_batch_assd,axis=0) 89 | print(all_batch_assd) 90 | std_assd = np.std(all_batch_assd,axis=0) 91 | print('-----------') 92 | print('MYO||LV||RV') 93 | print('Dice mean:{}'.format(mean_dice)) 94 | print('Dice std:{}'.format(std_dice)) 95 | print('total mean dice:',np.mean(mean_dice)) 96 | print('ASSD mean:{}'.format(mean_assd)) 97 | print('ASSD std:{}'.format(std_assd)) 98 | print('total mean assd:',np.mean(mean_assd)) 99 | print('-----------') 100 | -------------------------------------------------------------------------------- /tf_to_np.py: -------------------------------------------------------------------------------- 1 | #Use tensorflow2 2 | # convert tf record to npz for SIFAdata 3 | 4 | import tensorflow as tf 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import os 8 | 9 | ##tfrecord to numpy 10 | def decode_tfrecords(example): 11 | features = { 12 | 'data_vol':tf.io.FixedLenFeature([],tf.string), 13 | 'label_vol':tf.io.FixedLenFeature([],tf.string) 14 | } 15 | feature_dict = tf.io.parse_single_example(example,features) 16 | img = tf.io.decode_raw(feature_dict['data_vol'],out_type = tf.float32) 17 | lab = tf.io.decode_raw(feature_dict['label_vol'],out_type = tf.float32) 18 | img = tf.reshape(img,[256,256,3]) 19 | lab = tf.reshape(lab,[256,256,3]) 20 | return img,lab 21 | ''' 22 | #one case to test 23 | files = 'ct_train_slice169.tfrecords' 24 | rawdata = tf.data.TFRecordDataset(files) 25 | dataset = rawdata.map(decode_tfrecords) 26 | itera = tf.compat.v1.data.make_one_shot_iterator(dataset) 27 | img,lab = itera.get_next() 28 | img = img.numpy() 29 | lab = lab.numpy() 30 | print(img.shape) 31 | print(lab.shape) 32 | print(np.unique(lab)) 33 | img = img[:,:,1] 34 | lab = lab[:,:,1] 35 | plt.figure() 36 | plt.imshow(img,cmap='gray') 37 | plt.figure() 38 | plt.imshow(lab) 39 | plt.show() 40 | ''' 41 | ##convert tfrecord to npz 42 | #for CT 43 | path = 'SIFAdata/train/ct_train/' 44 | pathlist = os.listdir(path) 45 | for i in range(len(pathlist)): 46 | files = pathlist[i] 47 | name = files.split('.')[0] 48 | rawdata = tf.data.TFRecordDataset(path+files) 49 | dataset = rawdata.map(decode_tfrecords) 50 | itera = tf.compat.v1.data.make_one_shot_iterator(dataset) 51 | img,lab = itera.get_next() 52 | img = img.numpy() 53 | lab = lab.numpy() 54 | np.savez('UDA/data/SIFAdata/ct_train/'+name,img,lab) 55 | 56 | #for MR 57 | path = 'SIFAdata/train/mr_train/' 58 | pathlist = os.listdir(path) 59 | for i in range(len(pathlist)): 60 | files = pathlist[i] 61 | name = files.split('.')[0] 62 | rawdata = tf.data.TFRecordDataset(path+files) 63 | dataset = rawdata.map(decode_tfrecords) 64 | itera = tf.compat.v1.data.make_one_shot_iterator(dataset) 65 | img,lab = itera.get_next() 66 | img = img.numpy() 67 | lab = lab.numpy() 68 | np.savez('UDA/data/SIFAdata/mr_train/'+name,img,lab) 69 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from utils import UnpairedDataset, parse_config,get_config, set_random 2 | import yaml 3 | import matplotlib.pyplot as plt 4 | from model import SIFA 5 | from torch.utils.data import DataLoader 6 | import torch 7 | import numpy as np 8 | import matplotlib 9 | import os 10 | import configparser 11 | matplotlib.use('Agg') 12 | 13 | # train 14 | def train(): 15 | # load config 16 | config = "./SIFA/config/train.cfg" 17 | config = parse_config(config) 18 | # load data 19 | print(config) 20 | A_path = config['train']['a_path'] 21 | B_path = config['train']['b_path'] 22 | batch_size = config['train']['batch_size'] 23 | 24 | 25 | 26 | trainset = UnpairedDataset(A_path, B_path) 27 | train_loader = DataLoader(trainset, batch_size, 28 | shuffle=True, drop_last=True) 29 | # load exp_name 30 | exp_name = config['train']['exp_name'] 31 | 32 | loss_cycle = [] 33 | loss_seg = [] 34 | # load model 35 | 36 | 37 | device = torch.device('cuda:{}'.format(config['train']['gpu'])) 38 | # device = torch.device('cpu') 39 | sifa_model = SIFA(config).to(device) 40 | sifa_model.train() 41 | sifa_model.initialize() 42 | num_epochs = config['train']['num_epochs'] 43 | save_epoch = num_epochs // 20 44 | 45 | for epoch in range(num_epochs): 46 | for i, (A, A_label, B, _) in enumerate(train_loader): 47 | 48 | A = A.to(device).detach() 49 | B = B.to(device).detach() 50 | A_label = A_label.to(device).detach() 51 | 52 | sifa_model.update_GAN(A, B) 53 | sifa_model.update_seg(A, B, A_label) 54 | loss_cyclea, loss_cycleb, segloss = sifa_model.print_loss() 55 | loss_cycle.append(loss_cyclea+loss_cycleb) 56 | loss_seg.append(segloss) 57 | # ddfseg_model.update_lr() #no need for changing lr 58 | if (epoch+1) % save_epoch == 0: 59 | model_dir = "save_model/" + str(exp_name) 60 | if(not os.path.exists(model_dir)): 61 | os.mkdir(model_dir) 62 | sifa_model.sample_image(epoch, exp_name) 63 | torch.save(sifa_model.state_dict(), 64 | '{}/model-{}.pth'.format(model_dir, epoch+1)) 65 | sifa_model.update_lr() 66 | 67 | print('train finished') 68 | loss_cycle = np.array(loss_cycle) 69 | loss_seg = np.array(loss_seg) 70 | np.savez('trainingloss.npz', loss_cycle, loss_seg) 71 | x = np.arange(0, loss_cycle.shape[0]) 72 | plt.figure(1) 73 | plt.plot(x, loss_cycle, label='cycle loss of training') 74 | plt.legend() 75 | plt.xlabel('iterations') 76 | plt.ylabel('cycle loss') 77 | plt.savefig('cycleloss.jpg') 78 | plt.close() 79 | plt.figure(2) 80 | plt.plot(x, loss_seg, label='seg loss of training') 81 | plt.legend() 82 | plt.xlabel('iterations') 83 | plt.ylabel('seg loss') 84 | plt.savefig('segloss.jpg') 85 | plt.close() 86 | print('loss saved') 87 | 88 | 89 | if __name__ == '__main__': 90 | set_random() 91 | train() 92 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | ##load data 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import os 5 | import torch 6 | import yaml 7 | import random 8 | import configparser 9 | 10 | # config setting 11 | def is_int(val_str): 12 | start_digit = 0 13 | if(val_str[0] =='-'): 14 | start_digit = 1 15 | flag = True 16 | for i in range(start_digit, len(val_str)): 17 | if(str(val_str[i]) < '0' or str(val_str[i]) > '9'): 18 | flag = False 19 | break 20 | return flag 21 | 22 | def is_float(val_str): 23 | flag = False 24 | if('.' in val_str and len(val_str.split('.'))==2 and not('./' in val_str)): 25 | if(is_int(val_str.split('.')[0]) and is_int(val_str.split('.')[1])): 26 | flag = True 27 | else: 28 | flag = False 29 | elif('e' in val_str and val_str[0] != 'e' and len(val_str.split('e'))==2): 30 | if(is_int(val_str.split('e')[0]) and is_int(val_str.split('e')[1])): 31 | flag = True 32 | else: 33 | flag = False 34 | else: 35 | flag = False 36 | return flag 37 | 38 | def is_bool(var_str): 39 | if( var_str.lower() =='true' or var_str.lower() == 'false'): 40 | return True 41 | else: 42 | return False 43 | 44 | def parse_bool(var_str): 45 | if(var_str.lower() =='true'): 46 | return True 47 | else: 48 | return False 49 | 50 | def is_list(val_str): 51 | if(val_str[0] == '[' and val_str[-1] == ']'): 52 | return True 53 | else: 54 | return False 55 | 56 | def parse_list(val_str): 57 | sub_str = val_str[1:-1] 58 | splits = sub_str.split(',') 59 | output = [] 60 | for item in splits: 61 | item = item.strip() 62 | if(is_int(item)): 63 | output.append(int(item)) 64 | elif(is_float(item)): 65 | output.append(float(item)) 66 | elif(is_bool(item)): 67 | output.append(parse_bool(item)) 68 | elif(item.lower() == 'none'): 69 | output.append(None) 70 | else: 71 | output.append(item) 72 | return output 73 | 74 | def parse_value_from_string(val_str): 75 | # val_str = val_str.encode('ascii','ignore') 76 | if(is_int(val_str)): 77 | val = int(val_str) 78 | elif(is_float(val_str)): 79 | val = float(val_str) 80 | elif(is_list(val_str)): 81 | val = parse_list(val_str) 82 | elif(is_bool(val_str)): 83 | val = parse_bool(val_str) 84 | elif(val_str.lower() == 'none'): 85 | val = None 86 | else: 87 | val = val_str 88 | return val 89 | 90 | def parse_config(filename): 91 | config = configparser.ConfigParser() 92 | config.read(filename) 93 | output = {} 94 | for section in config.sections(): 95 | output[section] = {} 96 | for key in config[section]: 97 | val_str = str(config[section][key]) 98 | if(len(val_str)>0): 99 | val = parse_value_from_string(val_str) 100 | output[section][key] = val 101 | else: 102 | val = None 103 | print(section, key, val_str, val) 104 | return output 105 | 106 | 107 | def load_npz(path): 108 | img = np.load(path)['arr_0'] 109 | gt = np.load(path)['arr_1'] 110 | return img, gt 111 | 112 | def get_config(config): 113 | with open(config, 'r') as stream: 114 | return yaml.load(stream,Loader=yaml.FullLoader) 115 | 116 | def set_random(seed_id=1234): 117 | np.random.seed(seed_id) 118 | torch.manual_seed(seed_id) #for cpu 119 | torch.cuda.manual_seed_all(seed_id) #for GPU 120 | torch.backends.cudnn.deterministic = True 121 | torch.backends.cudnn.benchmark = True 122 | 123 | # config setting 124 | def is_int(val_str): 125 | start_digit = 0 126 | if(val_str[0] =='-'): 127 | start_digit = 1 128 | flag = True 129 | for i in range(start_digit, len(val_str)): 130 | if(str(val_str[i]) < '0' or str(val_str[i]) > '9'): 131 | flag = False 132 | break 133 | return flag 134 | 135 | def is_float(val_str): 136 | flag = False 137 | if('.' in val_str and len(val_str.split('.'))==2 and not('./' in val_str)): 138 | if(is_int(val_str.split('.')[0]) and is_int(val_str.split('.')[1])): 139 | flag = True 140 | else: 141 | flag = False 142 | elif('e' in val_str and val_str[0] != 'e' and len(val_str.split('e'))==2): 143 | if(is_int(val_str.split('e')[0]) and is_int(val_str.split('e')[1])): 144 | flag = True 145 | else: 146 | flag = False 147 | else: 148 | flag = False 149 | return flag 150 | 151 | def is_bool(var_str): 152 | if( var_str.lower() =='true' or var_str.lower() == 'false'): 153 | return True 154 | else: 155 | return False 156 | 157 | def parse_bool(var_str): 158 | if(var_str.lower() =='true'): 159 | return True 160 | else: 161 | return False 162 | 163 | def is_list(val_str): 164 | if(val_str[0] == '[' and val_str[-1] == ']'): 165 | return True 166 | else: 167 | return False 168 | 169 | def parse_list(val_str): 170 | sub_str = val_str[1:-1] 171 | splits = sub_str.split(',') 172 | output = [] 173 | for item in splits: 174 | item = item.strip() 175 | if(is_int(item)): 176 | output.append(int(item)) 177 | elif(is_float(item)): 178 | output.append(float(item)) 179 | elif(is_bool(item)): 180 | output.append(parse_bool(item)) 181 | elif(item.lower() == 'none'): 182 | output.append(None) 183 | else: 184 | output.append(item) 185 | return output 186 | 187 | def parse_value_from_string(val_str): 188 | # val_str = val_str.encode('ascii','ignore') 189 | if(is_int(val_str)): 190 | val = int(val_str) 191 | elif(is_float(val_str)): 192 | val = float(val_str) 193 | elif(is_list(val_str)): 194 | val = parse_list(val_str) 195 | elif(is_bool(val_str)): 196 | val = parse_bool(val_str) 197 | elif(val_str.lower() == 'none'): 198 | val = None 199 | else: 200 | val = val_str 201 | return val 202 | 203 | def parse_config(filename): 204 | config = configparser.ConfigParser() 205 | config.read(filename) 206 | output = {} 207 | for section in config.sections(): 208 | output[section] = {} 209 | for key in config[section]: 210 | val_str = str(config[section][key]) 211 | if(len(val_str)>0): 212 | val = parse_value_from_string(val_str) 213 | output[section][key] = val 214 | else: 215 | val = None 216 | print(section, key, val_str, val) 217 | return output 218 | 219 | 220 | 221 | 222 | class UnpairedDataset(Dataset): 223 | #get unpaired dataset, such as MR-CT dataset 224 | def __init__(self,A_path,B_path): 225 | listA = os.listdir(A_path) 226 | listB = os.listdir(B_path) 227 | self.listA = [os.path.join(A_path,k) for k in listA] 228 | self.listB = [os.path.join(B_path,k) for k in listB] 229 | self.Asize = len(self.listA) 230 | self.Bsize = len(self.listB) 231 | self.dataset_size = max(self.Asize,self.Bsize) 232 | 233 | def __getitem__(self,index): 234 | if self.Asize == self.dataset_size: 235 | A,A_gt = load_npz(self.listA[index]) 236 | B,B_gt = load_npz(self.listB[random.randint(0, self.Bsize - 1)]) 237 | else : 238 | B,B_gt = load_npz(self.listB[index]) 239 | A,A_gt = load_npz(self.listA[random.randint(0, self.Asize - 1)]) 240 | 241 | 242 | A = torch.from_numpy(A.copy()).unsqueeze(0).float() 243 | A_gt = torch.from_numpy(A_gt.copy()).unsqueeze(0).float() 244 | B = torch.from_numpy(B.copy()).unsqueeze(0).float() 245 | B_gt = torch.from_numpy(B_gt.copy()).unsqueeze(0).float() 246 | return A,A_gt,B,B_gt 247 | 248 | def __len__(self): 249 | return self.dataset_size 250 | 251 | 252 | class SingleDataset(Dataset): 253 | def __init__(self,test_path): 254 | test_list = os.listdir(test_path) 255 | self.test = [os.path.join(test_path,k) for k in test_list] 256 | 257 | def __getitem__(self,index): 258 | img,gt = load_npz(self.test[index]) 259 | 260 | img = torch.from_numpy(img.copy()).unsqueeze(0).float() 261 | gt = torch.from_numpy(gt.copy()).unsqueeze(0).float() 262 | return img, gt 263 | 264 | def __len__(self): 265 | return len(self.test) --------------------------------------------------------------------------------