├── util ├── util.pyc ├── image_pool.py ├── util.py ├── html.py ├── image_property.py └── visualizer.py ├── dynamic_head.gif ├── data ├── base_data_loader.py ├── base_dataset.py ├── __init__.py ├── ms_3d_dataset.py ├── data_generator.py └── ms_dataset.py ├── LICENSE ├── options ├── test_options.py ├── train_options.py └── base_options.py ├── models ├── __init__.py ├── losses.py ├── tiramisu_layers.py ├── networks.py ├── tiramisu_layers_dyn.py ├── tiramisu_model_dyn.py ├── base_model.py ├── tiramisu_model.py └── ms_model.py ├── configurations.py ├── pytorch_ssim └── __init__.py ├── data_conversion.py ├── train.py ├── Python38.yml ├── test.py └── README.md /util/util.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/han-liu/ModDropPlusPlus/HEAD/util/util.pyc -------------------------------------------------------------------------------- /dynamic_head.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/han-liu/ModDropPlusPlus/HEAD/dynamic_head.gif -------------------------------------------------------------------------------- /data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | class BaseDataLoader(): 2 | def __init__(self): 3 | pass 4 | 5 | def initialize(self, opt): 6 | self.opt = opt 7 | pass 8 | 9 | def load_data(): 10 | return None 11 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | 4 | class BaseDataset(data.Dataset): 5 | def __init__(self): 6 | super(BaseDataset, self).__init__() 7 | 8 | def name(self): 9 | return 'BaseDataset' 10 | 11 | @staticmethod 12 | def modify_commandline_options(parser, is_train): 13 | return parser 14 | 15 | def initialize(self, opt): 16 | pass 17 | 18 | def __len__(self): 19 | return 0 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Han Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | def __init__(self, pool_size): 7 | self.pool_size = pool_size 8 | if self.pool_size > 0: 9 | self.num_imgs = 0 10 | self.images = [] 11 | 12 | def query(self, images): 13 | if self.pool_size == 0: 14 | return images 15 | return_images = [] 16 | for image in images: 17 | image = torch.unsqueeze(image.data, 0) 18 | if self.num_imgs < self.pool_size: 19 | self.num_imgs = self.num_imgs + 1 20 | self.images.append(image) 21 | return_images.append(image) 22 | else: 23 | p = random.uniform(0, 1) 24 | if p > 0.5: 25 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 26 | tmp = self.images[random_id].clone() 27 | self.images[random_id] = image 28 | return_images.append(tmp) 29 | else: 30 | return_images.append(image) 31 | return_images = torch.cat(return_images, 0) 32 | return return_images 33 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 8 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 9 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 10 | # Dropout and Batchnorm has different behavioir during training and test. 11 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 12 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') 13 | parser.add_argument('--load_str', type=str, default='', help='the string describing the folds to load, e.g. val0,val1') 14 | parser.add_argument('--use_modality_dropout', action='store_true', help='use modality dropout at training phase') 15 | parser.add_argument('--use_dyn', action='store_false', help='default: use dynamic filter') 16 | self.isTrain = False 17 | return parser 18 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from models.base_model import BaseModel 3 | 4 | 5 | def find_model_using_name(model_name): 6 | # Given the option --model [modelname], 7 | # the file "models/modelname_model.py" 8 | # will be imported. 9 | model_filename = "models." + model_name + "_model" 10 | modellib = importlib.import_module(model_filename) 11 | 12 | # In the file, the class called ModelNameModel() will 13 | # be instantiated. It has to be a subclass of BaseModel, 14 | # and it is case-insensitive. 15 | model = None 16 | target_model_name = model_name.replace('_', '') + 'model' 17 | for name, cls in modellib.__dict__.items(): 18 | if name.lower() == target_model_name.lower() \ 19 | and issubclass(cls, BaseModel): 20 | model = cls 21 | 22 | if model is None: 23 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 24 | exit(0) 25 | 26 | return model 27 | 28 | 29 | def get_option_setter(model_name): 30 | model_class = find_model_using_name(model_name) 31 | return model_class.modify_commandline_options 32 | 33 | 34 | def create_model(opt, model_suffix=''): 35 | model = find_model_using_name(opt.model) 36 | instance = model() 37 | instance.initialize(opt, model_suffix) 38 | print("model [%s] was created" % (instance.name())) 39 | return instance 40 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import os 6 | 7 | 8 | # Converts a Tensor into an image array (numpy) 9 | # |imtype|: the desired type of the converted numpy array 10 | def tensor2im(input_image, imtype=np.uint8): 11 | if isinstance(input_image, torch.Tensor): 12 | image_tensor = input_image.data 13 | else: 14 | return input_image 15 | image_numpy = image_tensor[0].cpu().float().numpy() 16 | if len(image_numpy.shape) == 2: 17 | image_numpy = np.expand_dims(image_numpy, 0) 18 | 19 | ind = image_numpy.shape[0] // 2 20 | image_numpy = image_numpy[ind:ind + 1, :, :, ] 21 | if image_numpy.shape[0] == 1: 22 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 23 | max_value, min_value = 2, -2 24 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) - min_value) / (max_value - min_value) * 255.0 25 | return image_numpy.astype(imtype) 26 | 27 | 28 | def diagnose_network(net, name='network'): 29 | mean = 0.0 30 | count = 0 31 | for param in net.parameters(): 32 | if param.grad is not None: 33 | mean += torch.mean(torch.abs(param.grad.data)) 34 | count += 1 35 | if count > 0: 36 | mean = mean / count 37 | print(name) 38 | print(mean) 39 | 40 | 41 | def save_image(image_numpy, image_path): 42 | image_pil = Image.fromarray(image_numpy) 43 | image_pil.save(image_path) 44 | 45 | 46 | def print_numpy(x, val=True, shp=False): 47 | x = x.astype(np.float64) 48 | if shp: 49 | print('shape,', x.shape) 50 | if val: 51 | x = x.flatten() 52 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 53 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 54 | 55 | 56 | def mkdirs(paths): 57 | if isinstance(paths, list) and not isinstance(paths, str): 58 | for path in paths: 59 | mkdir(path) 60 | else: 61 | mkdir(paths) 62 | 63 | 64 | def mkdir(path): 65 | if not os.path.exists(path): 66 | os.makedirs(path) 67 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | def dice_loss(input, target): 8 | # this loss function need input in the range (0, 1), and target in (0, 1) 9 | smooth = 0.01 10 | 11 | iflat = input.view(-1) 12 | tflat = target.view(-1) 13 | intersection = (iflat * tflat).sum() 14 | 15 | return 1 - ((2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth)) 16 | 17 | 18 | def focal_loss(input, target, alpha, gamma, eps=1e-6): 19 | # this loss function need input in the range (0, 1), and target in (0, 1) 20 | input = input.view(-1, 1) 21 | input = torch.clamp(input, min=eps, max=1-eps) 22 | target = target.view(-1, 1) 23 | loss = -target * alpha * ((1 - input) ** gamma) * torch.log(input) - (1 - target) * (1-alpha) * ( 24 | input ** gamma) * torch.log(1 - input) 25 | return loss.mean() 26 | 27 | 28 | class FocalLoss(nn.Module): 29 | # this loss function need input in the range (-1, 1), and target in (0, 1) 30 | def __init__(self, gamma=0, alpha=None): 31 | super(FocalLoss, self).__init__() 32 | self.gamma = gamma 33 | self.alpha = alpha 34 | if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha]) 35 | if isinstance(alpha,list): self.alpha = torch.Tensor(alpha) 36 | 37 | def forward(self, input, target): 38 | input = torch.unsqueeze(input, -1) 39 | input = torch.cat([0 - input, input], -1) 40 | input = input.contiguous().view(-1, 2) 41 | target = target.view(-1, 1).long() 42 | 43 | logpt = F.log_softmax(input, -1) 44 | logpt = logpt.gather(1,target) 45 | logpt = logpt.view(-1) 46 | pt = Variable(logpt.data.exp()) 47 | 48 | if self.alpha is not None: 49 | if self.alpha.type() != input.data.type(): 50 | self.alpha = self.alpha.type_as(input.data) 51 | at = self.alpha.gather(0, target.data.view(-1)) 52 | logpt = logpt * Variable(at) 53 | 54 | loss = -1 * (1-pt)**self.gamma * logpt 55 | return loss.mean() -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | # print(self.img_dir) 16 | 17 | self.doc = dominate.document(title=title) 18 | if reflesh > 0: 19 | with self.doc.head: 20 | meta(http_equiv="reflesh", content=str(reflesh)) 21 | 22 | def get_image_dir(self): 23 | return self.img_dir 24 | 25 | def add_header(self, str): 26 | with self.doc: 27 | h3(str) 28 | 29 | def add_table(self, border=1): 30 | self.t = table(border=border, style="table-layout: fixed;") 31 | self.doc.add(self.t) 32 | 33 | def add_images(self, ims, txts, links, width=400): 34 | self.add_table() 35 | with self.t: 36 | with tr(): 37 | for im, txt, link in zip(ims, txts, links): 38 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 39 | with p(): 40 | with a(href=os.path.join('images', link)): 41 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 42 | br() 43 | p(txt) 44 | 45 | def save(self): 46 | html_file = '%s/index.html' % self.web_dir 47 | f = open(html_file, 'wt') 48 | f.write(self.doc.render()) 49 | f.close() 50 | 51 | 52 | if __name__ == '__main__': 53 | html = HTML('web/', 'test_html') 54 | html.add_header('hello world') 55 | 56 | ims = [] 57 | txts = [] 58 | links = [] 59 | for n in range(4): 60 | ims.append('image_%d.png' % n) 61 | txts.append('text_%d' % n) 62 | links.append('image_%d.png' % n) 63 | html.add_images(ims, txts, links) 64 | html.save() 65 | -------------------------------------------------------------------------------- /configurations.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # TODO: change the following path based on your dataset 4 | # should be one of the "ISBI", "MICCAI16_1", "MICCAI16_2", "MICCAI16_3", "UMCL" 5 | PATH_DATASET = '/gpfs23/scratch/liuh26/ModDropPlusPlus-main/ms_data' 6 | DATASETS = ["UMCL"] # training domains 7 | 8 | # We assume the format of file names in this pattern {PREFIX}_{PATIENT-ID}_{TIMEPOINT-ID}_{MODALITY(MASK)}.{SUFFIX} 9 | # PREFIX: can be any string 10 | # PATIENT-ID: number id of the patient 11 | # TIMEPOINT-ID: number id of the timepoint 12 | # MODALITY, MASK: t1, flair, t2, pd, mask1, mask2, etc 13 | # SUFFIX: nii, nii.gz, etc 14 | # e.g. training_01_01_flair.nii, training_03_05_mask1.nii 15 | # TODO: change the following constants based on your dataset 16 | 17 | MODALITIES = ['t1', 'flair', 't2', 'pd', 'ce'] # general 18 | MASKS = ['mask1', 'mask2', 'mask'] 19 | SUFFIX = 'nii.gz' 20 | 21 | # The axis corresponding to axial, sagittal and coronal, respectively 22 | # TODO: change the following axes based on your dataset 23 | AXIS_TO_TAKE = [2, 0, 1] 24 | 25 | 26 | # training indepedent modes (IM) 27 | 28 | #---- UMCL ---- 29 | # MODALITIES = ['t1', 'flair', 't2', 'ce'] # 1234 30 | # MODALITIES = ['flair', 't2', 'ce'] # 234 31 | # MODALITIES = ['t1', 't2', 'ce'] # 134 32 | # MODALITIES = ['t1', 'flair', 'ce'] # 124 33 | # MODALITIES = ['t1', 'flair', 't2'] # 123 34 | # MODALITIES = ['t2', 'ce'] # 34 35 | # MODALITIES = ['flair', 'ce'] # 24 36 | # MODALITIES = ['flair', 't2'] # 23 37 | # MODALITIES = ['t1', 'ce'] # 14 38 | # MODALITIES = ['t1', 't2'] # 13 39 | # MODALITIES = ['t1'] # 1 40 | # MODALITIES = ['flair'] # 2 41 | # MODALITIES = ['t2'] # 3 42 | # MODALITIES = ['ce'] # 4 43 | 44 | #---- ISBI ---- 45 | # MODALITIES = ['t1', 'flair', 't2', 'pd'] # 1234 46 | # MODALITIES = ['flair', 't2', 'pd'] # 234 47 | # MODALITIES = ['t1', 't2', 'pd'] # 134 48 | # MODALITIES = ['t1', 'flair', 'pd'] # 124 49 | # MODALITIES = ['t1', 'flair', 't2'] # 123 50 | # MODALITIES = ['t2', 'pd'] # 34 51 | # MODALITIES = ['flair', 'pd'] # 24 52 | # MODALITIES = ['flair', 't2'] # 23 53 | # MODALITIES = ['t1', 'pd'] # 14 54 | # MODALITIES = ['t1', 't2'] # 13 55 | # MODALITIES = ['t1', 'flair'] # 12 56 | # MODALITIES = ['t1'] # 1 57 | # MODALITIES = ['flair'] # 2 58 | # MODALITIES = ['t2'] # 3 59 | # MODALITIES = ['pd'] # 4 -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch.utils.data 3 | from data.base_data_loader import BaseDataLoader 4 | from data.base_dataset import BaseDataset 5 | 6 | 7 | def find_dataset_using_name(dataset_name): 8 | # Given the option --dataset_mode [datasetname], 9 | # the file "data/datasetname_dataset.py" 10 | # will be imported. 11 | dataset_filename = "data." + dataset_name + "_dataset" 12 | datasetlib = importlib.import_module(dataset_filename) 13 | 14 | # In the file, the class called DatasetNameDataset() will 15 | # be instantiated. It has to be a subclass of BaseDataset, 16 | # and it is case-insensitive. 17 | dataset = None 18 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 19 | for name, cls in datasetlib.__dict__.items(): 20 | if name.lower() == target_dataset_name.lower() \ 21 | and issubclass(cls, BaseDataset): 22 | dataset = cls 23 | 24 | if dataset is None: 25 | print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 26 | exit(0) 27 | 28 | return dataset 29 | 30 | 31 | def get_option_setter(dataset_name): 32 | dataset_class = find_dataset_using_name(dataset_name) 33 | return dataset_class.modify_commandline_options 34 | 35 | 36 | def create_dataset(opt): 37 | dataset = find_dataset_using_name(opt.dataset_mode) 38 | instance = dataset() 39 | instance.initialize(opt) 40 | print("dataset [%s] was created" % (instance.name())) 41 | return instance 42 | 43 | 44 | def CreateDataLoader(opt): 45 | data_loader = CustomDatasetDataLoader() 46 | data_loader.initialize(opt) 47 | return data_loader 48 | 49 | 50 | # Wrapper class of Dataset class that performs 51 | # multi-threaded data loading 52 | class CustomDatasetDataLoader(BaseDataLoader): 53 | def name(self): 54 | return 'CustomDatasetDataLoader' 55 | 56 | def initialize(self, opt): 57 | BaseDataLoader.initialize(self, opt) 58 | self.dataset = create_dataset(opt) 59 | self.dataloader = torch.utils.data.DataLoader( 60 | self.dataset, 61 | batch_size=opt.batch_size, 62 | shuffle=not opt.serial_batches, 63 | num_workers=int(opt.num_threads)) 64 | 65 | def load_data(self): 66 | return self 67 | 68 | def __len__(self): 69 | return min(len(self.dataset), self.opt.max_dataset_size) 70 | 71 | def __iter__(self): 72 | for i, data in enumerate(self.dataloader): 73 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 74 | break 75 | yield data 76 | -------------------------------------------------------------------------------- /pytorch_ssim/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 1 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | (_, channel, _, _) = img1.size() 49 | 50 | if channel == self.channel and self.window.data.type() == img1.data.type(): 51 | window = self.window 52 | else: 53 | window = create_window(self.window_size, channel) 54 | 55 | if img1.is_cuda: 56 | window = window.cuda(img1.get_device()) 57 | window = window.type_as(img1) 58 | 59 | self.window = window 60 | self.channel = channel 61 | 62 | 63 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 64 | 65 | def ssim(img1, img2, window_size = 11, size_average = True): 66 | (_, channel, _, _) = img1.size() 67 | window = create_window(window_size, channel) 68 | 69 | if img1.is_cuda: 70 | window = window.cuda(img1.get_device()) 71 | window = window.type_as(img1) 72 | 73 | return _ssim(img1, img2, window, window_size, channel, size_average) 74 | -------------------------------------------------------------------------------- /models/tiramisu_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DenseLayer(nn.Sequential): 6 | def __init__(self, in_channels, growth_rate): 7 | super().__init__() 8 | self.add_module('norm', nn.BatchNorm2d(in_channels)) 9 | self.add_module('relu', nn.ReLU(True)) 10 | self.add_module('conv', nn.Conv2d(in_channels, growth_rate, kernel_size=3, 11 | stride=1, padding=1, bias=True)) 12 | self.add_module('drop', nn.Dropout2d(0.2)) 13 | 14 | def forward(self, x): 15 | return super().forward(x) 16 | 17 | 18 | class DenseBlock(nn.Module): 19 | def __init__(self, in_channels, growth_rate, n_layers, upsample=False): 20 | super().__init__() 21 | self.upsample = upsample 22 | self.layers = nn.ModuleList([DenseLayer( 23 | in_channels + i*growth_rate, growth_rate) 24 | for i in range(n_layers)]) 25 | 26 | def forward(self, x): 27 | if self.upsample: 28 | new_features = [] 29 | #we pass all previous activations into each dense layer normally 30 | #But we only store each dense layer's output in the new_features array 31 | for layer in self.layers: 32 | out = layer(x) 33 | x = torch.cat([x, out], 1) 34 | new_features.append(out) 35 | return torch.cat(new_features,1) 36 | else: 37 | for layer in self.layers: 38 | out = layer(x) 39 | x = torch.cat([x, out], 1) # 1 = channel axis 40 | return x 41 | 42 | 43 | class TransitionDown(nn.Sequential): 44 | def __init__(self, in_channels): 45 | super().__init__() 46 | self.add_module('norm', nn.BatchNorm2d(num_features=in_channels)) 47 | self.add_module('relu', nn.ReLU(inplace=True)) 48 | self.add_module('conv', nn.Conv2d(in_channels, in_channels, 49 | kernel_size=1, stride=1, 50 | padding=0, bias=True)) 51 | self.add_module('drop', nn.Dropout2d(0.2)) 52 | self.add_module('maxpool', nn.MaxPool2d(2)) 53 | 54 | def forward(self, x): 55 | return super().forward(x) 56 | 57 | 58 | class TransitionUp(nn.Module): 59 | def __init__(self, in_channels, out_channels): 60 | super().__init__() 61 | # to use hiddenlayer, make padding=1, output_padding=1, 62 | self.convTrans = nn.ConvTranspose2d( 63 | in_channels=in_channels, out_channels=out_channels, 64 | kernel_size=3, stride=2, padding=0, bias=True) 65 | 66 | def forward(self, x, skip): 67 | out = self.convTrans(x) 68 | # to use hiddenlayer, comment the usage of center_crop 69 | out = center_crop(out, skip.size(2), skip.size(3)) 70 | out = torch.cat([out, skip], 1) 71 | return out 72 | 73 | 74 | class Bottleneck(nn.Sequential): 75 | def __init__(self, in_channels, growth_rate, n_layers): 76 | super().__init__() 77 | self.add_module('bottleneck', DenseBlock( 78 | in_channels, growth_rate, n_layers, upsample=True)) 79 | 80 | def forward(self, x): 81 | return super().forward(x) 82 | 83 | 84 | def center_crop(layer, max_height, max_width): 85 | _, _, h, w = layer.size() 86 | xy1 = (w - max_width) // 2 87 | xy2 = (h - max_height) // 2 88 | return layer[:, :, xy2:(xy2 + max_height), xy1:(xy1 + max_width)] 89 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import init 3 | from torch.optim import lr_scheduler 4 | from models.tiramisu_model_dyn import FCDenseNetDyn 5 | from models.tiramisu_model import FCDenseNet 6 | 7 | ############################################################################### 8 | # Helper Functions 9 | ############################################################################### 10 | 11 | 12 | def get_scheduler(optimizer, opt): 13 | if opt.lr_policy == 'lambda': 14 | def lambda_rule(epoch): 15 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 16 | return lr_l 17 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 18 | elif opt.lr_policy == 'step': 19 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 20 | elif opt.lr_policy == 'plateau': 21 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 22 | elif opt.lr_policy == 'cosine': 23 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 24 | else: 25 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 26 | return scheduler 27 | 28 | 29 | def init_weights(net, init_type='normal', gain=0.02): 30 | def init_func(m): 31 | classname = m.__class__.__name__ 32 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 33 | if init_type == 'normal': 34 | init.normal_(m.weight.data, 0.0, gain) 35 | elif init_type == 'xavier': 36 | init.xavier_normal_(m.weight.data, gain=gain) 37 | elif init_type == 'kaiming': 38 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 39 | elif init_type == 'orthogonal': 40 | init.orthogonal_(m.weight.data, gain=gain) 41 | else: 42 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 43 | if hasattr(m, 'bias') and m.bias is not None: 44 | init.constant_(m.bias.data, 0.0) 45 | elif classname.find('BatchNorm2d') != -1: 46 | init.normal_(m.weight.data, 1.0, gain) 47 | init.constant_(m.bias.data, 0.0) 48 | 49 | print('initialize network with %s' % init_type) 50 | net.apply(init_func) 51 | 52 | 53 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 54 | if len(gpu_ids) > 0: 55 | assert(torch.cuda.is_available()) 56 | net.to(gpu_ids[0]) 57 | # to use hiddenlayer, comment the usage of torch.nn.DataParallel 58 | net = torch.nn.DataParallel(net, gpu_ids) 59 | init_weights(net, init_type, gain=init_gain) 60 | return net 61 | 62 | 63 | def define_G(input_nc, init_type='normal', init_gain=0.02, gpu_ids=[], dyn=False, DG=False): 64 | 65 | if dyn and not DG: 66 | print('\n************************************************') 67 | print('***** MS Lesion Segmentation via ModDrop++ *****') 68 | print('************************************************\n') 69 | net = FCDenseNetDyn(in_channels=input_nc, down_blocks=(4, 4, 4, 4, 4), up_blocks=(4, 4, 4, 4, 4), 70 | bottleneck_layers=4, growth_rate=12, out_chans_first_conv=48, n_classes=1) 71 | 72 | if not dyn and not DG: 73 | print('\n*****************************************************') 74 | print('***** MS Lesion Segmentation via Static Network *****') 75 | print('*****************************************************\n') 76 | net = FCDenseNet(in_channels=input_nc, down_blocks=(4, 4, 4, 4, 4), up_blocks=(4, 4, 4, 4, 4), 77 | bottleneck_layers=4, growth_rate=12, out_chans_first_conv=48, n_classes=1) 78 | 79 | if DG: 80 | print('=== Code for Domain Generalization coming soon === !!!') 81 | 82 | return init_net(net, init_type, init_gain, gpu_ids) -------------------------------------------------------------------------------- /util/image_property.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import pickle 4 | import cv2 5 | import hashlib 6 | import numpy as np 7 | import nibabel as nib 8 | import statsmodels.api as sm 9 | from scipy.signal import argrelextrema 10 | 11 | 12 | def hash_file(filename): 13 | """"This function returns the SHA-1 hash 14 | of the file passed into it""" 15 | 16 | # make a hash object 17 | h = hashlib.sha1() 18 | 19 | # open file for reading in binary mode 20 | with open(filename, 'rb') as file: 21 | # loop till the end of the file 22 | chunk = 0 23 | while chunk != b'': 24 | # read only 1024 bytes at a time 25 | chunk = file.read(1024) 26 | h.update(chunk) 27 | 28 | # return the hex representation of digest 29 | return h.hexdigest() 30 | 31 | 32 | def normalize_image(vol, contrast): 33 | # copied from FLEXCONN 34 | # slightly changed to fit our implementation 35 | temp = vol[np.nonzero(vol)].astype(float) 36 | q = np.percentile(temp, 99) 37 | temp = temp[temp <= q] 38 | temp = temp.reshape(-1, 1) 39 | bw = q / 80 40 | # print("99th quantile is %.4f, gridsize = %.4f" % (q, bw)) 41 | 42 | kde = sm.nonparametric.KDEUnivariate(temp) 43 | 44 | kde.fit(kernel='gau', bw=bw, gridsize=80, fft=True) 45 | x_mat = 100.0 * kde.density 46 | y_mat = kde.support 47 | 48 | indx = argrelextrema(x_mat, np.greater) 49 | indx = np.asarray(indx, dtype=int) 50 | heights = x_mat[indx][0] 51 | peaks = y_mat[indx][0] 52 | peak = 0.00 53 | # print("%d peaks found." % (len(peaks))) 54 | 55 | # norm_vol = vol 56 | if contrast.lower() in ["t1", "mprage", "ce"]: # customized by Han 57 | peak = peaks[-1] 58 | # print("Peak found at %.4f for %s" % (peak, contrast)) 59 | # norm_vol = vol/peak 60 | # norm_vol[norm_vol > 1.25] = 1.25 61 | # norm_vol = norm_vol/1.25 62 | elif contrast.lower() in ['t2', 'pd', 'flair', 'fl']: 63 | peak_height = np.amax(heights) 64 | idx = np.where(heights == peak_height) 65 | peak = peaks[idx] 66 | # print("Peak found at %.4f for %s" % (peak, contrast)) 67 | # norm_vol = vol / peak 68 | # norm_vol[norm_vol > 3.5] = 3.5 69 | # norm_vol = norm_vol / 3.5 70 | else: 71 | print("Contrast must be either t1, t2, pd, ce or flair. You entered %s. Returning 0." % contrast) 72 | 73 | # return peak, norm_vol 74 | return peak 75 | 76 | 77 | def slice_with_neighborhood(data_3d, axis_to_take, idx, neighborhood=0, ratio=1.0): 78 | axis_len = data_3d.shape[axis_to_take] 79 | assert axis_to_take in [0, 1, 2] 80 | assert axis_len > idx >= 0 81 | transpose = [[1, 2, 0], [0, 2, 1], [0, 1, 2]] 82 | sl = [slice(None)] * 3 83 | if idx - neighborhood < 0: 84 | sl[axis_to_take] = slice(0, idx + neighborhood + 1, 1) 85 | slice_tmp = np.transpose(np.copy(data_3d[tuple(sl)]), transpose[axis_to_take]) 86 | shape = slice_tmp.shape 87 | array_pad = np.zeros((shape[0], shape[1], neighborhood-idx)) 88 | # print(slice_tmp.shape, array_pad.shape) 89 | slice_to_return = np.concatenate((array_pad, slice_tmp), axis=2) 90 | elif idx + neighborhood >= axis_len: 91 | sl[axis_to_take] = slice(idx - neighborhood, axis_len, 1) 92 | slice_tmp = np.transpose(np.copy(data_3d[tuple(sl)]), transpose[axis_to_take]) 93 | shape = slice_tmp.shape 94 | array_pad = np.zeros((shape[0], shape[1], idx + neighborhood - axis_len + 1)) 95 | slice_to_return = np.concatenate((slice_tmp, array_pad), axis=2) 96 | else: 97 | sl[axis_to_take] = slice(idx - neighborhood, idx + neighborhood + 1, 1) 98 | slice_to_return = np.transpose(np.copy(data_3d[tuple(sl)]), transpose[axis_to_take]) 99 | 100 | if abs(ratio - 1.0) < 1e-3: 101 | return slice_to_return 102 | 103 | return resize_with_ratio(slice_to_return, ratio) 104 | 105 | 106 | def resize_with_ratio(image, ratio, modality='t1'): 107 | height, width = image.shape[:2] 108 | if ratio > 1: 109 | height, width = height, int(width * ratio) 110 | else: 111 | height, width = int(height / ratio), width 112 | interpolation = cv2.INTER_LINEAR 113 | return cv2.resize(image, (width, height), interpolation) -------------------------------------------------------------------------------- /models/tiramisu_layers_dyn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class DenseLayer(nn.Sequential): 7 | def __init__(self, in_channels, growth_rate): 8 | super().__init__() 9 | self.add_module('norm', nn.BatchNorm2d(in_channels)) 10 | self.add_module('relu', nn.ReLU(True)) 11 | self.add_module('conv', nn.Conv2d(in_channels, growth_rate, kernel_size=3, 12 | stride=1, padding=1, bias=True)) 13 | self.add_module('drop', nn.Dropout2d(0.2)) 14 | 15 | def forward(self, x): 16 | return super().forward(x) 17 | 18 | 19 | class DenseBlock(nn.Module): 20 | def __init__(self, in_channels, growth_rate, n_layers, upsample=False): 21 | super().__init__() 22 | self.upsample = upsample 23 | self.layers = nn.ModuleList([DenseLayer( 24 | in_channels + i*growth_rate, growth_rate) 25 | for i in range(n_layers)]) 26 | 27 | def forward(self, x): 28 | if self.upsample: 29 | new_features = [] 30 | #we pass all previous activations into each dense layer normally 31 | #But we only store each dense layer's output in the new_features array 32 | for layer in self.layers: 33 | out = layer(x) 34 | x = torch.cat([x, out], 1) 35 | new_features.append(out) 36 | return torch.cat(new_features,1) 37 | else: 38 | for layer in self.layers: 39 | out = layer(x) 40 | x = torch.cat([x, out], 1) # 1 = channel axis 41 | return x 42 | 43 | 44 | class TransitionDown(nn.Sequential): 45 | def __init__(self, in_channels): 46 | super().__init__() 47 | self.add_module('norm', nn.BatchNorm2d(num_features=in_channels)) 48 | self.add_module('relu', nn.ReLU(inplace=True)) 49 | self.add_module('conv', nn.Conv2d(in_channels, in_channels, 50 | kernel_size=1, stride=1, 51 | padding=0, bias=True)) 52 | self.add_module('drop', nn.Dropout2d(0.2)) 53 | self.add_module('maxpool', nn.MaxPool2d(2)) 54 | 55 | def forward(self, x): 56 | return super().forward(x) 57 | 58 | 59 | class TransitionUp(nn.Module): 60 | def __init__(self, in_channels, out_channels): 61 | super().__init__() 62 | # to use hiddenlayer, make padding=1, output_padding=1, 63 | self.convTrans = nn.ConvTranspose2d( 64 | in_channels=in_channels, out_channels=out_channels, 65 | kernel_size=3, stride=2, padding=0, bias=True) 66 | 67 | def forward(self, x, skip): 68 | out = self.convTrans(x) 69 | # to use hiddenlayer, comment the usage of center_crop 70 | out = center_crop(out, skip.size(2), skip.size(3)) 71 | out = torch.cat([out, skip], 1) 72 | return out 73 | 74 | 75 | class Bottleneck(nn.Sequential): 76 | def __init__(self, in_channels, growth_rate, n_layers): 77 | super().__init__() 78 | self.add_module('bottleneck', DenseBlock( 79 | in_channels, growth_rate, n_layers, upsample=True)) 80 | 81 | def forward(self, x): 82 | return super().forward(x) 83 | 84 | 85 | def center_crop(layer, max_height, max_width): 86 | _, _, h, w = layer.size() 87 | xy1 = (w - max_width) // 2 88 | xy2 = (h - max_height) // 2 89 | return layer[:, :, xy2:(xy2 + max_height), xy1:(xy1 + max_width)] 90 | 91 | 92 | def parse_dynamic_params(params, weight_nums, bias_nums): 93 | assert params.dim() == 2 94 | assert len(weight_nums) == len(bias_nums) 95 | assert params.size(1) == sum(weight_nums) + sum(bias_nums) 96 | 97 | num_insts = params.size(0) # batch size N 98 | num_layers = len(weight_nums) 99 | 100 | params_splits = list(torch.split_with_sizes( 101 | params, weight_nums + bias_nums, dim=1 102 | )) 103 | 104 | weight_splits = params_splits[:num_layers] 105 | bias_splits = params_splits[num_layers:] 106 | w, b = weight_splits[0], bias_splits[0] 107 | w = w.reshape(48*num_insts, 15, 1, 1) # should be (out_channels, in_channels/groups, kH, kW) 108 | b = b.reshape(48*num_insts) # (out_channels) 109 | return w, b 110 | 111 | 112 | def dynamic_head(x, weights, biases, num_insts): # x: input data/feature maps 113 | assert x.dim() == 4 114 | # print(x.size()) 115 | x = F.conv2d( 116 | x, weights, bias=biases, 117 | stride=1, padding=1, 118 | groups=num_insts 119 | ) 120 | x = F.relu(x) 121 | return x -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--test_mode', required=True, 8 | help='should be local or submission. If val_test, we will use 3/5 for training, ' 9 | '1/5 for validation and 1/5 for test; if val, we will use 4/5 for training and 1/5 for validation') 10 | parser.add_argument('--display_freq', type=int, default=2000, help='frequency of showing training results on screen') 11 | parser.add_argument('--display_ncols', type=int, default=3, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 12 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') 13 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 14 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 15 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 16 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') 17 | parser.add_argument('--print_freq', type=int, default=200, help='frequency of showing training results on console') 18 | parser.add_argument('--save_latest_freq', type=int, default=40000, help='frequency of saving the latest results') 19 | parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs') 20 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 21 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 22 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 23 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 24 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 25 | parser.add_argument('--niter_decay', type=int, default=500, help='# of iter to linearly decay learning rate to zero') 26 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 27 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 28 | parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images') 29 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 30 | parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine') 31 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 32 | parser.add_argument('--val_epoch_freq', type=int, default=5, help='frequency of validation at the end of epochs') 33 | parser.add_argument('--num_val', type=int, default=400, help='how many test images to run during validation') 34 | parser.add_argument('--eval_val', action='store_true', help='use eval mode during validation.') 35 | parser.add_argument('--base_name', type=str, default=None, help='the name of base model to be loaded for fine tuning') 36 | parser.add_argument('--finetuning', action='store_true', help='continue training: fine tune the base model') 37 | parser.add_argument('--feature_extract', action='store_true', help='only fine tune the last layer (mask output) if set to be true') 38 | parser.add_argument('--loss_to_use', type=str, default='l2', help='the loss function to be used for segmentation training') 39 | parser.add_argument('--n_fold', type=int, default=5, help='n_fold cross-validation') 40 | parser.add_argument('--test_index', type=int, default=None, 41 | help='this argument has different effects when argument test_mode is set to different mode,' 42 | 'val: we will not do test, so this argument does not have effect; ' 43 | 'test: the test fold is set to be this value, in default it is set to the last fold' 44 | 'val_test: we will do both val and test, the test fold is set to this value') 45 | parser.add_argument('--use_modality_dropout', action='store_false', help='use modality dropout at training phase') 46 | parser.add_argument('--use_dyn', action='store_false', help='use dynamic filter') 47 | self.isTrain = True 48 | return parser 49 | -------------------------------------------------------------------------------- /data_conversion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import shutil 4 | import nibabel as nib 5 | import numpy as np 6 | from util.util import mkdir 7 | from util.image_property import normalize_image, hash_file 8 | from configurations import * 9 | 10 | # for ISBI dataset 11 | def get_ids(): 12 | dic_ids = {} 13 | mkdir(os.path.join(PATH_DATASET, 'raw')) 14 | fnames = sorted(os.listdir(PATH_DATASET)) 15 | for fname in fnames: 16 | if not fname.endswith(SUFFIX): 17 | continue 18 | fname_tmp = fname.split('.')[0] 19 | print(fname_tmp.split('_')) 20 | prefix, patient_str, timepoint_str, modality = fname_tmp.split('_') 21 | patient_id, timepoint_id = int(patient_str), int(timepoint_str) 22 | # patient_id, timepoint_id = int(patient_id), int(timepoint_id) 23 | if patient_id in dic_ids and timepoint_id in dic_ids[patient_id]: 24 | continue 25 | 26 | # if there is new patient id and timepoint id, we get all the modalities and masks based on the constants 27 | # and we will move the files into the 'raw' subdirectory 28 | if patient_id not in dic_ids: 29 | dic_ids[patient_id] = {} 30 | dic_ids[patient_id][timepoint_id] = {'modalities':{}, 'mask':{}} 31 | for mod in MODALITIES+MASKS: 32 | fname_modality = '_'.join((prefix, patient_str, timepoint_str, mod)) + '.' + SUFFIX 33 | path_modality_src = os.path.join(PATH_DATASET, fname_modality) 34 | path_modality_dst = os.path.join(PATH_DATASET, 'raw', fname_modality) 35 | category = 'modalities' if mod in MODALITIES else 'mask' 36 | print(path_modality_src) 37 | assert os.path.exists(path_modality_src) 38 | 39 | shutil.move(path_modality_src, path_modality_dst) 40 | dic_ids[patient_id][timepoint_id][category][mod] = path_modality_dst 41 | fname_json = os.path.join(PATH_DATASET, 'ids.json') 42 | with open(fname_json, 'w') as f: 43 | json.dump(dic_ids, f, indent=2) 44 | return dic_ids 45 | 46 | 47 | # for non-ISBI dataset 48 | # def get_ids(): 49 | # dic_ids = {} 50 | # mkdir(os.path.join(PATH_DATASET, 'raw')) 51 | # fnames = sorted(os.listdir(PATH_DATASET)) 52 | # for fname in fnames: 53 | # if not fname.endswith(SUFFIX): 54 | # continue 55 | # fname_tmp = fname.split('.')[0] 56 | # # print(fname_tmp.split('_')) 57 | # print(fname_tmp.split('_')) 58 | # patient_str, modality = fname_tmp.split('_') 59 | # # patient_id, timepoint_id = int(patient_id), int(timepoint_id) 60 | # 61 | # timepoint_str, timepoint_id = "1", 1 # consistent with the format of ISBI dataset 62 | # prefix = "training" 63 | # if patient_str in dic_ids and timepoint_id in dic_ids[patient_str]: 64 | # continue 65 | # 66 | # if patient_str not in dic_ids: 67 | # dic_ids[patient_str] = {} 68 | # dic_ids[patient_str][timepoint_id] = {'modalities':{}, 'mask':{}} 69 | # for mod in MODALITIES + MASKS: 70 | # fname_modality = '_'.join((patient_str, mod)) + '.' + SUFFIX 71 | # out_fname_modality = '_'.join((patient_str, timepoint_str, mod)) + '.' + SUFFIX 72 | # path_modality_src = os.path.join(PATH_DATASET, fname_modality) 73 | # # print(path_modality_src) 74 | # path_modality_dst = os.path.join(PATH_DATASET, 'raw', out_fname_modality) 75 | # category = 'modalities' if mod in MODALITIES else 'mask' 76 | # assert os.path.exists(path_modality_src) 77 | # shutil.move(path_modality_src, path_modality_dst) 78 | # dic_ids[patient_str][timepoint_id][category][mod] = path_modality_dst 79 | # fname_json = os.path.join(PATH_DATASET, 'ids.json') 80 | # with open(fname_json, 'w') as f: 81 | # json.dump(dic_ids, f, indent=2) 82 | # return dic_ids 83 | 84 | 85 | def get_properties(): 86 | fname_json = os.path.join(PATH_DATASET, 'ids.json') 87 | with open(fname_json, 'r') as f: 88 | dic_ids = json.load(f) 89 | dic_properties = {} 90 | for patient_id in dic_ids: 91 | print(patient_id) 92 | for timepoint_id in dic_ids[patient_id]: 93 | for modality in dic_ids[patient_id][timepoint_id]['modalities']: 94 | path_modality = dic_ids[patient_id][timepoint_id]['modalities'][modality] 95 | label = hash_file(path_modality) 96 | data = nib.load(path_modality).get_fdata() 97 | peak = normalize_image(data, modality) 98 | peak = peak[0] if isinstance(peak, np.ndarray) else peak 99 | dic_properties[label] = {} 100 | dic_properties[label]['path'] = path_modality 101 | dic_properties[label]['peak'] = peak 102 | fname_json = os.path.join(PATH_DATASET, 'properties.json') 103 | with open(fname_json, 'w') as f: 104 | json.dump(dic_properties, f, indent=2) 105 | return dic_properties 106 | 107 | 108 | if __name__ == '__main__': 109 | assert os.path.exists(PATH_DATASET) 110 | 111 | dic_ids = get_ids() 112 | dic_properties = get_properties() 113 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | import data 7 | from configurations import * 8 | 9 | 10 | class BaseOptions(): 11 | def __init__(self): 12 | self.initialized = False 13 | 14 | def initialize(self, parser): 15 | parser.add_argument('--dataroot', default=PATH_DATASET) 16 | parser.add_argument('--batch_size', type=int, default=16, help='input batch size') 17 | parser.add_argument('--testSize', type=int, default=256, help='the image size used in val/test phases') 18 | parser.add_argument('--trainSize', type=int, default=128, help='then image size used in train phase') 19 | parser.add_argument('--display_winsize', type=int, default=128, help='display window size for both visdom and HTML') 20 | parser.add_argument('--input_nc', type=int, default=1, help='# of input image channels') 21 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 22 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 23 | parser.add_argument('--dataset_mode', type=str, default='ms', help='chooses how datasets are loaded. [unaligned | aligned | single]') 24 | parser.add_argument('--model', type=str, default='ms', help='currently only ms model available') 25 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 26 | parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') 27 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 28 | parser.add_argument('--checkpoints_dir', type=str, default='../Checkpoints', help='models are saved here') 29 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 30 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 31 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 32 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 33 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') 34 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 35 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 36 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_size{trainSize}') 37 | self.initialized = True 38 | return parser 39 | 40 | def gather_options(self): 41 | # initialize parser with basic options 42 | if not self.initialized: 43 | parser = argparse.ArgumentParser( 44 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 45 | parser = self.initialize(parser) 46 | 47 | # get the basic options 48 | opt, _ = parser.parse_known_args() 49 | 50 | # modify model-related parser options 51 | model_name = opt.model 52 | model_option_setter = models.get_option_setter(model_name) 53 | parser = model_option_setter(parser, self.isTrain) 54 | opt, _ = parser.parse_known_args() # parse again with the new defaults 55 | 56 | # modify dataset-related parser options 57 | dataset_name = opt.dataset_mode 58 | dataset_option_setter = data.get_option_setter(dataset_name) 59 | parser = dataset_option_setter(parser, self.isTrain) 60 | 61 | self.parser = parser 62 | 63 | return parser.parse_args() 64 | 65 | def print_options(self, opt): 66 | message = '' 67 | message += '----------------- Options ---------------\n' 68 | for k, v in sorted(vars(opt).items()): 69 | comment = '' 70 | default = self.parser.get_default(k) 71 | if v != default: 72 | comment = '\t[default: %s]' % str(default) 73 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 74 | message += '----------------- End -------------------' 75 | print(message) 76 | 77 | # save to the disk 78 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 79 | util.mkdirs(expr_dir) 80 | file_name = os.path.join(expr_dir, '%s_opt.txt' % opt.phase) 81 | with open(file_name, 'wt') as opt_file: 82 | opt_file.write(message) 83 | opt_file.write('\n') 84 | 85 | def parse(self): 86 | 87 | opt = self.gather_options() 88 | opt.isTrain = self.isTrain # train or test 89 | 90 | # process opt.suffix 91 | if opt.suffix: 92 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 93 | opt.name = opt.name + suffix 94 | 95 | self.print_options(opt) 96 | 97 | # set gpu ids 98 | str_ids = opt.gpu_ids.split(',') 99 | opt.gpu_ids = [] 100 | for str_id in str_ids: 101 | id = int(str_id) 102 | if id >= 0: 103 | opt.gpu_ids.append(id) 104 | if len(opt.gpu_ids) > 0: 105 | torch.cuda.set_device(opt.gpu_ids[0]) 106 | 107 | self.opt = opt 108 | return self.opt 109 | -------------------------------------------------------------------------------- /data/ms_3d_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torchvision.transforms as transforms 3 | import numpy as np 4 | import json 5 | import nibabel as nib 6 | from data.base_dataset import BaseDataset 7 | from configurations import * 8 | from util.image_property import hash_file, normalize_image, slice_with_neighborhood 9 | 10 | 11 | def get_domain_code(path): 12 | code = np.zeros(5) 13 | if 'ISBI' in path: 14 | code[0] = 1 15 | elif 'MICCAI16_1' in path: 16 | code[1] = 1 17 | elif 'MICCAI16_2' in path: 18 | code[2] = 1 19 | elif 'MICCAI16_3' in path: 20 | code[3] = 1 21 | elif 'UMCL' in path: 22 | code[4] = 1 23 | return code 24 | 25 | 26 | def get_modality_code(modalities): 27 | code = np.zeros(5) 28 | if 't1' in modalities: 29 | code[0] = 1 30 | if 'flair' in modalities: 31 | code[1] = 1 32 | if 't2' in modalities: 33 | code[2] = 1 34 | if 'pd' in modalities: 35 | code[3] = 1 36 | if 'ce' in modalities: 37 | code[4] = 1 38 | return code 39 | 40 | 41 | def get_3d_paths(dir): 42 | images = [] 43 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 44 | 45 | for root, _, fnames in sorted(os.walk(dir)): 46 | for fname in fnames: 47 | if fname.endswith(MODALITIES[0]+'.' + SUFFIX): 48 | images.append([]) 49 | images[-1].append(os.path.join(root, fname)) 50 | for i in range(1, len(MODALITIES)): 51 | images[-1].append(os.path.join(root, fname.replace(MODALITIES[0]+'.'+SUFFIX, MODALITIES[i]+'.' + SUFFIX))) 52 | images[-1].append(os.path.join(root, fname.replace(MODALITIES[0] + '.' + SUFFIX, 'mask.' + SUFFIX))) 53 | images.sort(key=lambda x: x[0]) 54 | return images 55 | 56 | 57 | def flip_by_times(np_array, times): 58 | for i in range(times): 59 | np_array = np.flip(np_array, axis=1) 60 | return np_array 61 | 62 | 63 | class Ms3dDataset(BaseDataset): 64 | @staticmethod 65 | def modify_commandline_options(parser, is_train): 66 | return parser 67 | 68 | def initialize(self, opt): 69 | self.all_paths = [] 70 | self.dic_properties = {} 71 | for dataset_name in DATASETS: 72 | self.dir_AB = os.path.join(opt.dataroot, dataset_name, opt.phase) 73 | self.all_paths += get_3d_paths(self.dir_AB) 74 | with open(os.path.join(opt.dataroot, dataset_name, 'properties.json'), 'r') as f: 75 | self.dic_properties.update(json.load(f)) 76 | self.neighbors = opt.input_nc // 2 77 | 78 | def __getitem__(self, index): 79 | paths_this_scan = self.all_paths[index] 80 | # print(paths_this_scan): [../_t1.nii.gz, ../_flair.nii.gz, ../_t2.nii.gz, ../_pd.nii.gz, ../_mask.nii.gz] 81 | voxel_sizes = nib.affines.voxel_sizes(nib.load(paths_this_scan[0]).affine) 82 | data_all_modalities = {} 83 | all_modalities = [] 84 | if os.path.exists(paths_this_scan[-1]): 85 | data_all_modalities['mask'] = nib.load(paths_this_scan[-1]).get_fdata() 86 | path_name = paths_this_scan[0][:-9] 87 | 88 | for i, modality in enumerate(MODALITIES): 89 | # modality = 't1', 'flair', 't2', 'pd', 'ce' 90 | # i = 0, 1, 2, 3, 4 91 | path_modality = path_name + modality + '.nii.gz' 92 | if os.path.exists(path_modality): # and '_t2' not in path_modality and '_t1' not in path_modality and '_flair' not in path_modality: 93 | all_modalities.append(modality) 94 | label_modality = hash_file(path_modality) 95 | data_modality = nib.load(path_modality).get_fdata() 96 | 97 | if label_modality in self.dic_properties: 98 | peak_modality = self.dic_properties[label_modality]['peak'] 99 | else: 100 | peak_modality = normalize_image(data_modality, modality) 101 | data_all_modalities[modality] = np.array(data_modality / peak_modality, dtype=np.float32) 102 | else: 103 | data_all_modalities[modality] = np.zeros(data_all_modalities['mask'].shape) 104 | 105 | data_return = {mod: {'axial': [], 'sagittal': [], 'coronal': []} for mod in MODALITIES+['mask']} 106 | data_return['org_size'] = {'axial': None, 'sagittal': None, 'coronal': None} 107 | data_return['mask_paths'] = paths_this_scan[-1] 108 | data_return['alt_paths'] = paths_this_scan[0] 109 | data_return['dc'] = get_domain_code(paths_this_scan[0]) 110 | data_return['mc'] = get_modality_code(all_modalities) 111 | 112 | for k, orientation in enumerate(['axial', 'sagittal', 'coronal']): 113 | ratio = [size for axis, size in enumerate(voxel_sizes) if axis != AXIS_TO_TAKE[k]] 114 | ratio = ratio[1] / ratio[0] 115 | cur_shape = data_all_modalities[MODALITIES[0]].shape 116 | slices_per_image = cur_shape[AXIS_TO_TAKE[k]] 117 | data_return['org_size'][orientation] = \ 118 | tuple([axis_len for axis, axis_len in enumerate(cur_shape) if axis != AXIS_TO_TAKE[k]]) 119 | for i in range(slices_per_image): 120 | for modality in MODALITIES: 121 | slice_modality = slice_with_neighborhood(data_all_modalities[modality], AXIS_TO_TAKE[k], i, self.neighbors, ratio) 122 | slice_modality = transforms.ToTensor()(slice_modality).float() 123 | if modality in all_modalities: 124 | slice_modality = slice_modality / 2 - 1 125 | data_return[modality][orientation].append(slice_modality) 126 | if os.path.exists(paths_this_scan[-1]): 127 | slice_modality = slice_with_neighborhood(data_all_modalities['mask'], AXIS_TO_TAKE[k], i, 0) 128 | slice_modality = transforms.ToTensor()(slice_modality).float() 129 | slice_modality = slice_modality * 2 - 1 130 | data_return['mask'][orientation].append(slice_modality) 131 | return data_return 132 | 133 | def __len__(self): 134 | return len(self.all_paths) 135 | 136 | def name(self): 137 | return 'Ms3dDataset' 138 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import copy 3 | import os 4 | import torch 5 | import random 6 | import numpy as np 7 | from options.train_options import TrainOptions 8 | from data import CreateDataLoader 9 | from models import create_model 10 | from util.visualizer import Visualizer 11 | from data.data_generator import DataGenerator 12 | import test 13 | 14 | 15 | def set_seed(): 16 | seed = 10 17 | random.seed(seed) 18 | os.environ['PYTHONHASHSEED'] = str(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 23 | 24 | # make cudnn to be reproducible for performance 25 | # can be commented for faster training 26 | torch.backends.cudnn.benchmark = False 27 | torch.backends.cudnn.deterministic = True 28 | 29 | 30 | def get_val_test_opts(opt_train): 31 | opt_val = copy.deepcopy(opt_train) 32 | opt_val.phase = 'val' 33 | opt_val.num_threads = 1 34 | opt_val.batch_size = 1 35 | opt_val.serial_batches = True # no shuffle 36 | opt_val.no_flip = True # no flip 37 | opt_val.dataset_mode = 'ms_3d' 38 | 39 | opt_test = copy.deepcopy(opt_val) 40 | opt_test.phase = 'test' 41 | return opt_val, opt_test 42 | 43 | 44 | def create_data_loader(opt_this_phase): 45 | data_loader = CreateDataLoader(opt_this_phase) 46 | dataset = data_loader.load_data() 47 | dataset_size = len(data_loader) 48 | print('#%s images = %d' % (opt_this_phase.phase, dataset_size)) 49 | return dataset, dataset_size 50 | 51 | 52 | if __name__ == '__main__': 53 | set_seed() 54 | print('process id ', os.getpid()) 55 | 56 | opt = TrainOptions().parse() 57 | opt_val, opt_test = get_val_test_opts(opt) 58 | test_index = opt.n_fold - 1 if opt.test_index is None else opt.test_index 59 | test_index = opt.n_fold if 'test' not in opt.test_mode else test_index 60 | val_indices = [x for x in range(opt.n_fold) if x != test_index] if opt.test_mode != 'test' else [test_index] 61 | 62 | models = [] 63 | data_generator = DataGenerator() 64 | for val_index in val_indices: # for each fold in cross-validation 65 | # data_generator.build_dataset(val_index, test_index, opt.test_mode) # uncomment for online data generation 66 | dataset, dataset_size = create_data_loader(opt) 67 | dataset_val, dataset_size_val = create_data_loader(opt_val) 68 | dataset_test, dataset_size_test = create_data_loader(opt_test) 69 | model_suffix = 'val%d' % val_index if 'val' in opt.test_mode else '' 70 | model_suffix += 'test%d' % test_index if 'test' in opt.test_mode else '' 71 | model = create_model(opt, model_suffix) 72 | model.setup(opt) 73 | visualizer = Visualizer(opt) 74 | 75 | total_steps, best_epochs, val_losses_best = 0, 0, 0 76 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): # for each epoch 77 | epoch_start_time = time.time() 78 | iter_data_time = time.time() 79 | epoch_iter = 0 80 | 81 | for i, data in enumerate(dataset): # for each iteration 82 | # print(i, data, list(data.keys()), data['paths']) 83 | iter_start_time = time.time() 84 | if total_steps % opt.print_freq == 0: 85 | t_data = iter_start_time - iter_data_time 86 | visualizer.reset() 87 | total_steps += opt.batch_size 88 | epoch_iter += opt.batch_size 89 | model.set_input(data) 90 | model.optimize_parameters() 91 | 92 | if total_steps % opt.display_freq == 0: 93 | save_result = total_steps % opt.update_html_freq == 0 94 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 95 | 96 | if total_steps % opt.print_freq == 0: 97 | losses = model.get_current_losses() 98 | t = (time.time() - iter_start_time) / opt.batch_size 99 | visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data) 100 | if opt.display_id > 0: 101 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses) 102 | 103 | iter_data_time = time.time() 104 | 105 | # finish training, start validation 106 | 107 | if epoch % opt.save_epoch_freq == 0: 108 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) 109 | model.save_networks(epoch) 110 | 111 | if dataset_size_val > 0 and epoch % opt.val_epoch_freq == 0: 112 | start_time_val = time.time() 113 | if opt_val.eval_val: 114 | model.eval() 115 | losses_val = test.model_test([model], dataset_val, opt_val, dataset_size_val) 116 | if opt.display_id > 0: 117 | visualizer.plot_val_losses(epoch, 0, opt_val, losses_val, model_suffix=model_suffix) 118 | else: 119 | visualizer.save_val_losses(epoch, 0, opt_val, losses_val, model_suffix=model_suffix) 120 | visualizer.print_val_losses(epoch, losses_val, time.time() - start_time_val) 121 | model.train() 122 | 123 | if losses_val['dice'] > val_losses_best: 124 | val_losses_best = losses_val['dice'] 125 | best_epochs = epoch 126 | model.save_networks('latest') 127 | elif epoch - best_epochs >= 160 and 'val' in opt.test_mode: 128 | break 129 | 130 | print("best epoch", best_epochs, "best loss", val_losses_best) 131 | 132 | print('finished epoch %d / %d, \t time taken: %d sec' % 133 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 134 | model.update_learning_rate() 135 | models.append(model) 136 | 137 | losses_test = test.model_test(models, dataset_test, opt_test, dataset_size_test, save_images=True, 138 | mask_suffix=opt_test.name, save_membership=False) 139 | print(losses_test) 140 | -------------------------------------------------------------------------------- /Python38.yml: -------------------------------------------------------------------------------- 1 | name: Python38 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - argon2-cffi=20.1.0=py38h1e0a361_1 8 | - attrs=20.1.0=pyh9f0ad1d_0 9 | - backcall=0.2.0=pyh9f0ad1d_0 10 | - backports=1.0=py_2 11 | - backports.functools_lru_cache=1.6.1=py_0 12 | - bleach=3.1.5=pyh9f0ad1d_0 13 | - brotlipy=0.7.0=py38h1e0a361_1000 14 | - ca-certificates=2020.6.20=hecda079_0 15 | - certifi=2020.6.20=py38h32f6830_0 16 | - cffi=1.14.2=py38he30daa8_0 17 | - chardet=3.0.4=py38h32f6830_1006 18 | - cryptography=3.1=py38h766eaa4_0 19 | - decorator=4.4.2=py_0 20 | - defusedxml=0.6.0=py_0 21 | - entrypoints=0.3=py38h32f6830_1001 22 | - idna=2.10=pyh9f0ad1d_0 23 | - importlib-metadata=1.7.0=py38h32f6830_0 24 | - importlib_metadata=1.7.0=0 25 | - ipykernel=5.3.4=py38h23f93f0_0 26 | - ipython=7.18.1=py38h1cdfbd6_0 27 | - ipython_genutils=0.2.0=py_1 28 | - jedi=0.17.2=py38h32f6830_0 29 | - jinja2=2.11.2=pyh9f0ad1d_0 30 | - json5=0.9.4=pyh9f0ad1d_0 31 | - jsonschema=3.2.0=py38h32f6830_1 32 | - jupyter_client=6.1.7=py_0 33 | - jupyter_core=4.6.3=py38h32f6830_1 34 | - jupyterlab=2.2.6=py_0 35 | - jupyterlab_server=1.2.0=py_0 36 | - ld_impl_linux-64=2.33.1=h53a641e_7 37 | - libedit=3.1.20191231=h14c3975_1 38 | - libffi=3.3=he6710b0_2 39 | - libgcc-ng=9.1.0=hdf63c60_0 40 | - libsodium=1.0.18=h516909a_0 41 | - libstdcxx-ng=9.1.0=hdf63c60_0 42 | - markupsafe=1.1.1=py38h1e0a361_1 43 | - mistune=0.8.4=py38h1e0a361_1001 44 | - nbconvert=5.6.1=py38h32f6830_1 45 | - nbformat=5.0.7=py_0 46 | - ncurses=6.2=he6710b0_1 47 | - openssl=1.1.1g=h516909a_1 48 | - packaging=20.4=pyh9f0ad1d_0 49 | - pandoc=2.10.1=h516909a_0 50 | - pandocfilters=1.4.2=py_1 51 | - parso=0.7.1=pyh9f0ad1d_0 52 | - pexpect=4.8.0=py38h32f6830_1 53 | - pickleshare=0.7.5=py38h32f6830_1001 54 | - pip=20.1.1=py38_1 55 | - prometheus_client=0.8.0=pyh9f0ad1d_0 56 | - prompt-toolkit=3.0.6=py_0 57 | - ptyprocess=0.6.0=py_1001 58 | - pycparser=2.20=pyh9f0ad1d_2 59 | - pygments=2.6.1=py_0 60 | - pyopenssl=19.1.0=py_1 61 | - pyparsing=2.4.7=pyh9f0ad1d_0 62 | - pyrsistent=0.16.0=py38h1e0a361_0 63 | - pysocks=1.7.1=py38h32f6830_1 64 | - python=3.8.3=hcff3b4d_2 65 | - python-dateutil=2.8.1=py_0 66 | - python_abi=3.8=1_cp38 67 | - pyzmq=19.0.2=py38ha71036d_0 68 | - readline=8.0=h7b6447c_0 69 | - requests=2.24.0=pyh9f0ad1d_0 70 | - send2trash=1.5.0=py_0 71 | - six=1.15.0=pyh9f0ad1d_0 72 | - sqlite=3.32.3=h62c20be_0 73 | - terminado=0.8.3=py38h32f6830_1 74 | - testpath=0.4.4=py_0 75 | - tk=8.6.10=hbc83047_0 76 | - tornado=6.0.4=py38h1e0a361_1 77 | - traitlets=4.3.3=py38h32f6830_1 78 | - urllib3=1.25.10=py_0 79 | - wcwidth=0.2.5=pyh9f0ad1d_1 80 | - webencodings=0.5.1=py_1 81 | - wheel=0.34.2=py38_0 82 | - xz=5.2.5=h7b6447c_0 83 | - zeromq=4.3.2=he1b5a44_3 84 | - zipp=3.1.0=py_0 85 | - zlib=1.2.11=h7b6447c_3 86 | - pip: 87 | - absl-py==0.10.0 88 | - antspyx==0.2.4 89 | - anyascii==0.1.7 90 | - anytree==2.8.0 91 | - appdirs==1.4.4 92 | - ase==3.20.1 93 | - beautifulsoup4==4.9.3 94 | - black==20.8b1 95 | - bs4==0.0.1 96 | - cachetools==4.1.1 97 | - chart-studio==1.1.0 98 | - ci-info==0.2.0 99 | - click==7.1.2 100 | - connected-components-3d==1.7.0 101 | - contractions==0.0.48 102 | - coverage==5.3 103 | - cycler==0.10.0 104 | - dipy==1.1.1 105 | - dominate==2.6.0 106 | - etelemetry==0.2.1 107 | - filelock==3.0.12 108 | - flake8==3.8.4 109 | - flake8-bugbear==20.1.4 110 | - flake8-comprehensions==3.2.3 111 | - flake8-executable==2.0.4 112 | - flake8-polyfill==1.0.2 113 | - flake8-pyi==20.10.0 114 | - future==0.18.2 115 | - gdown==3.12.2 116 | - gensim==4.0.0 117 | - google-auth==1.22.1 118 | - google-auth-oauthlib==0.4.1 119 | - googledrivedownloader==0.4 120 | - grpcio==1.33.1 121 | - h5py==2.10.0 122 | - imageio==2.9.0 123 | - imgaug==0.4.0 124 | - importlab==0.5.1 125 | - inflect==5.3.0 126 | - isodate==0.6.0 127 | - isort==5.6.4 128 | - itk==5.1.1.post1 129 | - itk-core==5.1.1.post1 130 | - itk-filtering==5.1.1.post1 131 | - itk-io==5.1.1.post1 132 | - itk-numerics==5.1.1.post1 133 | - itk-registration==5.1.1.post1 134 | - itk-segmentation==5.1.1.post1 135 | - joblib==0.16.0 136 | - jsonpatch==1.32 137 | - jsonpointer==2.1 138 | - jupyterthemes==0.20.0 139 | - kiwisolver==1.2.0 140 | - lesscpy==0.14.0 141 | - llvmlite==0.34.0 142 | - lxml==4.5.2 143 | - markdown==3.3.2 144 | - matplotlib==3.3.0 145 | - mccabe==0.6.1 146 | - monai==0.3.0 147 | - mypy==0.790 148 | - mypy-extensions==0.4.3 149 | - networkx==2.4 150 | - nibabel==3.1.1 151 | - nilearn==0.6.2 152 | - ninja==1.10.0.post2 153 | - nipype==1.5.0 154 | - nltk==3.5 155 | - notebook==6.1.5 156 | - numba==0.51.2 157 | - numpy==1.19.1 158 | - oauthlib==3.1.0 159 | - opencv-python==4.3.0.36 160 | - pandas==1.1.0 161 | - parameterized==0.7.4 162 | - pathspec==0.8.0 163 | - patsy==0.5.1 164 | - pep8-naming==0.11.1 165 | - pillow==7.2.0 166 | - plotly==4.14.3 167 | - ply==3.11 168 | - protobuf==3.13.0 169 | - prov==1.5.3 170 | - pyahocorasick==1.4.2 171 | - pyasn1==0.4.8 172 | - pyasn1-modules==0.2.8 173 | - pycodestyle==2.6.0 174 | - pydot==1.4.1 175 | - pydotplus==2.0.2 176 | - pyflakes==2.2.0 177 | - pystrum==0.1 178 | - pytorch-ignite==0.4.2 179 | - pytype==2020.10.8 180 | - pytz==2020.1 181 | - pywavelets==1.1.1 182 | - pyyaml==5.3.1 183 | - rdflib==5.0.0 184 | - regex==2020.10.15 185 | - requests-oauthlib==1.3.0 186 | - retrying==1.3.3 187 | - rsa==4.6 188 | - scikit-fmm==0.0.8 189 | - scikit-image==0.17.2 190 | - scikit-learn==0.24.1 191 | - scipy==1.5.2 192 | - seaborn==0.11.0 193 | - setuptools==50.3.0 194 | - shapely==1.7.0 195 | - simpleitk==1.2.4 196 | - simplejson==3.17.2 197 | - sklearn==0.0 198 | - smart-open==5.0.0 199 | - soupsieve==2.2.1 200 | - statsmodels==0.12.1 201 | - tensorboard==2.3.0 202 | - tensorboard-plugin-wit==1.7.0 203 | - textsearch==0.0.21 204 | - threadpoolctl==2.1.0 205 | - tifffile==2020.7.24 206 | - toml==0.10.1 207 | - torch==1.6.0 208 | - torch-cluster==1.5.8 209 | - torch-geometric==1.6.1 210 | - torch-scatter==2.0.5 211 | - torch-sparse==0.6.7 212 | - torch-spline-conv==1.2.0 213 | - torchfile==0.1.0 214 | - torchsummary==1.5.1 215 | - torchvision==0.7.0 216 | - tqdm==4.50.2 217 | - traits==6.1.1 218 | - typed-ast==1.4.1 219 | - typing-extensions==3.7.4.3 220 | - visdom==0.1.8.9 221 | - webcolors==1.11.1 222 | - websocket-client==0.58.0 223 | - werkzeug==1.0.1 224 | - xlrd==1.2.0 225 | prefix: /mnt/sdb2/anaconda3/envs/Python38 226 | -------------------------------------------------------------------------------- /models/tiramisu_model_dyn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from models.tiramisu_layers_dyn import * 4 | 5 | 6 | class FCDenseNetDyn(nn.Module): 7 | # copied from https://github.com/bfortuner/pytorch_tiramisu 8 | # slightly changed to output range from -1 to 1 9 | def __init__(self, in_channels=3, down_blocks=(5,5,5,5,5), 10 | up_blocks=(5,5,5,5,5), bottleneck_layers=5, 11 | growth_rate=16, out_chans_first_conv=48, n_classes=12): 12 | super().__init__() 13 | self.down_blocks = down_blocks 14 | self.up_blocks = up_blocks 15 | cur_channels_count = 0 16 | skip_connection_channel_counts = [] 17 | 18 | ## First Convolution ## 19 | self.add_module('firstconv', nn.Conv2d(in_channels=in_channels, 20 | out_channels=out_chans_first_conv, kernel_size=3, 21 | stride=1, padding=1, bias=True)) 22 | 23 | cur_channels_count = out_chans_first_conv 24 | 25 | ##################### 26 | # Downsampling path # 27 | ##################### 28 | 29 | self.denseBlocksDown = nn.ModuleList([]) 30 | self.transDownBlocks = nn.ModuleList([]) 31 | for i in range(len(down_blocks)): 32 | self.denseBlocksDown.append( 33 | DenseBlock(cur_channels_count, growth_rate, down_blocks[i])) 34 | cur_channels_count += (growth_rate*down_blocks[i]) 35 | skip_connection_channel_counts.insert(0,cur_channels_count) 36 | self.transDownBlocks.append(TransitionDown(cur_channels_count)) 37 | 38 | ##################### 39 | # Bottleneck # 40 | ##################### 41 | 42 | self.add_module('bottleneck',Bottleneck(cur_channels_count, growth_rate, bottleneck_layers)) 43 | prev_block_channels = growth_rate*bottleneck_layers 44 | cur_channels_count += prev_block_channels 45 | 46 | ####################### 47 | # Upsampling path # 48 | ####################### 49 | 50 | self.transUpBlocks = nn.ModuleList([]) 51 | self.denseBlocksUp = nn.ModuleList([]) 52 | for i in range(len(up_blocks)-1): 53 | self.transUpBlocks.append(TransitionUp(prev_block_channels, prev_block_channels)) 54 | cur_channels_count = prev_block_channels + skip_connection_channel_counts[i] 55 | 56 | self.denseBlocksUp.append(DenseBlock( 57 | cur_channels_count, growth_rate, up_blocks[i], 58 | upsample=True)) 59 | prev_block_channels = growth_rate*up_blocks[i] 60 | cur_channels_count += prev_block_channels 61 | 62 | ## Final DenseBlock ## 63 | 64 | self.transUpBlocks.append(TransitionUp( 65 | prev_block_channels, prev_block_channels)) 66 | cur_channels_count = prev_block_channels + skip_connection_channel_counts[-1] 67 | 68 | self.denseBlocksUp.append(DenseBlock( 69 | cur_channels_count, growth_rate, up_blocks[-1], 70 | upsample=False)) 71 | cur_channels_count += growth_rate*up_blocks[-1] 72 | 73 | ## Softmax ## 74 | 75 | self.finalConv = nn.Conv2d(in_channels=cur_channels_count, 76 | out_channels=n_classes, kernel_size=1, stride=1, 77 | padding=0, bias=True) 78 | self.tanh = nn.Tanh() 79 | # self.softmax = nn.LogSoftmax(dim=1) 80 | 81 | # self.controller = nn.Conv2d(5, 15*3*3*48+48, kernel_size=1, stride=1, padding=0) 82 | self.controller = nn.Conv2d(5, 15*48 + 48, kernel_size=1, stride=1, padding=0) 83 | 84 | def forward(self, x, mc, get_dyn_feat=False): # add modality code as additional input 85 | # adapted from https://github.com/jianpengz/DoDNet/blob/main/a_DynConv/unet3D_DynConv882.py 86 | 87 | N, _, H, W = x.size() 88 | x = x.reshape(1, -1, H, W) # 1 x N*15 x 128 x 128 89 | 90 | params = self.controller(mc.unsqueeze(-1).unsqueeze(-1).float()) # 1 x 48*15+48 x 1 x 1 91 | params.squeeze_(-1).squeeze_(-1) # 1 x 48*15+48 92 | # print(params.size()) # N x 768 93 | weight_nums, bias_nums = [], [] 94 | weight_nums.append(15 * 48) 95 | bias_nums.append(48) 96 | sw, sb = parse_dynamic_params(params, weight_nums, bias_nums) 97 | sw = sw.reshape(N, 48, 15).unsqueeze(-1).unsqueeze(-1) 98 | sb = sb.reshape(N, 48) 99 | 100 | """generate all parameters: 5 -> 6000 mapping""" 101 | # N, _, H, W = x.size() 102 | # x = x.reshape(1, -1, H, W) 103 | # out = my_first_conv(x, weights, biases, N) 104 | # out = out.reshape(-1, 48, out.size()[-2], out.size()[-1]) 105 | 106 | w, b = self.firstconv.weight.clone().unsqueeze(0), self.firstconv.bias.clone() # 1 x 48 x 15 x 3 x 3 107 | # This function is differentiable, so gradients will flow back from the result of this operation to input. 108 | # To create a tensor without an autograd relationship to input see detach() 109 | 110 | # check if both scalars and kernels are updated! 111 | # print('======================================') 112 | # print(self.firstconv.weight[0, 0, :, :], self.firstconv.bias[0]) 113 | # print(sw[0, :3, :3, 0, 0]) 114 | 115 | w = w * sw 116 | b = b * sb 117 | 118 | # print(sw.size(), w.size()) 119 | # print(w.size(), b.size()) 120 | 121 | w = w.reshape(-1, 15, 3, 3) 122 | b = b.reshape(-1) 123 | 124 | out = dynamic_head(x, w, b, N) 125 | out = out.reshape(-1, 48, out.size()[-2], out.size()[-1]) 126 | 127 | # dynamic filter place! # 128 | if get_dyn_feat: 129 | dyn_feat = out 130 | 131 | # regular network 132 | skip_connections = [] 133 | for i in range(len(self.down_blocks)): 134 | out = self.denseBlocksDown[i](out) 135 | skip_connections.append(out) 136 | out = self.transDownBlocks[i](out) 137 | 138 | # print(out) 139 | out = self.bottleneck(out) 140 | 141 | # BOTTLENECK PLACE! # deprecated 142 | # print('===============') 143 | # if get_dyn_feat: 144 | # dyn_feat = out 145 | # print('===============') 146 | 147 | for i in range(len(self.up_blocks)): 148 | skip = skip_connections.pop() 149 | out = self.transUpBlocks[i](out, skip) 150 | out = self.denseBlocksUp[i](out) 151 | 152 | out = self.finalConv(out) 153 | out = self.tanh(out) 154 | # out = self.softmax(out) 155 | # print(out.size()) 156 | 157 | if get_dyn_feat: 158 | return dyn_feat, out 159 | return out 160 | 161 | 162 | if __name__ == '__main__': 163 | from torchsummary import summary 164 | net = FCDenseNetDyn(in_channels=15, down_blocks=(4, 4, 4, 4, 4), up_blocks=(4, 4, 4, 4, 4), 165 | bottleneck_layers=4, growth_rate=12, out_chans_first_conv=48, n_classes=1).cuda() 166 | 167 | data = torch.zeros((1, 15, 128, 128)).cuda() 168 | mc = torch.zeros((1, 5)).cuda() 169 | output = net(data, mc) 170 | print('finish running...') 171 | 172 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from . import networks 5 | 6 | 7 | class BaseModel(): 8 | # modify parser to add command line options, 9 | # and also change the default values if needed 10 | @staticmethod 11 | def modify_commandline_options(parser, is_train): 12 | return parser 13 | 14 | def name(self): 15 | return 'BaseModel' 16 | 17 | def initialize(self, opt, model_suffix): 18 | self.opt = opt 19 | self.gpu_ids = opt.gpu_ids 20 | self.isTrain = opt.isTrain 21 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 22 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 23 | self.loss_names = [] 24 | self.model_names = [] 25 | self.visual_names = [] 26 | self.image_paths = [] 27 | self.model_suffix = model_suffix 28 | 29 | def set_input(self, input): 30 | pass 31 | 32 | def forward(self): 33 | pass 34 | 35 | # load and print networks; create schedulers 36 | def setup(self, opt, parser=None): 37 | if self.isTrain: 38 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 39 | if not self.isTrain or opt.continue_train: 40 | load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch 41 | if self.isTrain: 42 | base_dir = os.path.join(opt.checkpoints_dir, opt.base_name) if opt.finetuning else self.save_dir 43 | else: 44 | base_dir = self.save_dir 45 | self.load_networks(load_suffix, base_dir) 46 | self.print_networks(opt.verbose) 47 | 48 | # make models eval mode during test time 49 | def eval(self): 50 | for name in self.model_names: 51 | if isinstance(name, str): 52 | net = getattr(self, 'net' + name) 53 | net.eval() 54 | 55 | def train(self): 56 | for name in self.model_names: 57 | if isinstance(name, str): 58 | net = getattr(self, 'net' + name) 59 | net.train() 60 | 61 | # used in test time, wrapping `forward` in no_grad() so we don't save 62 | # intermediate steps for backprop 63 | def test(self): 64 | with torch.no_grad(): 65 | self.forward() 66 | 67 | def get_val_losses(self): 68 | pass 69 | 70 | # get image paths 71 | def get_image_paths(self): 72 | return self.image_paths 73 | 74 | def optimize_parameters(self): 75 | pass 76 | 77 | # update learning rate (called once every epoch) 78 | def update_learning_rate(self): 79 | for scheduler in self.schedulers: 80 | scheduler.step() 81 | lr = self.optimizers[0].param_groups[0]['lr'] 82 | print('learning rate = %.7f' % lr) 83 | 84 | # return visualization images. train.py will display these images, and save the images to a html 85 | def get_current_visuals(self): 86 | visual_ret = OrderedDict() 87 | for name in self.visual_names: 88 | if isinstance(name, str): 89 | visual_ret[name] = getattr(self, name) 90 | return visual_ret 91 | 92 | # return traning losses/errors. train.py will print out these errors as debugging information 93 | def get_current_losses(self): 94 | errors_ret = OrderedDict() 95 | for name in self.loss_names: 96 | if isinstance(name, str): 97 | # float(...) works for both scalar tensor and float number 98 | errors_ret[name] = float(getattr(self, 'loss_' + name)) 99 | return errors_ret 100 | 101 | # save models to the disk 102 | def save_networks(self, epoch): 103 | for name in self.model_names: 104 | if isinstance(name, str): 105 | save_filename = '%s_net_%s_%s.pth' % (epoch, name, self.model_suffix) 106 | save_path = os.path.join(self.save_dir, save_filename) 107 | net = getattr(self, 'net' + name) 108 | 109 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 110 | # to use hiddenlayer, change net.module.cpu() to be net.cpu() 111 | torch.save(net.module.cpu().state_dict(), save_path) 112 | net.cuda(self.gpu_ids[0]) 113 | else: 114 | torch.save(net.cpu().state_dict(), save_path) 115 | 116 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 117 | key = keys[i] 118 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 119 | if module.__class__.__name__.startswith('InstanceNorm') and \ 120 | (key == 'running_mean' or key == 'running_var'): 121 | if getattr(module, key) is None: 122 | state_dict.pop('.'.join(keys)) 123 | if module.__class__.__name__.startswith('InstanceNorm') and \ 124 | (key == 'num_batches_tracked'): 125 | state_dict.pop('.'.join(keys)) 126 | else: 127 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 128 | 129 | # load models from the disk 130 | def load_networks(self, epoch, base_dir): 131 | for name in self.model_names: 132 | if isinstance(name, str): 133 | load_filename = '%s_net_%s_%s.pth' % (epoch, name, self.model_suffix) 134 | load_path = os.path.join(base_dir, load_filename) 135 | net = getattr(self, 'net' + name) 136 | if isinstance(net, torch.nn.DataParallel): 137 | net = net.module 138 | print('loading the model from %s' % load_path) 139 | # if you are using PyTorch newer than 0.4 (e.g., built from 140 | # GitHub source), you can remove str() on self.device 141 | state_dict = torch.load(load_path, map_location=str(self.device)) 142 | if hasattr(state_dict, '_metadata'): 143 | del state_dict._metadata 144 | 145 | # patch InstanceNorm checkpoints prior to 0.4 146 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 147 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 148 | strict = not self.opt.feature_extract if self.isTrain else True 149 | net.load_state_dict(state_dict, strict) 150 | 151 | # print network information 152 | def print_networks(self, verbose): 153 | print('---------- Networks initialized -------------') 154 | for name in self.model_names: 155 | if isinstance(name, str): 156 | net = getattr(self, 'net' + name) 157 | num_params = 0 158 | for param in net.parameters(): 159 | num_params += param.numel() 160 | if verbose: 161 | print(net) 162 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 163 | print('-----------------------------------------------') 164 | 165 | # set requies_grad=Fasle to avoid computation 166 | def set_requires_grad(self, nets, requires_grad=False): 167 | if not isinstance(nets, list): 168 | nets = [nets] 169 | for net in nets: 170 | if net is not None: 171 | for param in net.parameters(): 172 | param.requires_grad = requires_grad 173 | -------------------------------------------------------------------------------- /data/data_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import shutil 4 | import pickle 5 | import numpy as np 6 | import nibabel as nib 7 | from util.util import mkdir 8 | from util.image_property import hash_file, slice_with_neighborhood 9 | from configurations import * 10 | 11 | 12 | def get_domain_code(dataroot): 13 | code = np.zeros(5) 14 | if dataroot.endswith('ISBI'): 15 | code[0] = 1 16 | elif dataroot.endswith('MICCAI16_1'): 17 | code[1] = 1 18 | elif dataroot.endswith('MICCAI16_2'): 19 | code[2] = 1 20 | elif dataroot.endswith('MICCAI16_3'): 21 | code[3] = 1 22 | elif dataroot.endswith('UMCL'): 23 | code[4] = 1 24 | return code 25 | 26 | 27 | def get_modality_code(modalities): 28 | code = np.zeros(5) 29 | if 't1' in modalities: 30 | code[0] = 1 31 | if 'flair' in modalities: 32 | code[1] = 1 33 | if 't2' in modalities: 34 | code[2] = 1 35 | if 'pd' in modalities: 36 | code[3] = 1 37 | if 'ce' in modalities: 38 | code[4] = 1 39 | return code 40 | 41 | 42 | def remove_folder_if_exist(folder_path): 43 | if os.path.exists(folder_path): 44 | shutil.rmtree(folder_path) 45 | print('%s removed' % folder_path) 46 | else: 47 | print('%s does not exist' % folder_path) 48 | 49 | 50 | class DataGenerator: 51 | def __init__(self): 52 | self.sample_2d_count = 0 53 | self.sample_3d_count = 0 54 | self.dataroot = PATH_DATASET 55 | self.dataset_names = DATASETS 56 | self.dir_this_phase = None 57 | self.axes = [0, 1, 2] 58 | self.rotation_2d = -1 59 | 60 | def build_dataset(self, val_index, test_index, test_phases): 61 | val, test = 'val' in test_phases, 'test' in test_phases 62 | 63 | for dataset_name in self.dataset_names: 64 | print(f"Start generating data for dataset: {dataset_name}...") 65 | self.domain_code = get_domain_code(dataset_name) 66 | self.sample_2d_count, self.sample_3d_count = 0, 0 67 | 68 | with open(os.path.join(self.dataroot, dataset_name, 'properties.json'), 'r') as f: 69 | self.dic_properties = json.load(f) 70 | with open(os.path.join(self.dataroot, dataset_name, 'ids.json'), 'r') as f: 71 | self.dic_ids = json.load(f) 72 | 73 | remove_folder_if_exist(os.path.join(self.dataroot, dataset_name, 'train')) 74 | remove_folder_if_exist(os.path.join(self.dataroot, dataset_name, 'val')) 75 | remove_folder_if_exist(os.path.join(self.dataroot, dataset_name, 'test')) 76 | 77 | all_ids = sorted(list(map(str, self.dic_ids.keys()))) 78 | total_num = len(all_ids) 79 | num_fold = 5 80 | splits = [total_num // num_fold] * num_fold 81 | for i in range(total_num % num_fold): 82 | splits[i] += 1 83 | val_ids = [all_ids[i] for i in range(sum(splits[:val_index]), sum(splits[:val_index+1]))] if val else [] 84 | test_ids = [all_ids[i] for i in range(sum(splits[:test_index]), sum(splits[:test_index+1]))] if test else [] 85 | train_ids = [i for i in all_ids if i not in val_ids + test_ids] 86 | 87 | print(val_ids, test_ids) 88 | self.generate_general_data('val', val_ids, '3d', dataset_name) 89 | self.generate_general_data('test', test_ids, '3d', dataset_name) 90 | self.generate_general_data('train', train_ids, '2d', dataset_name) 91 | 92 | def generate_general_data(self, phase, ids, mode, dataset_name, neighborhood=1): 93 | self.dir_this_phase = os.path.join(self.dataroot, dataset_name, phase) # /UMCL/train 94 | mkdir(self.dir_this_phase) 95 | for subject_id in ids: 96 | timepoints = self.dic_ids[str(subject_id)].keys() 97 | for timepoint in timepoints: 98 | masks = self.dic_ids[str(subject_id)][str(timepoint)]['mask'].keys() 99 | if mode == '3d': 100 | self.sample_3d_count += 1 101 | for mask in masks: 102 | if mode == '3d': 103 | self.generate_3d_data(subject_id, timepoint, mask, phase) 104 | else: 105 | self.generate_2d_data(subject_id, timepoint, mask, neighborhood) 106 | 107 | if mode == '2d': 108 | seq = np.arange(0, self.sample_2d_count) 109 | np.random.shuffle(seq) 110 | # print(seq) 111 | for i, org in enumerate(seq): 112 | file_org = os.path.join(self.dir_this_phase, 'tmp_%d.pkl' % org) 113 | file_new = os.path.join(self.dir_this_phase, '%d.pkl' % i) 114 | shutil.move(file_org, file_new) 115 | 116 | def generate_2d_data(self, subject_id, timepoint, mask, neighborhood): 117 | modalities = self.dic_ids[str(subject_id)][str(timepoint)]['modalities'] # available modalities here 118 | path_mask = self.dic_ids[str(subject_id)][str(timepoint)]['mask'][mask] 119 | image_mask = nib.load(path_mask) 120 | voxel_sizes = nib.affines.voxel_sizes(image_mask.affine) 121 | 122 | # append mask into image_data 123 | image_data = {'mask': np.array(image_mask.get_fdata(), dtype=np.float32)} 124 | 125 | for modality in modalities: 126 | path_modality = self.dic_ids[str(subject_id)][str(timepoint)]['modalities'][modality] 127 | hash_label = hash_file(path_modality) 128 | modality_peak = self.dic_properties[hash_label]['peak'] 129 | image_data[modality] = np.array(nib.load(path_modality).get_fdata() / modality_peak, dtype=np.float32) 130 | 131 | modality_code = get_modality_code(modalities) 132 | data_to_save = {i: [] for i in modalities} 133 | data_to_save['mask'] = [] 134 | for axis in self.axes: 135 | ratio = [k for i, k in enumerate(voxel_sizes) if i != axis] 136 | ratio = ratio[0] / ratio[1] # if there is no rot90 following, it should be ratio[1]/ratio[0] 137 | slices_per_image = image_data['mask'].shape[axis] 138 | print("Slices per image %d, current samples %d" % (slices_per_image, self.sample_2d_count)) 139 | for i in range(slices_per_image): 140 | slice_mask = slice_with_neighborhood(image_data['mask'], axis, i, 0) 141 | if np.count_nonzero(slice_mask) < 2: 142 | continue 143 | data_to_save['mask'] = np.rot90(slice_mask, self.rotation_2d) # legacy due to saving png 144 | for modality in modalities: 145 | slice_modality = slice_with_neighborhood(image_data[modality], axis, i, neighborhood) 146 | data_to_save[modality] = np.rot90(slice_modality, self.rotation_2d) 147 | 148 | data_to_save['ratio'] = ratio 149 | data_to_save['dc'] = self.domain_code 150 | data_to_save['mc'] = modality_code 151 | 152 | with open(os.path.join(self.dir_this_phase, 'tmp_%d.pkl' % self.sample_2d_count), 'wb') as f: 153 | pickle.dump(data_to_save, f) 154 | 155 | self.sample_2d_count += 1 156 | 157 | def generate_3d_data(self, subject_id, timepoint, mask, phase): 158 | modalities = self.dic_ids[str(subject_id)][str(timepoint)]['modalities'] 159 | for modality in modalities: 160 | path_src = self.dic_ids[str(subject_id)][str(timepoint)]['modalities'][modality] 161 | path_dst = os.path.join(self.dir_this_phase, mask + '_%s%d_%s.%s' % (phase, self.sample_3d_count, modality, SUFFIX)) 162 | shutil.copyfile(path_src, path_dst) 163 | path_src_mask = self.dic_ids[str(subject_id)][str(timepoint)]['mask'][mask] 164 | path_dst_mask = os.path.join(self.dir_this_phase, mask + '_%s%d_mask.%s' % (phase, self.sample_3d_count, SUFFIX)) 165 | shutil.copyfile(path_src_mask, path_dst_mask) 166 | 167 | 168 | if __name__ == '__main__': 169 | seed = 10 170 | os.environ['PYTHONHASHSEED'] = str(seed) 171 | np.random.seed(seed) 172 | dataroot = '/media/liuhan/HanLiu/ms_data/MICCAI16_3' -------------------------------------------------------------------------------- /data/ms_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | import torchvision.transforms as transforms 4 | import torch.nn.functional as F 5 | import torch 6 | import cv2 7 | import numpy as np 8 | from data.base_dataset import BaseDataset 9 | from configurations import * 10 | import copy 11 | from albumentations.augmentations.functional import grid_distortion 12 | import matplotlib.pyplot as plt 13 | 14 | 15 | def get_2d_paths(dir): 16 | arrays = [] 17 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 18 | 19 | for root, _, fnames in sorted(os.walk(dir)): 20 | for fname in fnames: 21 | if fname.endswith('.pkl'): 22 | path = os.path.join(root, fname) 23 | arrays.append(path) 24 | 25 | return arrays 26 | 27 | 28 | def augmentations(data, ratio, opt): 29 | height, width = data['mask'].shape[:2] # height/y for first axis, width/x for second axis 30 | for axis in [1, 0]: 31 | if random.random() < 0.3: 32 | for modality in MODALITIES + ['mask']: 33 | data[modality] = np.flip(data[modality], axis).copy() 34 | 35 | if random.random() < 0.5: 36 | height, width = width, height 37 | for modality in MODALITIES: 38 | data[modality] = np.transpose(data[modality], (1, 0, 2)) 39 | data['mask'] = np.transpose(data['mask'], (1, 0)) 40 | need_resize = False 41 | if random.random() < 0: 42 | crop_size = random.randint(int(opt.trainSize / 1.5), min(height, width)) 43 | need_resize = True 44 | else: 45 | crop_size = opt.trainSize 46 | 47 | mask = data['mask'] 48 | if np.sum(mask) == 0 or random.random() < 0.005: 49 | x_min = random.randint(0, width - crop_size) 50 | y_min = random.randint(0, height - crop_size) 51 | else: 52 | non_zero_yx = np.argwhere(mask) 53 | y, x = random.choice(non_zero_yx) 54 | x_min = x - random.randint(0, crop_size - 1) 55 | y_min = y - random.randint(0, crop_size - 1) 56 | x_min = np.clip(x_min, 0, width - crop_size) 57 | y_min = np.clip(y_min, 0, height - crop_size) 58 | 59 | for modality in MODALITIES + ['mask']: 60 | interpolation = cv2.INTER_LINEAR 61 | data[modality] = data[modality][y_min:y_min + crop_size, x_min:x_min + crop_size] 62 | if need_resize: 63 | data[modality] = cv2.resize(data[modality], (opt.trainSize, opt.trainSize), interpolation) 64 | 65 | data['mask'] = (data['mask'] > 0.5).astype(np.float32) 66 | return data 67 | 68 | 69 | #======================================================================================================================= 70 | # Code description 71 | # This class is used for: 72 | # (1) independent models (fixed combination of modalities) 73 | # (2) ModDrop: regular modality dropout 74 | # (3) ModDrop+: dynamic filter network ONLY (without intra-subject co-training) 75 | 76 | """ 77 | class MsDataset(BaseDataset): 78 | @staticmethod 79 | def modify_commandline_options(parser, is_train): 80 | return parser 81 | 82 | def initialize(self, opt): 83 | self.opt = opt 84 | self.use_modality_dropout = opt.use_modality_dropout 85 | self.all_paths = [] 86 | for dataset_name in DATASETS: 87 | self.dir_data = os.path.join(opt.dataroot, dataset_name, opt.phase) 88 | self.all_paths += sorted(get_2d_paths(self.dir_data)) 89 | 90 | def __getitem__(self, index): 91 | path_this_sample = self.all_paths[index] 92 | data_all_modalities = np.load(path_this_sample, allow_pickle=True) 93 | # store the available modalities in a list 94 | data_return = {'paths': path_this_sample} 95 | available = [] 96 | 97 | for modality in MODALITIES: 98 | if modality in data_all_modalities: 99 | available.append(modality) 100 | data_return[modality] = data_all_modalities[modality] 101 | else: 102 | data_return[modality] = np.zeros(data_all_modalities['t1'].shape) 103 | 104 | data_return['mask'] = data_all_modalities['mask'][:, :, 0] 105 | 106 | # augmentation 107 | data_return = augmentations(data_return, data_all_modalities['ratio'], self.opt) 108 | 109 | # preprocessing 110 | for modality in available: 111 | data_return[modality] = data_return[modality] / 2 - 1 112 | data_return['mask'] = data_return['mask'] * 2 - 1 113 | data_return['dc'] = data_all_modalities['dc'] 114 | data_return['mc'] = data_all_modalities['mc'] 115 | 116 | for modality in MODALITIES: 117 | data_return[modality] = transforms.ToTensor()(data_return[modality]).float() 118 | 119 | # ======== modality dropout ======== 120 | if self.use_modality_dropout: 121 | mc_idx = list(np.where(data_return['mc'] == 1)[0]) 122 | zero_idx = random.sample(mc_idx, random.randint(0, len(mc_idx)-1)) 123 | for idx in zero_idx: 124 | # image set as zero tensor 125 | data_return[MODALITIES[idx]] = torch.zeros(data_return[MODALITIES[idx]].size()) 126 | data_return['mc'][idx] = 0 # modality code set as 0 127 | 128 | return data_return 129 | 130 | def __len__(self): 131 | return len(self.all_paths) 132 | 133 | def name(self): 134 | return 'MsDataset' 135 | """ 136 | 137 | 138 | #======================================================================================================================= 139 | # Code description: 140 | # This class is used for ModDrop++: (1) dynamic filter network and (2) intra-subject co-training. This dataloader 141 | # returns both (1) full-modality data and (2) missing modality data (randomly dropped) from the same subject. 142 | 143 | 144 | class MsDataset(BaseDataset): 145 | @staticmethod 146 | def modify_commandline_options(parser, is_train): 147 | return parser 148 | 149 | def initialize(self, opt): 150 | self.opt = opt 151 | self.use_modality_dropout = opt.use_modality_dropout 152 | self.all_paths = [] 153 | for dataset_name in DATASETS: 154 | self.dir_data = os.path.join(opt.dataroot, dataset_name, opt.phase) 155 | self.all_paths += sorted(get_2d_paths(self.dir_data)) 156 | 157 | def __getitem__(self, index): 158 | path_this_sample = self.all_paths[index] 159 | data_all_modalities = np.load(path_this_sample, allow_pickle=True) 160 | # store the available modalities in a list 161 | data_full = {'paths': path_this_sample} 162 | available = [] 163 | 164 | for modality in MODALITIES: 165 | if modality in data_all_modalities: 166 | available.append(modality) 167 | data_full[modality] = data_all_modalities[modality] 168 | else: 169 | data_full[modality] = np.zeros(data_all_modalities['t1'].shape) 170 | data_full['mask'] = data_all_modalities['mask'][:, :, 0] 171 | 172 | # augmentation 173 | data_full = augmentations(data_full, data_all_modalities['ratio'], self.opt) 174 | 175 | # preprocessing 176 | for modality in available: 177 | data_full[modality] = data_full[modality] / 2 - 1 178 | 179 | data_full['mask'] = data_full['mask'] * 2 - 1 180 | data_full['dc'] = data_all_modalities['dc'] 181 | data_full['mc'] = data_all_modalities['mc'] 182 | 183 | for modality in MODALITIES: 184 | data_full[modality] = transforms.ToTensor()(data_full[modality]).float() 185 | data_miss = copy.deepcopy(data_full) 186 | # === modality dropout === 187 | if self.use_modality_dropout: 188 | mc_idx = list(np.where(data_miss['mc'] == 1)[0]) 189 | zero_idx = random.sample(mc_idx, random.randint(0, len(mc_idx) - 1)) 190 | for idx in zero_idx: 191 | data_miss[MODALITIES[idx]] = torch.zeros(data_miss[MODALITIES[idx]].size()) # image set as zero tensor 192 | data_miss['mc'][idx] = 0 # modality code set as 0 193 | 194 | return data_full, data_miss 195 | 196 | def __len__(self): 197 | return len(self.all_paths) 198 | 199 | def name(self): 200 | return 'MsDataset' -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import copy 4 | import cv2 5 | import numpy as np 6 | import nibabel as nib 7 | import torch.nn.functional as F 8 | import matplotlib.pyplot as plt 9 | from options.test_options import TestOptions 10 | from data import CreateDataLoader 11 | from models import create_model 12 | from collections import OrderedDict, defaultdict 13 | from skimage import measure 14 | from scipy.stats import pearsonr 15 | from configurations import * 16 | from util.image_property import hash_file 17 | import torch.multiprocessing 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | 20 | 21 | def pad_images(opt_test, *image_list): 22 | padNum = -1 23 | pad_pos_y = (opt_test.testSize - image_list[0][0].shape[-2]) // 2 24 | pad_pos_x = (opt_test.testSize - image_list[0][0].shape[-1]) // 2 25 | pad_param = [pad_pos_x, opt_test.testSize - image_list[0].shape[-1] - pad_pos_x, 26 | pad_pos_y, opt_test.testSize - image_list[0].shape[-2] - pad_pos_y] 27 | 28 | var_return = [] 29 | image_list = list(image_list) 30 | for one_image in image_list: 31 | pad_image = F.pad(one_image, pad_param, 'constant', padNum) 32 | var_return += [pad_image] 33 | 34 | sl = [slice(None)] * 2 35 | sl[0] = slice(pad_pos_y, pad_pos_y + image_list[0][0].shape[-2], 1) 36 | sl[1] = slice(pad_pos_x, pad_pos_x + image_list[0][0].shape[-1], 1) 37 | var_return += [tuple(sl)] 38 | return var_return 39 | 40 | 41 | def seg_metrics(seg_vol, truth_vol, output_errors=False): 42 | time_start = time.time() 43 | seg_total = np.sum(seg_vol) 44 | truth_total = np.sum(truth_vol) 45 | tp = np.sum(seg_vol[truth_vol == 1]) 46 | dice = 2 * tp / (seg_total + truth_total) 47 | ppv = tp / (seg_total + 0.001) 48 | tpr = tp / (truth_total + 0.001) 49 | vd = abs(seg_total - truth_total) / truth_total 50 | 51 | # calculate LFPR 52 | seg_labels, seg_num = measure.label(seg_vol, return_num=True, connectivity=2) 53 | lfp_cnt = 0 54 | # tmp_cnt = 0 55 | for label in range(1, seg_num + 1): 56 | # tmp_cnt = np.sum(seg_vol[seg_labels == label]) 57 | if np.sum(truth_vol[seg_labels == label]) == 0: # was == 58 | lfp_cnt += 1 59 | lfpr = lfp_cnt / (seg_num + 0.001) 60 | 61 | # calculate LTPR 62 | truth_labels, truth_num = measure.label(truth_vol, return_num=True, connectivity=2) 63 | ltp_cnt = 0 64 | for label in range(1, truth_num + 1): 65 | if np.sum(seg_vol[truth_labels == label]) > 0: 66 | ltp_cnt += 1 67 | ltpr = ltp_cnt / truth_num 68 | 69 | # calculate Pearson's correlation coefficient 70 | corr = pearsonr(seg_vol.flatten(), truth_vol.flatten())[0] 71 | # print("Timed used calculating metrics: ", time.time() - time_start) 72 | 73 | return OrderedDict([('dice', dice), ('ppv', ppv), ('tpr', tpr), ('lfpr', lfpr), 74 | ('ltpr', ltpr), ('vd', vd), ('corr', corr)]) 75 | 76 | 77 | def print_metrics(prefix, metrics): 78 | message = prefix + ' ' 79 | for k, v in metrics.items(): 80 | message += '%s: %.3f ' % (k, v) 81 | print(message) 82 | 83 | 84 | def model_test(models, dataset_test, opt_test, num_test, save_images=False, models_weight=None, 85 | mask_suffix='pred', save_membership=False): 86 | if not num_test: 87 | print("no %s subjects" % opt_test.phase) 88 | assert len(models), "no models loaded" 89 | 90 | start_time = time.time() 91 | orientations = ['axial', 'sagittal', 'coronal'] 92 | transpose = {2: (1, 2, 0), 0: (0, 1, 2), 1: (1, 0, 2)} 93 | orientation_weight = [1, 1, 1] 94 | ret_metrics = defaultdict(float) 95 | metrics = [] 96 | 97 | dict_results = {} 98 | for i, data in enumerate(dataset_test): 99 | if i >= num_test: 100 | break 101 | 102 | mask, mask_path, alt_path = data['mask'], data['mask_paths'][0], data['alt_paths'][0] 103 | dc, mc = data['dc'], data['mc'] 104 | basename = os.path.basename(data['alt_paths'][0]) 105 | basename = basename[:len(basename) - len(MODALITIES[0]) - len(SUFFIX) - 1] 106 | 107 | hash_label = hash_file(alt_path) 108 | if hash_label not in dict_results: 109 | mask_pred = 0 110 | for k, orientation in enumerate(orientations): 111 | mask_cur_orientation = [] 112 | num_slices = len(data[MODALITIES[0]][orientation]) 113 | org_size = data['org_size'][orientation] 114 | interpolation = cv2.INTER_LINEAR 115 | for j in range(num_slices): 116 | pad_data ={} 117 | pad_data['mask'] = torch.zeros_like(data[MODALITIES[0]][orientation][j]) 118 | for modality in MODALITIES: 119 | pad_data[modality] = data[modality][orientation][j] 120 | 121 | slice_all_models = 0 122 | for m, current_model in enumerate(models): 123 | m_input = {mod: pad_data[mod] for mod in MODALITIES + ['mask']} 124 | m_input['dc'] = dc 125 | m_input['mc'] = mc 126 | current_model.set_input(m_input) 127 | current_model.test() 128 | current_visuals = current_model.get_current_visuals() # inference 129 | weight_this_model = 1 if models_weight is None else models_weight[m] 130 | slice_this_model = np.squeeze(current_visuals['fake_mask'].cpu().numpy()) # DYN 131 | slice_all_models += slice_this_model * weight_this_model 132 | 133 | numerator = len(models) if models_weight is None else np.sum(models_weight) 134 | slice_all_models = np.array(slice_all_models) / numerator 135 | slice_all_models = np.squeeze(slice_all_models + 1) / 2 136 | slice_all_models = cv2.resize(slice_all_models, (int(org_size[1][0].cpu().numpy()), int(org_size[0][0].cpu().numpy())), interpolation) 137 | mask_cur_orientation.append(slice_all_models) 138 | 139 | mask_pred += np.transpose(np.squeeze(mask_cur_orientation), transpose[AXIS_TO_TAKE[k]]) * \ 140 | orientation_weight[k] 141 | mask_pred = np.array(mask_pred) / np.sum(orientation_weight) 142 | dict_results[hash_label] = mask_pred 143 | 144 | alt_image = nib.load(alt_path) 145 | if save_membership: 146 | mask_membership_name = alt_path.replace('%s.%s' % (MODALITIES[0], SUFFIX), 147 | 'membership_%s.%s' % (mask_suffix, SUFFIX)) 148 | nib.Nifti1Image(mask_pred, alt_image.affine, alt_image.header).to_filename(mask_membership_name) 149 | else: 150 | mask_pred = dict_results[hash_label] 151 | 152 | mask_pred = (mask_pred > 0.5).astype(np.int8) 153 | if os.path.exists(mask_path): 154 | mask_data = nib.load(mask_path).get_fdata().astype(np.int8) 155 | res_this_mask = seg_metrics(mask_pred, mask_data, output_errors=False) 156 | metrics = list(res_this_mask.keys()) 157 | for k in metrics: 158 | ret_metrics[k] += res_this_mask[k] 159 | print_metrics('processed ' + basename + '*,', res_this_mask) 160 | else: 161 | print('processed ' + basename + '*') 162 | 163 | if save_images: 164 | mask_pred_name = alt_path.replace('%s.%s' % (MODALITIES[0], SUFFIX), 'pred_%s.%s' % (mask_suffix, SUFFIX)) 165 | nib.Nifti1Image(mask_pred, alt_image.affine, alt_image.header).to_filename(mask_pred_name) 166 | 167 | for k in metrics: 168 | ret_metrics[k] = ret_metrics[k] / num_test if num_test != 0 else ret_metrics[k] 169 | print("time used for validation: ", time.time() - start_time) 170 | return ret_metrics 171 | 172 | 173 | if __name__ == '__main__': 174 | opt_test = TestOptions().parse() 175 | 176 | # hard-code some parameters for test 177 | opt_test.num_threads = 1 # test code only supports num_threads = 1 178 | opt_test.batch_size = 1 # test code only supports batch_size = 1 179 | opt_test.serial_batches = True # no shuffle 180 | opt_test.no_flip = True # no flip 181 | opt_test.display_id = -1 # no visdom display 182 | opt_test.dataset_mode = 'ms_3d' 183 | data_loader = CreateDataLoader(opt_test) 184 | dataset_test = data_loader.load_data() 185 | 186 | models = [] 187 | models_indx = opt_test.load_str.split(',') 188 | models_weight = [1] * len(models_indx) 189 | for i in models_indx: 190 | current_model = create_model(opt_test, i) 191 | current_model.setup(opt_test) 192 | if opt_test.eval: 193 | current_model.eval() 194 | models.append(current_model) 195 | 196 | losses = model_test(models, dataset_test, opt_test, len(data_loader), save_images=True, 197 | models_weight=models_weight, mask_suffix=opt_test.name, save_membership=False) 198 | print_metrics('test results', losses) 199 | 200 | -------------------------------------------------------------------------------- /models/tiramisu_model.py: -------------------------------------------------------------------------------- 1 | from models.tiramisu_layers import * 2 | 3 | 4 | class FCDenseNet(nn.Module): 5 | # copied from https://github.com/bfortuner/pytorch_tiramisu 6 | # slightly changed to output range from -1 to 1 7 | def __init__(self, in_channels=3, down_blocks=(5,5,5,5,5), 8 | up_blocks=(5,5,5,5,5), bottleneck_layers=5, 9 | growth_rate=16, out_chans_first_conv=48, n_classes=12): 10 | super().__init__() 11 | self.down_blocks = down_blocks 12 | self.up_blocks = up_blocks 13 | cur_channels_count = 0 14 | skip_connection_channel_counts = [] 15 | 16 | ## First Convolution ## 17 | 18 | self.add_module('firstconv', nn.Conv2d(in_channels=in_channels, 19 | out_channels=out_chans_first_conv, kernel_size=3, 20 | stride=1, padding=1, bias=True)) 21 | cur_channels_count = out_chans_first_conv 22 | 23 | ##################### 24 | # Downsampling path # 25 | ##################### 26 | 27 | self.denseBlocksDown = nn.ModuleList([]) 28 | self.transDownBlocks = nn.ModuleList([]) 29 | for i in range(len(down_blocks)): 30 | self.denseBlocksDown.append( 31 | DenseBlock(cur_channels_count, growth_rate, down_blocks[i])) 32 | cur_channels_count += (growth_rate*down_blocks[i]) 33 | skip_connection_channel_counts.insert(0,cur_channels_count) 34 | self.transDownBlocks.append(TransitionDown(cur_channels_count)) 35 | 36 | ##################### 37 | # Bottleneck # 38 | ##################### 39 | 40 | self.add_module('bottleneck',Bottleneck(cur_channels_count, 41 | growth_rate, bottleneck_layers)) 42 | prev_block_channels = growth_rate*bottleneck_layers 43 | cur_channels_count += prev_block_channels 44 | 45 | ####################### 46 | # Upsampling path # 47 | ####################### 48 | 49 | self.transUpBlocks = nn.ModuleList([]) 50 | self.denseBlocksUp = nn.ModuleList([]) 51 | for i in range(len(up_blocks)-1): 52 | self.transUpBlocks.append(TransitionUp(prev_block_channels, prev_block_channels)) 53 | cur_channels_count = prev_block_channels + skip_connection_channel_counts[i] 54 | 55 | self.denseBlocksUp.append(DenseBlock( 56 | cur_channels_count, growth_rate, up_blocks[i], 57 | upsample=True)) 58 | prev_block_channels = growth_rate*up_blocks[i] 59 | cur_channels_count += prev_block_channels 60 | 61 | ## Final DenseBlock ## 62 | 63 | self.transUpBlocks.append(TransitionUp( 64 | prev_block_channels, prev_block_channels)) 65 | cur_channels_count = prev_block_channels + skip_connection_channel_counts[-1] 66 | 67 | self.denseBlocksUp.append(DenseBlock( 68 | cur_channels_count, growth_rate, up_blocks[-1], 69 | upsample=False)) 70 | cur_channels_count += growth_rate*up_blocks[-1] 71 | 72 | ## Softmax ## 73 | 74 | self.finalConv = nn.Conv2d(in_channels=cur_channels_count, 75 | out_channels=n_classes, kernel_size=1, stride=1, 76 | padding=0, bias=True) 77 | self.tanh = nn.Tanh() 78 | # self.softmax = nn.LogSoftmax(dim=1) 79 | 80 | def forward(self, x): 81 | out = self.firstconv(x) 82 | 83 | skip_connections = [] 84 | for i in range(len(self.down_blocks)): 85 | out = self.denseBlocksDown[i](out) 86 | skip_connections.append(out) 87 | out = self.transDownBlocks[i](out) 88 | 89 | out = self.bottleneck(out) 90 | for i in range(len(self.up_blocks)): 91 | skip = skip_connections.pop() 92 | out = self.transUpBlocks[i](out, skip) 93 | out = self.denseBlocksUp[i](out) 94 | 95 | out = self.finalConv(out) 96 | out = self.tanh(out) 97 | # out = self.softmax(out) 98 | return out 99 | 100 | 101 | 102 | ## 103 | ############ FOR Teacher Student ############ 104 | # class FCDenseNet(nn.Module): 105 | # # copied from https://github.com/bfortuner/pytorch_tiramisu 106 | # # slightly changed to output range from -1 to 1 107 | # def __init__(self, in_channels=3, down_blocks=(5,5,5,5,5), 108 | # up_blocks=(5,5,5,5,5), bottleneck_layers=5, 109 | # growth_rate=16, out_chans_first_conv=48, n_classes=12): 110 | # super().__init__() 111 | # self.down_blocks = down_blocks 112 | # self.up_blocks = up_blocks 113 | # cur_channels_count = 0 114 | # skip_connection_channel_counts = [] 115 | 116 | # ## First Convolution ## 117 | 118 | # self.add_module('firstconv', nn.Conv2d(in_channels=in_channels, 119 | # out_channels=out_chans_first_conv, kernel_size=3, 120 | # stride=1, padding=1, bias=True)) 121 | # cur_channels_count = out_chans_first_conv 122 | 123 | # ##################### 124 | # # Downsampling path # 125 | # ##################### 126 | 127 | # self.denseBlocksDown = nn.ModuleList([]) 128 | # self.transDownBlocks = nn.ModuleList([]) 129 | # for i in range(len(down_blocks)): 130 | # self.denseBlocksDown.append( 131 | # DenseBlock(cur_channels_count, growth_rate, down_blocks[i])) 132 | # cur_channels_count += (growth_rate*down_blocks[i]) 133 | # skip_connection_channel_counts.insert(0,cur_channels_count) 134 | # self.transDownBlocks.append(TransitionDown(cur_channels_count)) 135 | 136 | # ##################### 137 | # # Bottleneck # 138 | # ##################### 139 | 140 | # self.add_module('bottleneck',Bottleneck(cur_channels_count, 141 | # growth_rate, bottleneck_layers)) 142 | 143 | # prev_block_channels = growth_rate*bottleneck_layers 144 | # cur_channels_count += prev_block_channels 145 | 146 | # ####################### 147 | # # Upsampling path # 148 | # ####################### 149 | 150 | # self.transUpBlocks = nn.ModuleList([]) 151 | # self.denseBlocksUp = nn.ModuleList([]) 152 | # for i in range(len(up_blocks)-1): 153 | # self.transUpBlocks.append(TransitionUp(prev_block_channels, prev_block_channels)) 154 | # cur_channels_count = prev_block_channels + skip_connection_channel_counts[i] 155 | 156 | # self.denseBlocksUp.append(DenseBlock( 157 | # cur_channels_count, growth_rate, up_blocks[i], 158 | # upsample=True)) 159 | # prev_block_channels = growth_rate*up_blocks[i] 160 | # cur_channels_count += prev_block_channels 161 | 162 | # ## Final DenseBlock ## 163 | 164 | # self.transUpBlocks.append(TransitionUp( 165 | # prev_block_channels, prev_block_channels)) 166 | # cur_channels_count = prev_block_channels + skip_connection_channel_counts[-1] 167 | 168 | # self.denseBlocksUp.append(DenseBlock( 169 | # cur_channels_count, growth_rate, up_blocks[-1], 170 | # upsample=False)) 171 | # cur_channels_count += growth_rate*up_blocks[-1] 172 | 173 | # ## Softmax ## 174 | 175 | # self.finalConv = nn.Conv2d(in_channels=cur_channels_count, 176 | # out_channels=n_classes, kernel_size=1, stride=1, 177 | # padding=0, bias=True) 178 | # self.tanh = nn.Tanh() 179 | # # self.softmax = nn.LogSoftmax(dim=1) 180 | 181 | # def forward(self, x): 182 | # out = self.firstconv(x) 183 | 184 | # dyn_feat = out # DYN place 185 | 186 | # skip_connections = [] 187 | 188 | # for i in range(len(self.down_blocks)): 189 | # out = self.denseBlocksDown[i](out) 190 | # skip_connections.append(out) 191 | # out = self.transDownBlocks[i](out) 192 | 193 | # out = self.bottleneck(out) 194 | 195 | # # dyn_feat = out # BOTTLENECK 196 | 197 | # for i in range(len(self.up_blocks)): 198 | # skip = skip_connections.pop() 199 | # out = self.transUpBlocks[i](out, skip) 200 | # out = self.denseBlocksUp[i](out) 201 | 202 | # out = self.finalConv(out) 203 | # out = self.tanh(out) 204 | # # out = self.softmax(out) 205 | # return dyn_feat, out 206 | 207 | 208 | 209 | # def FCDenseNet57(n_classes): 210 | # return FCDenseNet( 211 | # in_channels=3, down_blocks=(4, 4, 4, 4, 4), 212 | # up_blocks=(4, 4, 4, 4, 4), bottleneck_layers=4, 213 | # growth_rate=12, out_chans_first_conv=48, n_classes=n_classes) 214 | 215 | 216 | # def FCDenseNet67(n_classes): 217 | # return FCDenseNet( 218 | # in_channels=3, down_blocks=(5, 5, 5, 5, 5), 219 | # up_blocks=(5, 5, 5, 5, 5), bottleneck_layers=5, 220 | # growth_rate=16, out_chans_first_conv=48, n_classes=n_classes) 221 | 222 | 223 | # def FCDenseNet103(n_classes): 224 | # return FCDenseNet( 225 | # in_channels=3, down_blocks=(4,5,7,10,12), 226 | # up_blocks=(12,10,7,5,4), bottleneck_layers=15, 227 | # growth_rate=16, out_chans_first_conv=48, n_classes=n_classes) 228 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![](https://img.shields.io/badge/Language-python-brightgreen.svg) 2 | [![](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/han-liu/ModDropPlusPlus/blob/main/LICENSE) 3 | 4 | 5 | # ModDrop++ 6 | 7 | The repository is the official PyTorch implementation of the paper "ModDrop++: A Dynamic Filter Network with Intra-subject Co-training for Multiple Sclerosis Lesion Segmentation with Missing Modalities". [[paper]](https://arxiv.org/pdf/2203.04959.pdf) 8 | 9 | ![gif](https://github.com/han-liu/ModDropPlusPlus/blob/main/dynamic_head.gif) 10 | 11 | Modality Dropout (ModDrop) has been widely used as an effective training scheme to train a unified model that can be self-adaptive to different missing conditions. However, the classic ModDrop suffers from two limitations: (1) regardless of different missing conditions, it always forces the network to learn a single set of parameters and thus may limit the expressiveness of the network and (2) ModDrop does not leverage the intra-subject relation between full- and missing-modality data. To address these two limitations, the proposed ModDrop++ incoportates (1) a plug-and-play dynamic head and (2) an intra-subject co-training strategy to upgrade the ModDrop. ModDrop++ has been developed and implemented based on the [2.5D Tiramisu model](https://github.com/MedICL-VU/LesionSeg), which achieved the state-of-the-art performance for MS lesion segmentation on the ISBI 2015 challenge. 12 | 13 | The trained models for UMCL and ISBI datasets are available [here](https://drive.google.com/drive/folders/1_g-OdFeCPtzYRL9UjH8gTami1uGqU7D4?usp=sharing). 14 | 15 | If you find our code/paper helpful for your research, please consider citing our work: 16 | ``` 17 | @inproceedings{liu2022moddrop++, 18 | title={ModDrop++: A Dynamic Filter Network with Intra-subject Co-training for Multiple Sclerosis Lesion Segmentation with Missing Modalities}, 19 | author={Liu, Han and Fan, Yubo and Li, Hao and Wang, Jiacheng and Hu, Dewei and Cui, Can and Lee, Ho Hin and Zhang, Huahong and Oguz, Ipek}, 20 | booktitle={Medical Image Computing and Computer Assisted Intervention--MICCAI 2022: 25th International Conference, Singapore, September 18--22, 2022, Proceedings, Part V}, 21 | pages={444--453}, 22 | year={2022}, 23 | organization={Springer} 24 | } 25 | ``` 26 | If you have any questions, feel free to contact han.liu@vanderbilt.edu or open an Issue in this repo. 27 | 28 | ## Prerequisites 29 | * NVIDIA GPU + CUDA + cuDNN 30 | 31 | ## Installation 32 | We suggest installing the dependencies using Anaconda 33 | * create the environment and activate (replace DL with your environment name) 34 | ```shell script 35 | conda create --name DL python=3.8 36 | ``` 37 | * Install PyTorch with the official guide from http://pytorch.org (we used the CUDA version 10.0), 38 | and then install other dependencies: 39 | ```shell script 40 | conda activate DL 41 | conda install pytorch torchvision cudatoolkit=10.0 -c pytorch 42 | conda install nibabel statsmodels visdom jsonpatch dominate scikit-image -c conda-forge -c anaconda 43 | ``` 44 | 45 | ## Datasets 46 | 47 | You can put your images in any folder, but to run this model, 48 | we assume the format of file names in the pattern like {PREFIX}\_{PATIENT-ID}\_{TIMEPOINT-ID}\_{MODALITY(MASK)}.{SUFFIX}, 49 | e.g. training_01_01_flair.nii, training_03_05_mask1.nii : 50 | ``` 51 | PREFIX: can be any string 52 | PATIENT-ID: number id of the patient 53 | TIMEPOINT-ID: number id of the timepoint 54 | MODALITY, MASK: t1, flair, t2, pd, mask1, mask2, etc 55 | SUFFIX: nii, nii.gz, etc 56 | ``` 57 | 58 | You need to specify the name of MODALITIES, MASKS and SUFFIX in the configuration.py under the root folder. 59 | 60 | ## Data Conversion 61 | 62 | This data conversion is to convert the different kinds of datasets into the same structure. 63 | To illustrate the process of training and validation, we use the ISBI-2015 challenge data as an example. 64 | The steps are as follows: 65 | 1. Copy all of the preprocessed training data of ISBI challenge dataset into the target dataset folder. 66 | The folder is ~/Documents/Datasets/sample_dataset 67 | 2. Modify the configuration file (configuration.py). The positions that need attention are marked with TODO. 68 | * The dataset folder, e.g. PATH_DATASET = os.path.join(os.path.expanduser('~'), 'Documents', 'Datasets', 'sample_dataset'). 69 | * The modalities available with this dataset, e.g. MODALITIES = ['t1', 'flair', 't2', 'pd']. 70 | * The delineations available with this dataset, e.g. MASKS = ['mask1', 'mask2']. 71 | * The suffix of the files, usually 'nii' or 'nii.gz' 72 | * (This is not necessary for now) The axis corresponding to axial, sagittal and coronal, respectively. 73 | For the ISBI dataset, it is [2, 0, 1]. 74 | 3. Rename all the files to comply with the pattern as mentioned in the Dataset section. For example, since we use 't1' 75 | in the MODALITIES, simply rename 'training01_01_mprage_pp.nii' to 'training_01_01_t1.nii'. 76 | The following commands that might be useful: 77 | ```shell script 78 | rename 's/training/training_/g' *; rename 's/_pp.nii/.nii/g' *; rename 's/mprage/t1/g' *; 79 | ``` 80 | 4. Run the data\_conversion file. This function will move your files under the sample_dataset folder 81 | into its subfolder ./raw and generate two JSON files (ids.json, properties.json). 82 | 83 | The ids.json contains the new data paths based on PATIENT-ID, TIMEPOINT-ID, MODALITY or MASK. 84 | It is needed for cross-validation to avoid splitting the scan from the same patient into 85 | both Train and Val (or Test) folder. An example is 86 | ``` 87 | "1": { 88 | "1": { 89 | "modalities": { 90 | "t1": "~/Documents/Datasets/sample_dataset/raw/training_01_01_t1.nii", 91 | "flair": "~/Documents/Datasets/sample_dataset/raw/training_01_01_flair.nii", 92 | "t2": "~/Documents/Datasets/sample_dataset/raw/training_01_01_t2.nii", 93 | "pd": "~/Documents/Datasets/sample_dataset/raw/training_01_01_pd.nii" 94 | }, 95 | "mask": { 96 | "mask1": "~/Documents/Datasets/sample_dataset/raw/training_01_01_mask1.nii", 97 | "mask2": "~/Documents/Datasets/sample_dataset/raw/training_01_01_mask2.nii" 98 | } 99 | }, 100 | "2": { 101 | ... 102 | } 103 | ... 104 | } 105 | ``` 106 | 107 | The properties.json contains the peak of each modality using kernel density estimation. 108 | It is saved in the JSON file so that we don't need to calculate it repeatedly during the training process. An example is 109 | ```shell script 110 | "406576b1b92f6740b0e20a29016952ae1fa6c4cf": { 111 | "path": "~/Documents/Datasets/sample_dataset/raw/training_01_01_t1.nii", 112 | "peak": 178855.95956321745 113 | } 114 | ``` 115 | 116 | ## Training and Validation 117 | 118 | Before you run and training and validation, you can decide what kind of cross-validation strategy you want to use. 119 | At default, we use 5-fold cross-validation. 120 | We implemented three strategies, which can be set using test_mode option: 121 | * 'val': validation only, no test. For each fold, 4/5 of data for training and 1/5 of data for validation. 122 | The models with the best dice score will be saved as 'latest' models. 5 'latest' models will be preserved. 123 | The training will stop after 160 epoch of no improvement or the fixed number of epochs as provided in the training options, whichever comes first. 124 | Since all the ISBI data we placed into sample\_dataset are from the training dataset 125 | (which means the test data will be provided separately), we use this mode. 126 | * 'test': test only, no validation. The fold can be set using test_index option. Only one model will be generated. 127 | The program will not automatically save the 'latest' model because the model performance is determined using the validation set. 128 | The program will run fixed epochs and stop. 129 | * 'val_test': do both validation and test. The test fold is set to the last fold at default (1/5 of data) 130 | and will not be seen by the model during the training/validation process. 131 | For the remaining 4/5 of data, training takes 3/5 and validation takes 1/5. 132 | In such a way, 4 'latest' model will be saved and they will be finally tested with the hold test set. 133 | 134 | The visdom is needed to run before the training function if you want to visualize the results 135 | ```shell script 136 | conda activate DL 137 | python -m visdom.server 138 | ``` 139 | 140 | An example to run the ModDrop++ training is 141 | ```shell script 142 | conda activate DL 143 | python train.py --loss_to_use focal_ssim --input_nc 3 --trainSize 128 --test_mode val_test --name experiment_name --eval_val --batch_size 16 144 | ``` 145 | where: 146 | * loss_to_use: you can choose from ['focal', 'l2', 'dice', 'ssim']. You can also use combinations of them, e.g. 'focal_l2'. 147 | If you use more than one loss functions, you can set the weight of each loss by setting lambda\_L2, lambda\_focal and lambda\_dice. 148 | * input_nc: the number of slices in a stack. As mentioned in the paper, using 3 achieved the best results. 149 | * trainSize: for ISBI dataset, the size is 181x217x181. We set the trainSize to be 128 so that the program will crop 150 | the slice during the training process. If you set it to 256, the program will pad the slices. 151 | * name: the experiment name. The checkpoint of this name will be automatically created under the ../Checkpoint folder. 152 | Also, in the test phase, the mask prediction has the same suffix as this value. 153 | *eval_val: use the eval mode for validation during the training process. 154 | 155 | _If you got "RuntimeError: unable to open shared memory object", 156 | use PyCharm to run the code instead of using terminal_ 157 | 158 | ## Testing 159 | 160 | In the testing phase, you can create a new folder for the test dataset. 161 | For example, if you want to use the test set of ISBI challenge, create a new subfolder ./challenge 162 | under sample\_dataset (parallel to train, val, test). 163 | The naming of files does not need to follow the strict pattern described in Dataset and Data Conversion, 164 | but the files should end with {MODALITY(MASK)}.{SUFFIX}. The reference mask files are not needed. 165 | If the reference mask files are provided, the segmentation metrics will be output. 166 | 167 | Remember to switch to Testing mode in 'models/ms_model.py' file. This needs to be done manually for now and will be cleaned up later. 168 | 169 | Example: 170 | ```shell script 171 | conda activate DL 172 | python test.py --load_str val0test4 --input_nc 3 --testSize 512 --name experiment_name --epoch 300 --phase test --eval 173 | ``` 174 | where: 175 | * load_str: in the training phase, the lastest models are saved due to cross-validation. 176 | This string describes the models (folds) you want to load. 177 | With it is set to 'val0,val1,val2,val3,val4', latest models from all the 5 folds are loaded 178 | (you can check the model names by exploring the checkpoint folder). 179 | * testSize: the default value is 256. The slices will be padded to this size in the inference stage. 180 | If you have any dimension larger than 256, simply change it to a number that larger than all the dimension sizes. 181 | * epoch: the default value is latest. You can set it to a specific number, but it is not recommended. 182 | * phase: the new subfolder of the test set. 183 | 184 | You can find the options files under the option folder to get more flexibility of running the experiments. 185 | -------------------------------------------------------------------------------- /models/ms_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_ssim 3 | import random 4 | import numpy as np 5 | import torch.nn as nn 6 | from util.image_pool import ImagePool 7 | from .base_model import BaseModel 8 | from .losses import FocalLoss, dice_loss 9 | from . import networks 10 | from configurations import * 11 | 12 | 13 | class MsModel(BaseModel): # ModDrop++ model 14 | def name(self): 15 | return 'MsModel' 16 | 17 | @staticmethod 18 | def modify_commandline_options(parser, is_train=True): 19 | if is_train: 20 | parser.add_argument('--lambda_L2', type=int, default=2000, help='weight for L2 loss') 21 | parser.add_argument('--lambda_dice', type=int, default=100, help='weight for dice loss') 22 | parser.add_argument('--lambda_SSIM', type=int, default=200, help='weight for SSIM loss (feature space)') 23 | parser.add_argument('--lambda_KL', type=int, default=100, help='weight for KL loss (feature space)') 24 | parser.add_argument('--lambda_focal', type=int, default=10000, help='weight for focal loss') 25 | return parser 26 | 27 | def initialize(self, opt, model_suffix): 28 | BaseModel.initialize(self, opt, model_suffix) 29 | self.isTrain = opt.isTrain 30 | self.use_dyn = opt.use_dyn 31 | self.model_names = ['G'] 32 | self.netG = networks.define_G(opt.input_nc * len(MODALITIES), opt.init_type, opt.init_gain, self.gpu_ids, dyn=self.use_dyn) # Network input 3 slices x 5 modalities 33 | 34 | if self.isTrain: 35 | self.visual_names = ['real_mask', 'full_mask', 'miss_mask'] 36 | for modality in MODALITIES: 37 | self.visual_names += ['full_' + modality] 38 | self.visual_names += ['miss_' + modality] 39 | else: 40 | self.visual_names = ['real_mask', 'fake_mask'] 41 | for modality in MODALITIES: 42 | self.visual_names += [modality] 43 | 44 | if self.isTrain: 45 | self.loss_names = ['total'] 46 | self.criterion_names = [] 47 | criterions = {'L2': torch.nn.MSELoss(), 'focal': FocalLoss(gamma=1, alpha=0.25).to(self.device), 'dice': dice_loss, 'ssim': pytorch_ssim.SSIM(), 'KL': torch.nn.KLDivLoss()} 48 | for k in criterions.keys(): 49 | if k in opt.loss_to_use: 50 | self.loss_names += [k] 51 | setattr(self, 'criterion_%s' % k, criterions[k]) 52 | self.criterion_names.append(k) 53 | assert len(self.criterion_names), 'should use at least one loss function in L2, focal, dice' 54 | self.fake_AB_pool = ImagePool(opt.pool_size) 55 | # for feature extraction, update only the last layer, otherwise update all the parameters 56 | if self.opt.feature_extract: 57 | params_to_update = [] 58 | print("Params to learn:") 59 | for name, param in self.netG.named_parameters(): 60 | if 'thres' in name: 61 | params_to_update.append(param) 62 | else: 63 | param.requires_grad = False 64 | else: 65 | params_to_update = self.netG.parameters() 66 | self.optimizers = [torch.optim.Adam(params_to_update, lr=opt.lr, betas=(opt.beta1, 0.999))] 67 | 68 | def set_input(self, input): 69 | if self.isTrain: # training phase 70 | data_full, data_miss = input 71 | self.full_dc = data_full['dc'].to(self.device) 72 | self.full_mc = data_full['mc'].to(self.device) 73 | self.miss_dc = data_miss['dc'].to(self.device) 74 | self.miss_mc = data_miss['mc'].to(self.device) 75 | self.real_mask = data_full['mask'].to(self.device) 76 | for modality in MODALITIES: 77 | setattr(self, 'full_' + modality, data_full[modality].to(self.device)) 78 | setattr(self, 'miss_' + modality, data_miss[modality].to(self.device)) 79 | self.full_input = torch.cat([getattr(self, 'full_' + k) for k in MODALITIES], 1) 80 | self.miss_input = torch.cat([getattr(self, 'miss_' + k) for k in MODALITIES], 1) 81 | else: # inference phase 82 | self.mc = input['mc'].to(self.device) 83 | for modality in MODALITIES: 84 | setattr(self, modality, input[modality].to(self.device)) 85 | self.real_mask = input['mask'].to(self.device) 86 | self.input = torch.cat([getattr(self, k) for k in MODALITIES], 1) 87 | 88 | def forward(self): 89 | if self.isTrain: 90 | self.full_feat, self.full_mask = self.netG(self.full_input, self.full_mc, get_dyn_feat=True) 91 | self.miss_feat, self.miss_mask = self.netG(self.miss_input, self.miss_mc, get_dyn_feat=True) 92 | else: 93 | self.fake_mask = self.netG(self.input, self.mc) # inference phase 94 | 95 | 96 | def backward_G(self): 97 | def normalize(feat): 98 | vector = torch.flatten(feat) 99 | min_v = torch.min(vector) 100 | range_v = torch.max(vector) - min_v 101 | if range_v > 0: 102 | normalised = (vector - min_v) / range_v 103 | else: 104 | normalised = torch.zeros(vector.size()) 105 | return normalised 106 | 107 | self.loss_total = 0 108 | full_mask = (self.full_mask + 1) / 2 109 | miss_mask = (self.miss_mask + 1) / 2 110 | real_mask = (self.real_mask + 1) / 2 111 | for k, criterion_name in enumerate(self.criterion_names): 112 | criterion = getattr(self, 'criterion_%s' % criterion_name) 113 | if criterion_name == 'dice': 114 | cur_loss = criterion(full_mask, real_mask) * self.opt.lambda_dice + criterion(miss_mask, real_mask) * self.opt.lambda_dice 115 | self.loss_total += cur_loss 116 | elif self.criterion_names[k] == 'focal': 117 | cur_loss = criterion(self.full_mask, real_mask) * self.opt.lambda_focal + criterion(self.miss_mask, real_mask) * self.opt.lambda_focal 118 | self.loss_total += cur_loss 119 | elif self.criterion_names[k] == 'L2': 120 | cur_loss = criterion(self.full_feat, self.miss_feat) * self.opt.lambda_L2 121 | self.loss_total += cur_loss 122 | elif self.criterion_names[k] == 'ssim': 123 | cur_loss = -criterion(self.full_feat, self.miss_feat) * self.opt.lambda_SSIM 124 | self.loss_total += cur_loss 125 | elif self.criterion_names[k] == 'KL': 126 | cur_loss = criterion(normalize(self.miss_feat), normalize(self.full_feat)) * self.opt.lambda_KL 127 | self.loss_total += cur_loss 128 | setattr(self, 'loss_%s' % self.criterion_names[k], cur_loss) 129 | 130 | self.loss_total.backward() 131 | 132 | def optimize_parameters(self): 133 | self.forward() 134 | self.optimizers[0].zero_grad() 135 | self.backward_G() 136 | self.optimizers[0].step() 137 | 138 | 139 | #=========================================== Testing/ModDrop/ModDrop+=================================================== 140 | 141 | # import torch 142 | # import random 143 | # import numpy as np 144 | # import torch.nn as nn 145 | # from util.image_pool import ImagePool 146 | # from .base_model import BaseModel 147 | # from .losses import FocalLoss, dice_loss 148 | # from . import networks 149 | # from configurations import * 150 | # 151 | # 152 | # class MsModel(BaseModel): 153 | # def name(self): 154 | # return 'MsModel' 155 | # 156 | # @staticmethod 157 | # def modify_commandline_options(parser, is_train=True): 158 | # if is_train: 159 | # parser.add_argument('--lambda_L2', type=int, default=2000, help='weight for L2 loss') 160 | # parser.add_argument('--lambda_dice', type=int, default=100, help='weight for dice loss') 161 | # parser.add_argument('--lambda_focal', type=int, default=10000, help='weight for focal loss') 162 | # return parser 163 | # 164 | # def initialize(self, opt, model_suffix): 165 | # BaseModel.initialize(self, opt, model_suffix) 166 | # self.isTrain = opt.isTrain 167 | # self.use_dyn = opt.use_dyn 168 | # self.model_names = ['G'] 169 | # self.netG = networks.define_G(opt.input_nc * len(MODALITIES), opt.init_type, opt.init_gain, self.gpu_ids, dyn=self.use_dyn) # Network input 3 slices x 5 modalities 170 | # self.visual_names = ['real_mask', 'fake_mask'] 171 | # for modality in MODALITIES: 172 | # self.visual_names += [modality] 173 | # 174 | # if self.isTrain: 175 | # self.loss_names = ['total'] 176 | # self.criterion_names = [] 177 | # criterions = {'L2': torch.nn.MSELoss(), 'focal': FocalLoss(gamma=1, alpha=0.25).to(self.device), 'dice': dice_loss} 178 | # for k in criterions.keys(): 179 | # if k in opt.loss_to_use: 180 | # self.loss_names += [k] 181 | # setattr(self, 'criterion_%s' % k, criterions[k]) 182 | # self.criterion_names.append(k) 183 | # assert len(self.criterion_names), 'should use at least one loss function in L2, focal, dice' 184 | # self.fake_AB_pool = ImagePool(opt.pool_size) 185 | # 186 | # # for feature extraction, update only the last layer, otherwise update all the parameters 187 | # if self.opt.feature_extract: 188 | # params_to_update = [] 189 | # print("Params to learn:") 190 | # for name, param in self.netG.named_parameters(): 191 | # if 'thres' in name: 192 | # params_to_update.append(param) 193 | # else: 194 | # param.requires_grad = False 195 | # else: 196 | # params_to_update = self.netG.parameters() 197 | # self.optimizers = [torch.optim.Adam(params_to_update, lr=opt.lr, betas=(opt.beta1, 0.999))] 198 | # 199 | # def set_input(self, input): 200 | # self.dc = input['dc'].to(self.device) 201 | # self.mc = input['mc'].to(self.device) 202 | # 203 | # for modality in MODALITIES: 204 | # setattr(self, modality, input[modality].to(self.device)) 205 | # 206 | # self.real_mask = input['mask'].to(self.device) 207 | # self.input = torch.cat([getattr(self, k) for k in MODALITIES], 1) 208 | # 209 | # def forward(self): 210 | # if self.use_dyn: 211 | # self.fake_mask = self.netG(self.input, self.mc) # dynamic filter version 212 | # else: 213 | # self.fake_mask = self.netG(self.input) # regular version 214 | # 215 | # def backward_G(self): 216 | # self.loss_total = 0 217 | # fake_mask = (self.fake_mask + 1) / 2 218 | # real_mask = (self.real_mask + 1) / 2 219 | # for k, criterion_name in enumerate(self.criterion_names): 220 | # criterion = getattr(self, 'criterion_%s' % criterion_name) 221 | # if criterion_name == 'dice': 222 | # tmp = criterion(fake_mask, real_mask) * self.opt.lambda_dice 223 | # elif self.criterion_names[k] == 'focal': 224 | # tmp = criterion(self.fake_mask, real_mask) * self.opt.lambda_focal 225 | # else: 226 | # tmp = criterion(self.fake_mask, self.real_mask) * self.opt.lambda_L2 227 | # self.loss_total += tmp 228 | # setattr(self, 'loss_%s' % self.criterion_names[k], tmp) 229 | # 230 | # self.loss_total.backward() 231 | # 232 | # def optimize_parameters(self): 233 | # self.forward() 234 | # self.optimizers[0].zero_grad() 235 | # self.backward_G() 236 | # self.optimizers[0].step() 237 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import ntpath 5 | import time 6 | import pandas as pd 7 | from . import util 8 | from . import html 9 | from PIL import Image 10 | 11 | if sys.version_info[0] == 2: 12 | VisdomExceptionBase = Exception 13 | else: 14 | VisdomExceptionBase = ConnectionError 15 | 16 | 17 | # save image to the disk 18 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): 19 | image_dir = webpage.get_image_dir() 20 | short_path = ntpath.basename(image_path[0]) 21 | name = os.path.splitext(short_path)[0] 22 | 23 | webpage.add_header(name) 24 | ims, txts, links = [], [], [] 25 | 26 | for label, im_data in visuals.items(): 27 | im = util.tensor2im(im_data) 28 | image_name = '%s_%s.png' % (name, label) 29 | save_path = os.path.join(image_dir, image_name) 30 | h, w, _ = im.shape 31 | if aspect_ratio > 1.0: 32 | im = np.array(Image.fromarray(im).resize((h, int(w * aspect_ratio)))) 33 | if aspect_ratio < 1.0: 34 | im = np.array(Image.fromarray(im).resize((int(h / aspect_ratio), w))) 35 | util.save_image(im, save_path) 36 | 37 | ims.append(image_name) 38 | txts.append(label) 39 | links.append(image_name) 40 | webpage.add_images(ims, txts, links, width=width) 41 | 42 | 43 | class Visualizer(): 44 | def __init__(self, opt): 45 | self.display_id = opt.display_id 46 | self.use_html = opt.isTrain and not opt.no_html 47 | self.win_size = opt.display_winsize 48 | self.name = opt.name 49 | self.opt = opt 50 | self.saved = False 51 | if self.display_id > 0: 52 | import visdom 53 | self.ncols = opt.display_ncols 54 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env, raise_exceptions=True) 55 | 56 | if self.use_html: 57 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 58 | self.img_dir = os.path.join(self.web_dir, 'images') 59 | print('create web directory %s...' % self.web_dir) 60 | util.mkdirs([self.web_dir, self.img_dir]) 61 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 62 | with open(self.log_name, "a") as log_file: 63 | now = time.strftime("%c") 64 | log_file.write('================ Training Loss (%s) ================\n' % now) 65 | self.val_log_name = os.path.join(opt.checkpoints_dir, opt.name, 'val_loss_log.csv') 66 | 67 | def reset(self): 68 | self.saved = False 69 | 70 | def throw_visdom_connection_error(self): 71 | print('\n\nCould not connect to Visdom server (https://github.com/facebookresearch/visdom) for displaying training progress.\nYou can suppress connection to Visdom using the option --display_id -1. To install visdom, run \n$ pip install visdom\n, and start the server by \n$ python -m visdom.server.\n\n') 72 | exit(1) 73 | 74 | # |visuals|: dictionary of images to display or save 75 | def display_current_results(self, visuals, epoch, save_result): 76 | if self.display_id > 0: # show images in the browser 77 | ncols = self.ncols 78 | if ncols > 0: 79 | ncols = min(ncols, len(visuals)) 80 | h, w = next(iter(visuals.values())).shape[:2] 81 | table_css = """""" % (w, h) 85 | title = self.name 86 | label_html = '' 87 | label_html_row = '' 88 | images = [] 89 | idx = 0 90 | for label, image in visuals.items(): 91 | image_numpy = util.tensor2im(image) 92 | label_html_row += '%s' % label 93 | images.append(image_numpy.transpose([2, 0, 1])) 94 | idx += 1 95 | if idx % ncols == 0: 96 | label_html += '%s' % label_html_row 97 | label_html_row = '' 98 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 99 | while idx % ncols != 0: 100 | images.append(white_image) 101 | label_html_row += '' 102 | idx += 1 103 | if label_html_row != '': 104 | label_html += '%s' % label_html_row 105 | # pane col = image row 106 | try: 107 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 108 | padding=2, opts=dict(title=title + ' images')) 109 | label_html = '%s
' % label_html 110 | self.vis.text(table_css + label_html, win=self.display_id + 2, 111 | opts=dict(title=title + ' labels')) 112 | except VisdomExceptionBase: 113 | self.throw_visdom_connection_error() 114 | 115 | else: 116 | idx = 1 117 | for label, image in visuals.items(): 118 | image_numpy = util.tensor2im(image) 119 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 120 | win=self.display_id + idx) 121 | idx += 1 122 | 123 | if self.use_html and (save_result or not self.saved): # save images to a html file 124 | self.saved = True 125 | for label, image in visuals.items(): 126 | image_numpy = util.tensor2im(image) 127 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 128 | util.save_image(image_numpy, img_path) 129 | # update website 130 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) 131 | for n in range(epoch, 0, -1): 132 | webpage.add_header('epoch [%d]' % n) 133 | ims, txts, links = [], [], [] 134 | 135 | for label, image_numpy in visuals.items(): 136 | image_numpy = util.tensor2im(image) 137 | img_path = 'epoch%.3d_%s.png' % (n, label) 138 | ims.append(img_path) 139 | txts.append(label) 140 | links.append(img_path) 141 | webpage.add_images(ims, txts, links, width=self.win_size) 142 | webpage.save() 143 | 144 | # losses: dictionary of error labels and values 145 | def plot_current_losses(self, epoch, counter_ratio, opt, losses): 146 | if not hasattr(self, 'plot_data'): 147 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 148 | self.plot_data['X'].append(epoch + counter_ratio) 149 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 150 | try: 151 | self.vis.line( 152 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 153 | Y=np.array(self.plot_data['Y']), 154 | opts={ 155 | 'title': self.name + ' loss over time', 156 | 'legend': self.plot_data['legend'], 157 | 'xlabel': 'epoch', 158 | 'ylabel': 'loss'}, 159 | win=self.display_id) 160 | except VisdomExceptionBase: 161 | self.throw_visdom_connection_error() 162 | 163 | # losses: same format as |losses| of plot_current_losses 164 | def print_current_losses(self, epoch, i, losses, t, t_data): 165 | # '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data) 166 | message = '(epoch: %d, iters: %d) ' % (epoch, i) 167 | for k, v in losses.items(): 168 | message += '%s: %.3f ' % (k, v) 169 | 170 | print(message) 171 | with open(self.log_name, "a") as log_file: 172 | log_file.write('%s\n' % message) 173 | 174 | def display_val_results(self, visuals, epoch): 175 | if self.display_id > 0: # show images in the browser 176 | ncols = self.ncols 177 | if ncols > 0: 178 | ncols = min(ncols, len(visuals)) 179 | title = self.name 180 | images = [] 181 | idx = 0 182 | for label, image in visuals.items(): 183 | image_numpy = util.tensor2im(image) 184 | images.append(image_numpy.transpose([2, 0, 1])) 185 | idx += 1 186 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 187 | while idx % ncols != 0: 188 | images.append(white_image) 189 | idx += 1 190 | try: 191 | self.vis.images(images, nrow=ncols, win=self.display_id + 4, 192 | padding=2, opts=dict(title=title + ' val images')) 193 | except VisdomExceptionBase: 194 | self.throw_visdom_connection_error() 195 | 196 | else: 197 | idx = 1 198 | for label, image in visuals.items(): 199 | image_numpy = util.tensor2im(image) 200 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 201 | win=self.display_id + idx) 202 | idx += 1 203 | 204 | # losses: dictionary of error labels and values 205 | def plot_val_losses(self, epoch, counter_ratio, opt, losses, model_suffix=None): 206 | if not hasattr(self, 'val_data'): 207 | self.val_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 208 | self.val_data['X'].append(epoch + counter_ratio) 209 | self.val_data['Y'].append([losses[k] for k in self.val_data['legend']]) 210 | # if epoch == opt.niter + opt.niter_decay: 211 | df = pd.DataFrame(np.array(self.val_data['Y']), columns=losses.keys()) 212 | if model_suffix is not None: 213 | self.val_log_name = self.val_log_name.replace('val_loss_log.csv', 'val_loss_log_%s.csv' % model_suffix) 214 | df.to_csv(self.val_log_name) 215 | try: 216 | self.vis.line( 217 | X=np.stack([np.array(self.val_data['X'])] * len(self.val_data['legend']), 1), 218 | Y=np.array(self.val_data['Y']), 219 | opts={ 220 | 'title': self.name + ' val loss over time', 221 | 'legend': self.val_data['legend'], 222 | 'xlabel': 'epoch', 223 | 'ylabel': 'loss'}, 224 | win=self.display_id+3) 225 | except VisdomExceptionBase: 226 | self.throw_visdom_connection_error() 227 | 228 | def save_val_losses(self, epoch, counter_ratio, opt, losses, model_suffix=None): 229 | if not hasattr(self, 'val_data'): 230 | self.val_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 231 | self.val_data['X'].append(epoch + counter_ratio) 232 | self.val_data['Y'].append([losses[k] for k in self.val_data['legend']]) 233 | # if epoch == opt.niter + opt.niter_decay: 234 | df = pd.DataFrame(np.array(self.val_data['Y']), columns=losses.keys()) 235 | if model_suffix is not None: 236 | self.val_log_name = self.val_log_name.replace('val_loss_log.csv', 'val_loss_log_%s.csv' % model_suffix) 237 | df.to_csv(self.val_log_name) 238 | 239 | # losses: same format as |losses| of plot_current_losses 240 | def print_val_losses(self, epoch, losses, t): 241 | message = 'VAL: (epoch: %d) ' % epoch # 'VAL: (epoch: %d, time: %.3f) ' % (epoch, t) 242 | for k, v in losses.items(): 243 | message += '%s: %.3f ' % (k, v) 244 | 245 | print(message) 246 | with open(self.log_name, "a") as log_file: 247 | log_file.write('%s\n' % message) 248 | --------------------------------------------------------------------------------