├── .vscode └── settings.json ├── models ├── __init__.py ├── README.md ├── nn │ └── module.py ├── UNet.py ├── SegNet.py ├── ResUNet.py └── KiUNet.py ├── .gitignore ├── experiments └── README.md ├── utils ├── metrics.py ├── logger.py ├── common.py ├── weights_init.py └── loss.py ├── dataset ├── dataset_lits_val.py ├── dataset_lits_train.py ├── dataset_lits_test.py └── transforms.py ├── config.py ├── test.py ├── README.md ├── train.py └── preprocess_LiTS.py /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/home/yx/anaconda3/envs/lzq/bin/python" 3 | } -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .UNet import UNet 2 | from .KiUNet import KiUNet_min,KiUNet_org 3 | from .SegNet import SegNet 4 | from .ResUNet import ResUNet -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | #ignore some files 2 | .idea 3 | /__pycache__/ 4 | */__pycache__/ 5 | /samples/ 6 | /Lab/ 7 | /experiments/*/ 8 | /fixed_data/ 9 | /.vscode/ -------------------------------------------------------------------------------- /experiments/README.md: -------------------------------------------------------------------------------- 1 | All trained model will be saved here: `/output/model_name` 2 | And every trained model be saved with `events.out.tfevents.***` log file and `log.csv` 3 | You could observe loss and dice during training through Tensorboard by run: 4 | `tensorboard --logdir ./output/model_name` 5 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | 2 | Recently (May 2021), I investigated the existing 3D medical image segmentation algorithm based on U-Net and found that the models designed in [MICCAI-LITS2017](https://github.com/assassint2017/MICCAI-LITS2017) and [KiU-Net](https://github.com/jeya-maria-jose/KiU-Net-pytorch) are perfect. I recommend that you pay attention to these two projects. Rewriting the model file was a waste of time, so I deleted the previous model and referenced some of their work, and then made minor modifications. Thank you very much for these two open source projects. 3 | 4 | ### Reference: 5 | 1. https://github.com/assassint2017/MICCAI-LITS2017 6 | 2. https://github.com/jeya-maria-jose/KiU-Net-pytorch -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch 5 | import numpy as np 6 | 7 | class LossAverage(object): 8 | """Computes and stores the average and current value for calculate average loss""" 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = round(self.sum / self.count, 4) 23 | # print(self.val) 24 | 25 | class DiceAverage(object): 26 | """Computes and stores the average and current value for calculate average loss""" 27 | def __init__(self,class_num): 28 | self.class_num = class_num 29 | self.reset() 30 | 31 | def reset(self): 32 | self.value = np.asarray([0]*self.class_num, dtype='float64') 33 | self.avg = np.asarray([0]*self.class_num, dtype='float64') 34 | self.sum = np.asarray([0]*self.class_num, dtype='float64') 35 | self.count = 0 36 | 37 | def update(self, logits, targets): 38 | self.value = DiceAverage.get_dices(logits, targets) 39 | self.sum += self.value 40 | self.count += 1 41 | self.avg = np.around(self.sum / self.count, 4) 42 | # print(self.value) 43 | 44 | @staticmethod 45 | def get_dices(logits, targets): 46 | dices = [] 47 | for class_index in range(targets.size()[1]): 48 | inter = torch.sum(logits[:, class_index, :, :, :] * targets[:, class_index, :, :, :]) 49 | union = torch.sum(logits[:, class_index, :, :, :]) + torch.sum(targets[:, class_index, :, :, :]) 50 | dice = (2. * inter + 1) / (union + 1) 51 | dices.append(dice.item()) 52 | return np.asarray(dices) 53 | -------------------------------------------------------------------------------- /dataset/dataset_lits_val.py: -------------------------------------------------------------------------------- 1 | from posixpath import join 2 | from torch.utils.data import DataLoader 3 | import os 4 | import sys 5 | import numpy as np 6 | import SimpleITK as sitk 7 | import torch 8 | from torch.utils.data import Dataset as dataset 9 | from .transforms import Center_Crop, Compose 10 | 11 | 12 | class Val_Dataset(dataset): 13 | def __init__(self, args): 14 | 15 | self.args = args 16 | self.filename_list = self.load_file_name_list(os.path.join(args.dataset_path, 'val_path_list.txt')) 17 | 18 | self.transforms = Compose([Center_Crop(base=16, max_size=args.val_crop_max_size)]) 19 | 20 | def __getitem__(self, index): 21 | 22 | ct = sitk.ReadImage(self.filename_list[index][0], sitk.sitkInt16) 23 | seg = sitk.ReadImage(self.filename_list[index][1], sitk.sitkUInt8) 24 | 25 | ct_array = sitk.GetArrayFromImage(ct) 26 | seg_array = sitk.GetArrayFromImage(seg) 27 | 28 | ct_array = ct_array / self.args.norm_factor 29 | ct_array = ct_array.astype(np.float32) 30 | 31 | ct_array = torch.FloatTensor(ct_array).unsqueeze(0) 32 | seg_array = torch.FloatTensor(seg_array).unsqueeze(0) 33 | 34 | if self.transforms: 35 | ct_array, seg_array = self.transforms(ct_array, seg_array) 36 | 37 | return ct_array, seg_array.squeeze(0) 38 | 39 | def __len__(self): 40 | return len(self.filename_list) 41 | 42 | def load_file_name_list(self, file_path): 43 | file_name_list = [] 44 | with open(file_path, 'r') as file_to_read: 45 | while True: 46 | lines = file_to_read.readline().strip() # 整行读取数据 47 | if not lines: 48 | break 49 | file_name_list.append(lines.split()) 50 | return file_name_list 51 | 52 | if __name__ == "__main__": 53 | sys.path.append('/ssd/lzq/3DUNet') 54 | from config import args 55 | train_ds = Dataset(args, mode='train') 56 | 57 | # 定义数据加载 58 | train_dl = DataLoader(train_ds, 2, False, num_workers=1) 59 | 60 | for i, (ct, seg) in enumerate(train_dl): 61 | print(i,ct.size(),seg.size()) -------------------------------------------------------------------------------- /dataset/dataset_lits_train.py: -------------------------------------------------------------------------------- 1 | from posixpath import join 2 | from torch.utils.data import DataLoader 3 | import os 4 | import sys 5 | import random 6 | from torchvision.transforms import RandomCrop 7 | import numpy as np 8 | import SimpleITK as sitk 9 | import torch 10 | from torch.utils.data import Dataset as dataset 11 | from .transforms import RandomCrop, RandomFlip_LR, RandomFlip_UD, Center_Crop, Compose, Resize 12 | 13 | class Train_Dataset(dataset): 14 | def __init__(self, args): 15 | 16 | self.args = args 17 | 18 | self.filename_list = self.load_file_name_list(os.path.join(args.dataset_path, 'train_path_list.txt')) 19 | 20 | self.transforms = Compose([ 21 | RandomCrop(self.args.crop_size), 22 | RandomFlip_LR(prob=0.5), 23 | RandomFlip_UD(prob=0.5), 24 | # RandomRotate() 25 | ]) 26 | 27 | def __getitem__(self, index): 28 | 29 | ct = sitk.ReadImage(self.filename_list[index][0], sitk.sitkInt16) 30 | seg = sitk.ReadImage(self.filename_list[index][1], sitk.sitkUInt8) 31 | 32 | ct_array = sitk.GetArrayFromImage(ct) 33 | seg_array = sitk.GetArrayFromImage(seg) 34 | 35 | ct_array = ct_array / self.args.norm_factor 36 | ct_array = ct_array.astype(np.float32) 37 | 38 | ct_array = torch.FloatTensor(ct_array).unsqueeze(0) 39 | seg_array = torch.FloatTensor(seg_array).unsqueeze(0) 40 | 41 | if self.transforms: 42 | ct_array,seg_array = self.transforms(ct_array, seg_array) 43 | 44 | return ct_array, seg_array.squeeze(0) 45 | 46 | def __len__(self): 47 | return len(self.filename_list) 48 | 49 | def load_file_name_list(self, file_path): 50 | file_name_list = [] 51 | with open(file_path, 'r') as file_to_read: 52 | while True: 53 | lines = file_to_read.readline().strip() # 整行读取数据 54 | if not lines: 55 | break 56 | file_name_list.append(lines.split()) 57 | return file_name_list 58 | 59 | if __name__ == "__main__": 60 | sys.path.append('/ssd/lzq/3DUNet') 61 | from config import args 62 | train_ds = Train_Dataset(args) 63 | 64 | # 定义数据加载 65 | train_dl = DataLoader(train_ds, 2, False, num_workers=1) 66 | 67 | for i, (ct, seg) in enumerate(train_dl): 68 | print(i,ct.size(),seg.size()) -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(description='Hyper-parameters management') 4 | 5 | # Hardware options 6 | parser.add_argument('--n_threads', type=int, default=6,help='number of threads for data loading') 7 | parser.add_argument('--cpu', action='store_true',help='use cpu only') 8 | parser.add_argument('--gpu_id', type=list,default=[0,1], help='use cpu only') 9 | parser.add_argument('--seed', type=int, default=2021, help='random seed') 10 | 11 | # Preprocess parameters 12 | parser.add_argument('--n_labels', type=int, default=2,help='number of classes') # 分割肝脏则置为2(二类分割),分割肝脏和肿瘤则置为3(三类分割) 13 | parser.add_argument('--upper', type=int, default=200, help='') 14 | parser.add_argument('--lower', type=int, default=-200, help='') 15 | parser.add_argument('--norm_factor', type=float, default=200.0, help='') 16 | parser.add_argument('--expand_slice', type=int, default=20, help='') 17 | parser.add_argument('--min_slices', type=int, default=48, help='') 18 | parser.add_argument('--xy_down_scale', type=float, default=0.5, help='') 19 | parser.add_argument('--slice_down_scale', type=float, default=1.0, help='') 20 | parser.add_argument('--valid_rate', type=float, default=0.2, help='') 21 | 22 | # data in/out and dataset 23 | parser.add_argument('--dataset_path',default = '/ssd/lzq/dataset/fixed_lits',help='fixed trainset root path') 24 | parser.add_argument('--test_data_path',default = '/ssd/lzq/dataset/LiTS/test',help='Testset path') 25 | parser.add_argument('--save',default='ResUNet',help='save path of trained model') 26 | parser.add_argument('--batch_size', type=list, default=2,help='batch size of trainset') 27 | 28 | # train 29 | parser.add_argument('--epochs', type=int, default=200, metavar='N',help='number of epochs to train (default: 200)') 30 | parser.add_argument('--lr', type=float, default=0.0001, metavar='LR',help='learning rate (default: 0.0001)') 31 | parser.add_argument('--early-stop', default=30, type=int, help='early stopping (default: 30)') 32 | parser.add_argument('--crop_size', type=int, default=48) 33 | parser.add_argument('--val_crop_max_size', type=int, default=96) 34 | 35 | # test 36 | parser.add_argument('--test_cut_size', type=int, default=48, help='size of sliding window') 37 | parser.add_argument('--test_cut_stride', type=int, default=24, help='stride of sliding window') 38 | parser.add_argument('--postprocess', type=bool, default=False, help='post process') 39 | 40 | 41 | args = parser.parse_args() 42 | 43 | 44 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from tensorboardX import SummaryWriter 3 | import matplotlib.pyplot as plt 4 | import torch,random 5 | import numpy as np 6 | from collections import OrderedDict 7 | 8 | class Train_Logger(): 9 | def __init__(self,save_path,save_name): 10 | self.log = None 11 | self.summary = None 12 | self.save_path = save_path 13 | self.save_name = save_name 14 | 15 | def update(self,epoch,train_log,val_log): 16 | item = OrderedDict({'epoch':epoch}) 17 | item.update(train_log) 18 | item.update(val_log) 19 | # item = dict_round(item,4) # 保留小数点后四位有效数字 20 | print("\033[0;33mTrain:\033[0m",train_log) 21 | print("\033[0;33mValid:\033[0m",val_log) 22 | self.update_csv(item) 23 | self.update_tensorboard(item) 24 | 25 | def update_csv(self,item): 26 | tmp = pd.DataFrame(item,index=[0]) 27 | if self.log is not None: 28 | self.log = self.log.append(tmp, ignore_index=True) 29 | else: 30 | self.log = tmp 31 | self.log.to_csv('%s/%s.csv' %(self.save_path,self.save_name), index=False) 32 | 33 | def update_tensorboard(self,item): 34 | if self.summary is None: 35 | self.summary = SummaryWriter('%s/' % self.save_path) 36 | epoch = item['epoch'] 37 | for key,value in item.items(): 38 | if key != 'epoch': self.summary.add_scalar(key, value, epoch) 39 | 40 | class Test_Logger(): 41 | def __init__(self,save_path,save_name): 42 | self.log = None 43 | self.summary = None 44 | self.save_path = save_path 45 | self.save_name = save_name 46 | 47 | def update(self,name,log): 48 | item = OrderedDict({'img_name':name}) 49 | item.update(log) 50 | print("\033[0;33mTest:\033[0m",log) 51 | self.update_csv(item) 52 | 53 | def update_csv(self,item): 54 | tmp = pd.DataFrame(item,index=[0]) 55 | if self.log is not None: 56 | self.log = self.log.append(tmp, ignore_index=True) 57 | else: 58 | self.log = tmp 59 | self.log.to_csv('%s/%s.csv' %(self.save_path,self.save_name), index=False) 60 | 61 | def setpu_seed(seed): 62 | torch.manual_seed(seed) 63 | torch.cuda.manual_seed_all(seed) 64 | np.random.seed(seed) 65 | torch.backends.cudnn.deterministic=True 66 | random.seed(seed) 67 | 68 | def dict_round(dic,num): 69 | for key,value in dic.items(): 70 | dic[key] = round(value,num) 71 | return dic -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as sitk 2 | import numpy as np 3 | from scipy import ndimage 4 | import torch, random 5 | 6 | # target one-hot编码 7 | def to_one_hot_3d(tensor, n_classes=3): # shape = [batch, s, h, w] 8 | n, s, h, w = tensor.size() 9 | one_hot = torch.zeros(n, n_classes, s, h, w).scatter_(1, tensor.view(n, 1, s, h, w), 1) 10 | return one_hot 11 | 12 | def random_crop_3d(img, label, crop_size): 13 | random_x_max = img.shape[0] - crop_size[0] 14 | random_y_max = img.shape[1] - crop_size[1] 15 | random_z_max = img.shape[2] - crop_size[2] 16 | 17 | if random_x_max < 0 or random_y_max < 0 or random_z_max < 0: 18 | return None 19 | 20 | x_random = random.randint(0, random_x_max) 21 | y_random = random.randint(0, random_y_max) 22 | z_random = random.randint(0, random_z_max) 23 | 24 | crop_img = img[x_random:x_random + crop_size[0], y_random:y_random + crop_size[1], z_random:z_random + crop_size[2]] 25 | crop_label = label[x_random:x_random + crop_size[0], y_random:y_random + crop_size[1],z_random:z_random + crop_size[2]] 26 | 27 | return crop_img, crop_label 28 | 29 | def center_crop_3d(img, label, slice_num=16): 30 | if img.shape[0] < slice_num: 31 | return None 32 | left_x = img.shape[0]//2 - slice_num//2 33 | right_x = img.shape[0]//2 + slice_num//2 34 | 35 | crop_img = img[left_x:right_x] 36 | crop_label = label[left_x:right_x] 37 | return crop_img, crop_label 38 | 39 | def load_file_name_list(file_path): 40 | file_name_list = [] 41 | with open(file_path, 'r') as file_to_read: 42 | while True: 43 | lines = file_to_read.readline().strip() # 整行读取数据 44 | if not lines: 45 | break 46 | pass 47 | file_name_list.append(lines) 48 | pass 49 | return file_name_list 50 | 51 | def print_network(net): 52 | num_params = 0 53 | for param in net.parameters(): 54 | num_params += param.numel() 55 | print(net) 56 | print('Total number of parameters: %d' % num_params) 57 | 58 | def adjust_learning_rate(optimizer, epoch, args): 59 | """Sets the learning rate to the initial LR decayed by 10 every 10 epochs""" 60 | lr = args.lr * (0.1 ** (epoch // 20)) 61 | for param_group in optimizer.param_groups: 62 | param_group['lr'] = lr 63 | 64 | def adjust_learning_rate_V2(optimizer, lr): 65 | """Sets the learning rate to a fixed number""" 66 | for param_group in optimizer.param_groups: 67 | param_group['lr'] = lr -------------------------------------------------------------------------------- /utils/weights_init.py: -------------------------------------------------------------------------------- 1 | from torch.nn import init 2 | from torch import nn 3 | 4 | def weights_init_normal(m): 5 | classname = m.__class__.__name__ 6 | #print(classname) 7 | if classname.find('Conv') != -1: 8 | init.normal(m.weight.data, 0.0, 0.02) 9 | elif classname.find('Linear') != -1: 10 | init.normal(m.weight.data, 0.0, 0.02) 11 | elif classname.find('BatchNorm') != -1: 12 | init.normal(m.weight.data, 1.0, 0.02) 13 | init.constant(m.bias.data, 0.0) 14 | 15 | 16 | def weights_init_xavier(m): 17 | classname = m.__class__.__name__ 18 | #print(classname) 19 | if classname.find('Conv') != -1: 20 | init.xavier_normal(m.weight.data, gain=1) 21 | elif classname.find('Linear') != -1: 22 | init.xavier_normal(m.weight.data, gain=1) 23 | elif classname.find('BatchNorm') != -1: 24 | init.normal(m.weight.data, 1.0, 0.02) 25 | init.constant(m.bias.data, 0.0) 26 | 27 | 28 | def weights_init_kaiming(m): 29 | classname = m.__class__.__name__ 30 | #print(classname) 31 | if classname.find('Conv') != -1: 32 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 33 | elif classname.find('Linear') != -1: 34 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 35 | elif classname.find('BatchNorm') != -1: 36 | init.normal(m.weight.data, 1.0, 0.02) 37 | init.constant(m.bias.data, 0.0) 38 | 39 | 40 | def weights_init_orthogonal(m): 41 | classname = m.__class__.__name__ 42 | #print(classname) 43 | if classname.find('Conv') != -1: 44 | init.orthogonal(m.weight.data, gain=1) 45 | elif classname.find('Linear') != -1: 46 | init.orthogonal(m.weight.data, gain=1) 47 | elif classname.find('BatchNorm') != -1: 48 | init.normal(m.weight.data, 1.0, 0.02) 49 | init.constant(m.bias.data, 0.0) 50 | 51 | 52 | def init_weights(net, init_type='normal'): 53 | #print('initialization method [%s]' % init_type) 54 | if init_type == 'normal': 55 | net.apply(weights_init_normal) 56 | elif init_type == 'xavier': 57 | net.apply(weights_init_xavier) 58 | elif init_type == 'kaiming': 59 | net.apply(weights_init_kaiming) 60 | elif init_type == 'orthogonal': 61 | net.apply(weights_init_orthogonal) 62 | else: 63 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 64 | 65 | def init_model(net): 66 | if isinstance(net, nn.Conv3d) or isinstance(net, nn.ConvTranspose3d): 67 | nn.init.kaiming_normal_(net.weight.data, 0.25) 68 | nn.init.constant_(net.bias.data, 0) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import torch 3 | from tqdm import tqdm 4 | import config 5 | from utils import logger,common 6 | from dataset.dataset_lits_test import Test_Datasets,to_one_hot_3d 7 | import SimpleITK as sitk 8 | import os 9 | import numpy as np 10 | from models import ResUNet 11 | from utils.metrics import DiceAverage 12 | from collections import OrderedDict 13 | 14 | 15 | def predict_one_img(model, img_dataset, args): 16 | dataloader = DataLoader(dataset=img_dataset, batch_size=1, num_workers=0, shuffle=False) 17 | model.eval() 18 | test_dice = DiceAverage(args.n_labels) 19 | target = to_one_hot_3d(img_dataset.label, args.n_labels) 20 | 21 | with torch.no_grad(): 22 | for data in tqdm(dataloader,total=len(dataloader)): 23 | data = data.to(device) 24 | output = model(data) 25 | # output = nn.functional.interpolate(output, scale_factor=(1//args.slice_down_scale,1//args.xy_down_scale,1//args.xy_down_scale), mode='trilinear', align_corners=False) # 空间分辨率恢复到原始size 26 | img_dataset.update_result(output.detach().cpu()) 27 | 28 | pred = img_dataset.recompone_result() 29 | pred = torch.argmax(pred,dim=1) 30 | 31 | pred_img = common.to_one_hot_3d(pred,args.n_labels) 32 | test_dice.update(pred_img, target) 33 | 34 | test_dice = OrderedDict({'Dice_liver': test_dice.avg[1]}) 35 | if args.n_labels==3: test_dice.update({'Dice_tumor': test_dice.avg[2]}) 36 | 37 | pred = np.asarray(pred.numpy(),dtype='uint8') 38 | if args.postprocess: 39 | pass # TO DO 40 | pred = sitk.GetImageFromArray(np.squeeze(pred,axis=0)) 41 | 42 | return test_dice, pred 43 | 44 | if __name__ == '__main__': 45 | args = config.args 46 | save_path = os.path.join('./experiments', args.save) 47 | device = torch.device('cpu' if args.cpu else 'cuda') 48 | # model info 49 | model = ResUNet(in_channel=1, out_channel=args.n_labels,training=False).to(device) 50 | model = torch.nn.DataParallel(model, device_ids=args.gpu_id) # multi-GPU 51 | ckpt = torch.load('{}/best_model.pth'.format(save_path)) 52 | model.load_state_dict(ckpt['net']) 53 | 54 | test_log = logger.Test_Logger(save_path,"test_log") 55 | # data info 56 | result_save_path = '{}/result'.format(save_path) 57 | if not os.path.exists(result_save_path): 58 | os.mkdir(result_save_path) 59 | 60 | datasets = Test_Datasets(args.test_data_path,args=args) 61 | for img_dataset,file_idx in datasets: 62 | test_dice,pred_img = predict_one_img(model, img_dataset, args) 63 | test_log.update(file_idx, test_dice) 64 | sitk.WriteImage(pred_img, os.path.join(result_save_path, 'result-'+file_idx+'.gz')) 65 | -------------------------------------------------------------------------------- /models/nn/module.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | # 未测试 8 | class SEBlock(nn.Module): 9 | def __init__(self,in_channels,out_channels): 10 | super(SEBlock,self).__init__() 11 | self.gap=nn.AdaptiveAvgPool3d(1) 12 | conv=nn.Conv3d 13 | 14 | self.conv1=conv(in_channels,out_channels,1) 15 | self.conv2=conv(in_channels,out_channels,1) 16 | 17 | self.relu = nn.ReLU(inplace=True) 18 | self.sigmoid=nn.Sigmoid() 19 | 20 | def forward(self,x): 21 | inpu=x 22 | x=self.gap(x) 23 | x=self.conv1(x) 24 | x=self.relu(x) 25 | x=self.conv2(x) 26 | x=self.sigmoid(x) 27 | 28 | return inpu*x 29 | 30 | # 未测试 31 | class DenseBlock(nn.Module): 32 | def __init__(self,channels,conv_num): 33 | super(DenseBlock,self).__init__() 34 | self.conv_num=conv_num 35 | conv = nn.Conv3d 36 | 37 | self.relu=nn.ReLU() 38 | self.conv_list=[] 39 | self.bottle_conv_list=[] 40 | for i in conv_num: 41 | self.bottle_conv_list.append(conv(channels*(i+1),channels*4,1)) 42 | self.conv_list.append(conv(channels*4,channels,3,padding=1)) 43 | 44 | 45 | def forward(self,x): 46 | 47 | res_x=[] 48 | res_x.append(x) 49 | 50 | for i in self.conv_num: 51 | inputs=torch.cat(res_x,dim=1) 52 | x=self.bottle_conv_list[i](inputs) 53 | x=self.relu(x) 54 | x=self.conv_list[i](x) 55 | x=self.relu(x) 56 | res_x.append(x) 57 | 58 | return x 59 | 60 | class ResBlock(nn.Module): 61 | def __init__(self, in_channels, out_channels, stride=1): 62 | super(ResBlock, self).__init__() 63 | self.in_channels=in_channels 64 | self.out_channels=out_channels 65 | 66 | self.conv1 = nn.Conv3d(in_channels, out_channels, 3, stride=stride, padding=1) 67 | self.bn1 = nn.BatchNorm3d(out_channels) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.conv2 = nn.Conv3d(out_channels, out_channels, 3, stride=stride, padding=1) 70 | self.bn2 = nn.BatchNorm3d(out_channels) 71 | 72 | if in_channels!=out_channels: 73 | self.res_conv=nn.Conv3d(in_channels,out_channels,1,stride=stride) 74 | 75 | def forward(self, x): 76 | if self.in_channels != self.out_channels: 77 | res=self.res_conv(x) 78 | else: 79 | res=x 80 | x = self.conv1(x) 81 | x = self.bn1(x) 82 | x = self.relu(x) 83 | x = self.conv2(x) 84 | x = self.bn2(x) 85 | 86 | out = x + res 87 | out = self.relu(out) 88 | 89 | return out -------------------------------------------------------------------------------- /models/UNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class UNet(nn.Module): 6 | def __init__(self, in_channel=1, out_channel=2, training=True): 7 | super(UNet, self).__init__() 8 | self.training = training 9 | self.encoder1 = nn.Conv3d(in_channel, 32, 3, stride=1, padding=1) # b, 16, 10, 10 10 | self.encoder2= nn.Conv3d(32, 64, 3, stride=1, padding=1) # b, 8, 3, 3 11 | self.encoder3= nn.Conv3d(64, 128, 3, stride=1, padding=1) 12 | self.encoder4= nn.Conv3d(128, 256, 3, stride=1, padding=1) 13 | # self.encoder5= nn.Conv3d(256, 512, 3, stride=1, padding=1) 14 | 15 | # self.decoder1 = nn.Conv3d(512, 256, 3, stride=1,padding=1) # b, 16, 5, 5 16 | self.decoder2 = nn.Conv3d(256, 128, 3, stride=1, padding=1) # b, 8, 15, 1 17 | self.decoder3 = nn.Conv3d(128, 64, 3, stride=1, padding=1) # b, 1, 28, 28 18 | self.decoder4 = nn.Conv3d(64, 32, 3, stride=1, padding=1) 19 | self.decoder5 = nn.Conv3d(32, 2, 3, stride=1, padding=1) 20 | 21 | self.map4 = nn.Sequential( 22 | nn.Conv3d(2, out_channel, 1, 1), 23 | nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear'), 24 | nn.Softmax(dim =1) 25 | ) 26 | 27 | # 128*128 尺度下的映射 28 | self.map3 = nn.Sequential( 29 | nn.Conv3d(64, out_channel, 1, 1), 30 | nn.Upsample(scale_factor=(4, 8, 8), mode='trilinear'), 31 | nn.Softmax(dim =1) 32 | ) 33 | 34 | # 64*64 尺度下的映射 35 | self.map2 = nn.Sequential( 36 | nn.Conv3d(128, out_channel, 1, 1), 37 | nn.Upsample(scale_factor=(8, 16, 16), mode='trilinear'), 38 | nn.Softmax(dim =1) 39 | ) 40 | 41 | # 32*32 尺度下的映射 42 | self.map1 = nn.Sequential( 43 | nn.Conv3d(256, out_channel, 1, 1), 44 | nn.Upsample(scale_factor=(16, 32, 32), mode='trilinear'), 45 | nn.Softmax(dim =1) 46 | ) 47 | 48 | def forward(self, x): 49 | 50 | out = F.relu(F.max_pool3d(self.encoder1(x),2,2)) 51 | t1 = out 52 | out = F.relu(F.max_pool3d(self.encoder2(out),2,2)) 53 | t2 = out 54 | out = F.relu(F.max_pool3d(self.encoder3(out),2,2)) 55 | t3 = out 56 | out = F.relu(F.max_pool3d(self.encoder4(out),2,2)) 57 | # t4 = out 58 | # out = F.relu(F.max_pool3d(self.encoder5(out),2,2)) 59 | 60 | # t2 = out 61 | # out = F.relu(F.interpolate(self.decoder1(out),scale_factor=(2,2,2),mode ='trilinear')) 62 | # print(out.shape,t4.shape) 63 | output1 = self.map1(out) 64 | out = F.relu(F.interpolate(self.decoder2(out),scale_factor=(2,2,2),mode ='trilinear')) 65 | out = torch.add(out,t3) 66 | output2 = self.map2(out) 67 | out = F.relu(F.interpolate(self.decoder3(out),scale_factor=(2,2,2),mode ='trilinear')) 68 | out = torch.add(out,t2) 69 | output3 = self.map3(out) 70 | out = F.relu(F.interpolate(self.decoder4(out),scale_factor=(2,2,2),mode ='trilinear')) 71 | out = torch.add(out,t1) 72 | 73 | out = F.relu(F.interpolate(self.decoder5(out),scale_factor=(2,2,2),mode ='trilinear')) 74 | output4 = self.map4(out) 75 | # print(out.shape) 76 | # print(output1.shape,output2.shape,output3.shape,output4.shape) 77 | if self.training is True: 78 | return output1, output2, output3, output4 79 | else: 80 | return output4 -------------------------------------------------------------------------------- /models/SegNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | class SegNet(nn.Module): 9 | def __init__(self, in_channel=1, out_channel=2, training=True): 10 | super(SegNet, self).__init__() 11 | self.training = training 12 | self.encoder1 = nn.Conv3d(in_channel, 32, 3, stride=1, padding=1) # b, 16, 10, 10 13 | self.encoder2= nn.Conv3d(32, 64, 3, stride=1, padding=1) # b, 8, 3, 3 14 | self.encoder3= nn.Conv3d(64, 128, 3, stride=1, padding=1) 15 | self.encoder4= nn.Conv3d(128, 256, 3, stride=1, padding=1) 16 | # self.encoder5= nn.Conv3d(256, 512, 3, stride=1, padding=1) 17 | 18 | # self.decoder1 = nn.Conv3d(512, 256, 3, stride=1,padding=1) # b, 16, 5, 5 19 | self.decoder2 = nn.Conv3d(256, 128, 3, stride=1, padding=1) # b, 8, 15, 1 20 | self.decoder3 = nn.Conv3d(128, 64, 3, stride=1, padding=1) # b, 1, 28, 28 21 | self.decoder4 = nn.Conv3d(64, 32, 3, stride=1, padding=1) 22 | self.decoder5 = nn.Conv3d(32, 2, 3, stride=1, padding=1) 23 | 24 | self.map4 = nn.Sequential( 25 | nn.Conv3d(2, out_channel, 1, 1), 26 | nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear'), 27 | nn.Softmax(dim =1) 28 | ) 29 | 30 | # 128*128 尺度下的映射 31 | self.map3 = nn.Sequential( 32 | nn.Conv3d(64, out_channel, 1, 1), 33 | nn.Upsample(scale_factor=(4, 8, 8), mode='trilinear'), 34 | nn.Softmax(dim =1) 35 | ) 36 | 37 | # 64*64 尺度下的映射 38 | self.map2 = nn.Sequential( 39 | nn.Conv3d(128, out_channel, 1, 1), 40 | nn.Upsample(scale_factor=(8, 16, 16), mode='trilinear'), 41 | nn.Softmax(dim =1) 42 | ) 43 | 44 | # 32*32 尺度下的映射 45 | self.map1 = nn.Sequential( 46 | nn.Conv3d(256, out_channel, 1, 1), 47 | nn.Upsample(scale_factor=(16, 32, 32), mode='trilinear'), 48 | nn.Softmax(dim =1) 49 | ) 50 | 51 | self.soft = nn.Softmax(dim =1) 52 | 53 | def forward(self, x): 54 | 55 | out = F.relu(F.max_pool3d(self.encoder1(x),2,2)) 56 | t1 = out 57 | out = F.relu(F.max_pool3d(self.encoder2(out),2,2)) 58 | t2 = out 59 | out = F.relu(F.max_pool3d(self.encoder3(out),2,2)) 60 | t3 = out 61 | out = F.relu(F.max_pool3d(self.encoder4(out),2,2)) 62 | # t4 = out 63 | # out = F.relu(F.max_pool3d(self.encoder5(out),2,2)) 64 | 65 | # t2 = out 66 | # out = F.relu(F.interpolate(self.decoder1(out),scale_factor=(2,2,2),mode ='trilinear')) 67 | # print(out.shape,t4.shape) 68 | # out = torch.add(F.pad(out,[0,0,0,0,0,1]),t4) 69 | output1 = self.map1(out) 70 | out = F.relu(F.interpolate(self.decoder2(out),scale_factor=(2,2,2),mode ='trilinear')) 71 | # out = torch.add(out,t3) 72 | output2 = self.map2(out) 73 | out = F.relu(F.interpolate(self.decoder3(out),scale_factor=(2,2,2),mode ='trilinear')) 74 | # out = torch.add(out,t2) 75 | output3 = self.map3(out) 76 | out = F.relu(F.interpolate(self.decoder4(out),scale_factor=(2,2,2),mode ='trilinear')) 77 | # out = torch.add(out,t1) 78 | 79 | out = F.relu(F.interpolate(self.decoder5(out),scale_factor=(2,2,2),mode ='trilinear')) 80 | output4 = self.map4(out) 81 | # print(out.shape) 82 | # print(output1.shape,output2.shape,output3.shape,output4.shape) 83 | if self.training is True: 84 | return output1, output2, output3, output4 85 | else: 86 | return output4 -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | 基于Dice的loss函数,计算时pred和target的shape必须相同,亦即target为onehot编码后的Tensor 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class DiceLoss(nn.Module): 10 | 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def forward(self, pred, target): 15 | 16 | # pred = pred.squeeze(dim=1) 17 | 18 | smooth = 1 19 | 20 | dice = 0. 21 | # dice系数的定义 22 | for i in range(pred.size(1)): 23 | dice += 2 * (pred[:,i] * target[:,i]).sum(dim=1).sum(dim=1).sum(dim=1) / (pred[:,i].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + 24 | target[:,i].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + smooth) 25 | # 返回的是dice距离 26 | dice = dice / pred.size(1) 27 | return torch.clamp((1 - dice).mean(), 0, 1) 28 | 29 | class ELDiceLoss(nn.Module): 30 | def __init__(self): 31 | super().__init__() 32 | 33 | def forward(self, pred, target): 34 | 35 | smooth = 1 36 | 37 | dice = 0. 38 | # dice系数的定义 39 | for i in range(pred.size(1)): 40 | dice += 2 * (pred[:,i] * target[:,i]).sum(dim=1).sum(dim=1).sum(dim=1) / (pred[:,i].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + 41 | target[:,i].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + smooth) 42 | 43 | dice = dice / pred.size(1) 44 | # 返回的是dice距离 45 | return torch.clamp((torch.pow(-torch.log(dice + 1e-5), 0.3)).mean(), 0, 2) 46 | 47 | 48 | class HybridLoss(nn.Module): 49 | def __init__(self): 50 | super().__init__() 51 | 52 | self.bce_loss = nn.BCELoss() 53 | self.bce_weight = 1.0 54 | 55 | def forward(self, pred, target): 56 | 57 | smooth = 1 58 | 59 | dice = 0. 60 | # dice系数的定义 61 | for i in range(pred.size(1)): 62 | dice += 2 * (pred[:,i] * target[:,i]).sum(dim=1).sum(dim=1).sum(dim=1) / (pred[:,i].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + 63 | target[:,i].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + smooth) 64 | 65 | dice = dice / pred.size(1) 66 | 67 | # 返回的是dice距离 + 二值化交叉熵损失 68 | return torch.clamp((1 - dice).mean(), 0, 1) + self.bce_loss(pred, target) * self.bce_weight 69 | 70 | 71 | class JaccardLoss(nn.Module): 72 | def __init__(self): 73 | super().__init__() 74 | 75 | def forward(self, pred, target): 76 | 77 | smooth = 1 78 | 79 | # jaccard系数的定义 80 | jaccard = 0. 81 | 82 | for i in range(pred.size(1)): 83 | jaccard += (pred[:,i] * target[:,i]).sum(dim=1).sum(dim=1).sum(dim=1) / (pred[:,i].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + 84 | target[:,i].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) - (pred[:,i] * target[:,i]).sum(dim=1).sum(dim=1).sum(dim=1) + smooth) 85 | 86 | # 返回的是jaccard距离 87 | jaccard = jaccard / pred.size(1) 88 | return torch.clamp((1 - jaccard).mean(), 0, 1) 89 | 90 | class SSLoss(nn.Module): 91 | def __init__(self): 92 | super().__init__() 93 | 94 | def forward(self, pred, target): 95 | 96 | smooth = 1 97 | 98 | loss = 0. 99 | 100 | for i in range(pred.size(1)): 101 | s1 = ((pred[:,i] - target[:,i]).pow(2) * target[:,i]).sum(dim=1).sum(dim=1).sum(dim=1) / (smooth + target[:,i].sum(dim=1).sum(dim=1).sum(dim=1)) 102 | 103 | s2 = ((pred[:,i] - target[:,i]).pow(2) * (1 - target[:,i])).sum(dim=1).sum(dim=1).sum(dim=1) / (smooth + (1 - target[:,i]).sum(dim=1).sum(dim=1).sum(dim=1)) 104 | 105 | loss += (0.05 * s1 + 0.95 * s2) 106 | 107 | return loss / pred.size(1) 108 | 109 | class TverskyLoss(nn.Module): 110 | 111 | def __init__(self): 112 | super().__init__() 113 | 114 | def forward(self, pred, target): 115 | 116 | smooth = 1 117 | 118 | dice = 0. 119 | 120 | for i in range(pred.size(1)): 121 | dice += (pred[:,i] * target[:,i]).sum(dim=1).sum(dim=1).sum(dim=1) / ((pred[:,i] * target[:,i]).sum(dim=1).sum(dim=1).sum(dim=1)+ 122 | 0.3 * (pred[:,i] * (1 - target[:,i])).sum(dim=1).sum(dim=1).sum(dim=1) + 0.7 * ((1 - pred[:,i]) * target[:,i]).sum(dim=1).sum(dim=1).sum(dim=1) + smooth) 123 | 124 | dice = dice / pred.size(1) 125 | return torch.clamp((1 - dice).mean(), 0, 2) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 3DUNet implemented with pytorch 2 | 3 | ## Introduction 4 | The repository is a 3DUNet implemented with pytorch, referring to 5 | this [project](https://github.com/panxiaobai/lits_pytorch). 6 | I have redesigned the code structure and used the model to perform liver and tumor segmentation on the lits2017 dataset. 7 | paper: [3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation](https://lmb.informatik.uni-freiburg.de/Publications/2016/CABR16/cicek16miccai.pdf) 8 | ### Requirements: 9 | ```angular2 10 | pytorch >= 1.1.0 11 | torchvision 12 | SimpleITK 13 | Tensorboard 14 | Scipy 15 | ``` 16 | ### Code Structure 17 | ```angular2 18 | ├── dataset # Training and testing dataset 19 | │ ├── dataset_lits_train.py 20 | │ ├── dataset_lits_val.py 21 | │ ├── dataset_lits_test.py 22 | │ └── transforms.py 23 | ├── models # Model design 24 | │ ├── nn 25 | │ │ └── module.py 26 | │ │── ResUNet.py # 3DUNet class 27 | │ │── Unet.py # 3DUNet class 28 | │ │── SegNet.py # 3DUNet class 29 | │ └── KiUNet.py # 3DUNet class 30 | ├── experiments # Trained model 31 | |── utils # Some related tools 32 | | ├── common.py 33 | | ├── weights_init.py 34 | | ├── logger.py 35 | | ├── metrics.py 36 | | └── loss.py 37 | ├── preprocess_LiTS.py # preprocessing for raw data 38 | ├── test.py # Test code 39 | ├── train.py # Standard training code 40 | └── config.py # Configuration information for training and testing 41 | ``` 42 | ## Quickly Start 43 | ### 1) LITS2017 dataset preprocessing: 44 | 1. Download dataset from google drive: [Liver Tumor Segmentation Challenge.](https://drive.google.com/drive/folders/0B0vscETPGI1-Q1h1WFdEM2FHSUE) 45 | Or from my share: https://pan.baidu.com/s/1WgP2Ttxn_CV-yRT4UyqHWw 46 | Extraction code:**hfl8** (The dataset consists of two parts: batch1 and batch2) 47 | 2. Then you need decompress and merge batch1 and batch2 into one folder. It is recommended to use 20 samples(27\~46) of the LiTS dataset as the testset 48 | and 111 samples(0\~26 and 47\~131) as the trainset. Please put the volume data and segmentation labels of trainset and testset into different local folders, 49 | such as: 50 | ``` 51 | raw_dataset: 52 | ├── test # 20 samples(27~46) 53 | │   ├── ct 54 | │   │   ├── volume-27.nii 55 | │   │   ├── volume-28.nii 56 | | | |—— ... 57 | │   └── label 58 | │   ├── segmentation-27.nii 59 | │   ├── segmentation-28.nii 60 | | |—— ... 61 | │   62 | ├── train # 111 samples(0\~26 and 47\~131) 63 | │   ├── ct 64 | │   │   ├── volume-0.nii 65 | │   │   ├── volume-1.nii 66 | | | |—— ... 67 | │   └── label 68 | │   ├── segmentation-0.nii 69 | │   ├── segmentation-1.nii 70 | | |—— ... 71 | ``` 72 | 3. Finally, you need to change the root path of the volume data and segmentation labels in `./preprocess_LiTS.py`, such as: 73 | ``` 74 | row_dataset_path = './raw_dataset/train/' # path of origin dataset 75 | fixed_dataset_path = './fixed_data/' # path of fixed(preprocessed) dataset 76 | ``` 77 | 4. Run `python ./preprocess_LiTS.py` 78 | If nothing goes wrong, you can see the following files in the dir `./fixed_data` 79 | ```angular2 80 | │—— train_path_list.txt 81 | │—— val_path_list.txt 82 | │ 83 | |—— ct 84 | │ volume-0.nii 85 | │ volume-1.nii 86 | │ volume-2.nii 87 | │ ... 88 | └─ label 89 | segmentation-0.nii 90 | segmentation-1.nii 91 | segmentation-2.nii 92 | ... 93 | ``` 94 | --- 95 | ### 2) Training 3DUNet 96 | 1. Firstly, you should change the some parameters in `config.py`,especially, please set `--dataset_path` to `./fixed_data` 97 | All parameters are commented in the file `config.py`. 98 | 2. Secondely,run `python train.py --save model_name` 99 | 3. Besides, you can observe the dice and loss during the training process 100 | in the browser through `tensorboard --logdir ./output/model_name`. 101 | --- 102 | ### 3) Testing 3DUNet 103 | run `test.py` 104 | Please pay attention to path of trained model in `test.py`. 105 | (Since the calculation of the 3D convolution operation is too large, 106 | I use a sliding window to block the input tensor before prediction, and then stitch the results to get the final result. 107 | The size of the sliding window can be set by yourself in `config.py`) 108 | 109 | After the test, you can get the test results in the corresponding folder:`./experiments/model_name/result` 110 | 111 | You can also read my Chinese introduction about this 3DUNet project [here](https://zhuanlan.zhihu.com/p/113318562). However, I no longer update the blog, I will try my best to update the github code. 112 | If you have any suggestions or questions, 113 | welcome to open an issue to communicate with me. -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from dataset.dataset_lits_val import Val_Dataset 2 | from dataset.dataset_lits_train import Train_Dataset 3 | 4 | from torch.utils.data import DataLoader 5 | import torch 6 | import torch.optim as optim 7 | from tqdm import tqdm 8 | import config 9 | 10 | from models import UNet, ResUNet , KiUNet_min, SegNet 11 | 12 | from utils import logger, weights_init, metrics, common, loss 13 | import os 14 | import numpy as np 15 | from collections import OrderedDict 16 | 17 | def val(model, val_loader, loss_func, n_labels): 18 | model.eval() 19 | val_loss = metrics.LossAverage() 20 | val_dice = metrics.DiceAverage(n_labels) 21 | with torch.no_grad(): 22 | for idx,(data, target) in tqdm(enumerate(val_loader),total=len(val_loader)): 23 | data, target = data.float(), target.long() 24 | target = common.to_one_hot_3d(target, n_labels) 25 | data, target = data.to(device), target.to(device) 26 | output = model(data) 27 | loss=loss_func(output, target) 28 | 29 | val_loss.update(loss.item(),data.size(0)) 30 | val_dice.update(output, target) 31 | val_log = OrderedDict({'Val_Loss': val_loss.avg, 'Val_dice_liver': val_dice.avg[1]}) 32 | if n_labels==3: val_log.update({'Val_dice_tumor': val_dice.avg[2]}) 33 | return val_log 34 | 35 | def train(model, train_loader, optimizer, loss_func, n_labels, alpha): 36 | print("=======Epoch:{}=======lr:{}".format(epoch,optimizer.state_dict()['param_groups'][0]['lr'])) 37 | model.train() 38 | train_loss = metrics.LossAverage() 39 | train_dice = metrics.DiceAverage(n_labels) 40 | 41 | for idx, (data, target) in tqdm(enumerate(train_loader),total=len(train_loader)): 42 | data, target = data.float(), target.long() 43 | target = common.to_one_hot_3d(target,n_labels) 44 | data, target = data.to(device), target.to(device) 45 | optimizer.zero_grad() 46 | 47 | output = model(data) 48 | loss0 = loss_func(output[0], target) 49 | loss1 = loss_func(output[1], target) 50 | loss2 = loss_func(output[2], target) 51 | loss3 = loss_func(output[3], target) 52 | 53 | loss = loss3 + alpha * (loss0 + loss1 + loss2) 54 | loss.backward() 55 | optimizer.step() 56 | 57 | train_loss.update(loss3.item(),data.size(0)) 58 | train_dice.update(output[3], target) 59 | 60 | val_log = OrderedDict({'Train_Loss': train_loss.avg, 'Train_dice_liver': train_dice.avg[1]}) 61 | if n_labels==3: val_log.update({'Train_dice_tumor': train_dice.avg[2]}) 62 | return val_log 63 | 64 | if __name__ == '__main__': 65 | args = config.args 66 | save_path = os.path.join('./experiments', args.save) 67 | if not os.path.exists(save_path): os.mkdir(save_path) 68 | device = torch.device('cpu' if args.cpu else 'cuda') 69 | # data info 70 | train_loader = DataLoader(dataset=Train_Dataset(args),batch_size=args.batch_size,num_workers=args.n_threads, shuffle=True) 71 | val_loader = DataLoader(dataset=Val_Dataset(args),batch_size=1,num_workers=args.n_threads, shuffle=False) 72 | 73 | # model info 74 | model = ResUNet(in_channel=1, out_channel=args.n_labels,training=True).to(device) 75 | 76 | model.apply(weights_init.init_model) 77 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 78 | common.print_network(model) 79 | model = torch.nn.DataParallel(model, device_ids=args.gpu_id) # multi-GPU 80 | 81 | loss = loss.TverskyLoss() 82 | 83 | log = logger.Train_Logger(save_path,"train_log") 84 | 85 | best = [0,0] # 初始化最优模型的epoch和performance 86 | trigger = 0 # early stop 计数器 87 | alpha = 0.4 # 深监督衰减系数初始值 88 | for epoch in range(1, args.epochs + 1): 89 | common.adjust_learning_rate(optimizer, epoch, args) 90 | train_log = train(model, train_loader, optimizer, loss, args.n_labels, alpha) 91 | val_log = val(model, val_loader, loss, args.n_labels) 92 | log.update(epoch,train_log,val_log) 93 | 94 | # Save checkpoint. 95 | state = {'net': model.state_dict(),'optimizer':optimizer.state_dict(),'epoch': epoch} 96 | torch.save(state, os.path.join(save_path, 'latest_model.pth')) 97 | trigger += 1 98 | if val_log['Val_dice_liver'] > best[1]: 99 | print('Saving best model') 100 | torch.save(state, os.path.join(save_path, 'best_model.pth')) 101 | best[0] = epoch 102 | best[1] = val_log['Val_dice_liver'] 103 | trigger = 0 104 | print('Best performance at Epoch: {} | {}'.format(best[0],best[1])) 105 | 106 | # 深监督系数衰减 107 | if epoch % 30 == 0: alpha *= 0.8 108 | 109 | # early stopping 110 | if args.early_stop is not None: 111 | if trigger >= args.early_stop: 112 | print("=> early stopping") 113 | break 114 | torch.cuda.empty_cache() -------------------------------------------------------------------------------- /dataset/dataset_lits_test.py: -------------------------------------------------------------------------------- 1 | from torch._C import dtype 2 | from utils.common import * 3 | from scipy import ndimage 4 | import numpy as np 5 | from torchvision import transforms as T 6 | import torch, os 7 | from torch.utils.data import Dataset, DataLoader 8 | from glob import glob 9 | import math 10 | import SimpleITK as sitk 11 | 12 | class Img_DataSet(Dataset): 13 | def __init__(self, data_path, label_path, args): 14 | self.n_labels = args.n_labels 15 | self.cut_size = args.test_cut_size 16 | self.cut_stride = args.test_cut_stride 17 | 18 | # 读取一个data文件并归一化 、resize 19 | self.ct = sitk.ReadImage(data_path,sitk.sitkInt16) 20 | self.data_np = sitk.GetArrayFromImage(self.ct) 21 | self.ori_shape = self.data_np.shape 22 | self.data_np = ndimage.zoom(self.data_np, (args.slice_down_scale, args.xy_down_scale, args.xy_down_scale), order=3) # 双三次重采样 23 | self.data_np[self.data_np > args.upper] = args.upper 24 | self.data_np[self.data_np < args.lower] = args.lower 25 | self.data_np = self.data_np/args.norm_factor 26 | self.resized_shape = self.data_np.shape 27 | # 扩展一定数量的slices,以保证卷积下采样合理运算 28 | self.data_np = self.padding_img(self.data_np, self.cut_size,self.cut_stride) 29 | self.padding_shape = self.data_np.shape 30 | # 对数据按步长进行分patch操作,以防止显存溢出 31 | self.data_np = self.extract_ordered_overlap(self.data_np, self.cut_size, self.cut_stride) 32 | 33 | # 读取一个label文件 shape:[s,h,w] 34 | self.seg = sitk.ReadImage(label_path,sitk.sitkInt8) 35 | self.label_np = sitk.GetArrayFromImage(self.seg) 36 | if self.n_labels==2: 37 | self.label_np[self.label_np > 0] = 1 38 | self.label = torch.from_numpy(np.expand_dims(self.label_np,axis=0)).long() 39 | 40 | # 预测结果保存 41 | self.result = None 42 | 43 | def __getitem__(self, index): 44 | data = torch.from_numpy(self.data_np[index]) 45 | data = torch.FloatTensor(data).unsqueeze(0) 46 | return data 47 | 48 | def __len__(self): 49 | return len(self.data_np) 50 | 51 | def update_result(self, tensor): 52 | # tensor = tensor.detach().cpu() # shape: [N,class,s,h,w] 53 | # tensor_np = np.squeeze(tensor_np,axis=0) 54 | if self.result is not None: 55 | self.result = torch.cat((self.result, tensor), dim=0) 56 | else: 57 | self.result = tensor 58 | 59 | def recompone_result(self): 60 | 61 | patch_s = self.result.shape[2] 62 | 63 | N_patches_img = (self.padding_shape[0] - patch_s) // self.cut_stride + 1 64 | assert (self.result.shape[0] == N_patches_img) 65 | 66 | full_prob = torch.zeros((self.n_labels, self.padding_shape[0], self.ori_shape[1],self.ori_shape[2])) # itialize to zero mega array with sum of Probabilities 67 | full_sum = torch.zeros((self.n_labels, self.padding_shape[0], self.ori_shape[1], self.ori_shape[2])) 68 | 69 | for s in range(N_patches_img): 70 | full_prob[:, s * self.cut_stride:s * self.cut_stride + patch_s] += self.result[s] 71 | full_sum[:, s * self.cut_stride:s * self.cut_stride + patch_s] += 1 72 | 73 | assert (torch.min(full_sum) >= 1.0) # at least one 74 | final_avg = full_prob / full_sum 75 | # print(final_avg.size()) 76 | assert (torch.max(final_avg) <= 1.0) # max value for a pixel is 1.0 77 | assert (torch.min(final_avg) >= 0.0) # min value for a pixel is 0.0 78 | img = final_avg[:, :self.ori_shape[0], :self.ori_shape[1], :self.ori_shape[2]] 79 | return img.unsqueeze(0) 80 | 81 | def padding_img(self, img, size, stride): 82 | assert (len(img.shape) == 3) # 3D array 83 | img_s, img_h, img_w = img.shape 84 | leftover_s = (img_s - size) % stride 85 | 86 | if (leftover_s != 0): 87 | s = img_s + (stride - leftover_s) 88 | else: 89 | s = img_s 90 | 91 | tmp_full_imgs = np.zeros((s, img_h, img_w),dtype=np.float32) 92 | tmp_full_imgs[:img_s] = img 93 | print("Padded images shape: " + str(tmp_full_imgs.shape)) 94 | return tmp_full_imgs 95 | 96 | # Divide all the full_imgs in pacthes 97 | def extract_ordered_overlap(self, img, size, stride): 98 | img_s, img_h, img_w = img.shape 99 | assert (img_s - size) % stride == 0 100 | N_patches_img = (img_s - size) // stride + 1 101 | 102 | print("Patches number of the image:{}".format(N_patches_img)) 103 | patches = np.empty((N_patches_img, size, img_h, img_w), dtype=np.float32) 104 | 105 | for s in range(N_patches_img): # loop over the full images 106 | patch = img[s * stride : s * stride + size] 107 | patches[s] = patch 108 | 109 | return patches # array with all the full_imgs divided in patches 110 | 111 | def Test_Datasets(dataset_path, args): 112 | data_list = sorted(glob(os.path.join(dataset_path, 'ct/*'))) 113 | label_list = sorted(glob(os.path.join(dataset_path, 'label/*'))) 114 | print("The number of test samples is: ", len(data_list)) 115 | for datapath, labelpath in zip(data_list, label_list): 116 | print("\nStart Evaluate: ", datapath) 117 | yield Img_DataSet(datapath, labelpath,args=args), datapath.split('-')[-1] 118 | -------------------------------------------------------------------------------- /dataset/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | This part is based on the dataset class implemented by pytorch, 3 | including train_dataset and test_dataset, as well as data augmentation 4 | """ 5 | from torch.utils.data import Dataset 6 | import torch 7 | import numpy as np 8 | import random 9 | import torch.nn.functional as F 10 | from torchvision import transforms 11 | from torchvision.transforms.functional import normalize 12 | 13 | #----------------------data augment------------------------------------------- 14 | class Resize: 15 | def __init__(self, scale): 16 | # self.shape = [shape, shape, shape] if isinstance(shape, int) else shape 17 | self.scale = scale 18 | 19 | def __call__(self, img, mask): 20 | img, mask = img.unsqueeze(0), mask.unsqueeze(0).float() 21 | img = F.interpolate(img, scale_factor=(1,self.scale,self.scale),mode='trilinear', align_corners=False, recompute_scale_factor=True) 22 | mask = F.interpolate(mask, scale_factor=(1,self.scale,self.scale), mode="nearest", recompute_scale_factor=True) 23 | return img[0], mask[0] 24 | 25 | class RandomResize: 26 | def __init__(self,s_rank, w_rank,h_rank): 27 | self.w_rank = w_rank 28 | self.h_rank = h_rank 29 | self.s_rank = s_rank 30 | 31 | def __call__(self, img, mask): 32 | random_w = random.randint(self.w_rank[0],self.w_rank[1]) 33 | random_h = random.randint(self.h_rank[0],self.h_rank[1]) 34 | random_s = random.randint(self.s_rank[0],self.s_rank[1]) 35 | self.shape = [random_s,random_h,random_w] 36 | img, mask = img.unsqueeze(0), mask.unsqueeze(0).float() 37 | img = F.interpolate(img, size=self.shape,mode='trilinear', align_corners=False) 38 | mask = F.interpolate(mask, size=self.shape, mode="nearest") 39 | return img[0], mask[0].long() 40 | 41 | class RandomCrop: 42 | def __init__(self, slices): 43 | self.slices = slices 44 | 45 | def _get_range(self, slices, crop_slices): 46 | if slices < crop_slices: 47 | start = 0 48 | else: 49 | start = random.randint(0, slices - crop_slices) 50 | end = start + crop_slices 51 | if end > slices: 52 | end = slices 53 | return start, end 54 | 55 | def __call__(self, img, mask): 56 | 57 | ss, es = self._get_range(mask.size(1), self.slices) 58 | 59 | # print(self.shape, img.shape, mask.shape) 60 | tmp_img = torch.zeros((img.size(0), self.slices, img.size(2), img.size(3))) 61 | tmp_mask = torch.zeros((mask.size(0), self.slices, mask.size(2), mask.size(3))) 62 | tmp_img[:,:es-ss] = img[:,ss:es] 63 | tmp_mask[:,:es-ss] = mask[:,ss:es] 64 | return tmp_img, tmp_mask 65 | 66 | class RandomFlip_LR: 67 | def __init__(self, prob=0.5): 68 | self.prob = prob 69 | 70 | def _flip(self, img, prob): 71 | if prob[0] <= self.prob: 72 | img = img.flip(2) 73 | return img 74 | 75 | def __call__(self, img, mask): 76 | prob = (random.uniform(0, 1), random.uniform(0, 1)) 77 | return self._flip(img, prob), self._flip(mask, prob) 78 | 79 | class RandomFlip_UD: 80 | def __init__(self, prob=0.5): 81 | self.prob = prob 82 | 83 | def _flip(self, img, prob): 84 | if prob[1] <= self.prob: 85 | img = img.flip(3) 86 | return img 87 | 88 | def __call__(self, img, mask): 89 | prob = (random.uniform(0, 1), random.uniform(0, 1)) 90 | return self._flip(img, prob), self._flip(mask, prob) 91 | 92 | class RandomRotate: 93 | def __init__(self, max_cnt=3): 94 | self.max_cnt = max_cnt 95 | 96 | def _rotate(self, img, cnt): 97 | img = torch.rot90(img,cnt,[1,2]) 98 | return img 99 | 100 | def __call__(self, img, mask): 101 | cnt = random.randint(0,self.max_cnt) 102 | return self._rotate(img, cnt), self._rotate(mask, cnt) 103 | 104 | 105 | class Center_Crop: 106 | def __init__(self, base, max_size): 107 | self.base = base # base默认取16,因为4次下采样后为1 108 | self.max_size = max_size 109 | if self.max_size%self.base: 110 | self.max_size = self.max_size - self.max_size%self.base # max_size为限制最大采样slices数,防止显存溢出,同时也应为16的倍数 111 | def __call__(self, img , label): 112 | if img.size(1) < self.base: 113 | return None 114 | slice_num = img.size(1) - img.size(1) % self.base 115 | slice_num = min(self.max_size, slice_num) 116 | 117 | left = img.size(1)//2 - slice_num//2 118 | right = img.size(1)//2 + slice_num//2 119 | 120 | crop_img = img[:,left:right] 121 | crop_label = label[:,left:right] 122 | return crop_img, crop_label 123 | 124 | class ToTensor: 125 | def __init__(self): 126 | self.to_tensor = transforms.ToTensor() 127 | 128 | def __call__(self, img, mask): 129 | img = self.to_tensor(img) 130 | mask = torch.from_numpy(np.array(mask)) 131 | return img, mask[None] 132 | 133 | 134 | class Normalize: 135 | def __init__(self, mean, std): 136 | self.mean = mean 137 | self.std = std 138 | 139 | def __call__(self, img, mask): 140 | return normalize(img, self.mean, self.std, False), mask 141 | 142 | 143 | class Compose: 144 | def __init__(self, transforms): 145 | self.transforms = transforms 146 | 147 | def __call__(self, img, mask): 148 | for t in self.transforms: 149 | img, mask = t(img, mask) 150 | return img, mask -------------------------------------------------------------------------------- /preprocess_LiTS.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import SimpleITK as sitk 4 | import random 5 | from scipy import ndimage 6 | from os.path import join 7 | import config 8 | 9 | class LITS_preprocess: 10 | def __init__(self, raw_dataset_path,fixed_dataset_path, args): 11 | self.raw_root_path = raw_dataset_path 12 | self.fixed_path = fixed_dataset_path 13 | self.classes = args.n_labels # 分割类别数(只分割肝脏为2,或者分割肝脏和肿瘤为3) 14 | self.upper = args.upper 15 | self.lower = args.lower 16 | self.expand_slice = args.expand_slice # 轴向外侧扩张的slice数量 17 | self.size = args.min_slices # 取样的slice数量 18 | self.xy_down_scale = args.xy_down_scale 19 | self.slice_down_scale = args.slice_down_scale 20 | 21 | self.valid_rate = args.valid_rate 22 | 23 | def fix_data(self): 24 | if not os.path.exists(self.fixed_path): # 创建保存目录 25 | os.makedirs(join(self.fixed_path,'ct')) 26 | os.makedirs(join(self.fixed_path, 'label')) 27 | file_list = os.listdir(join(self.raw_root_path,'ct')) 28 | Numbers = len(file_list) 29 | print('Total numbers of samples is :',Numbers) 30 | for ct_file,i in zip(file_list,range(Numbers)): 31 | print("==== {} | {}/{} ====".format(ct_file, i+1,Numbers)) 32 | ct_path = os.path.join(self.raw_root_path, 'ct', ct_file) 33 | seg_path = os.path.join(self.raw_root_path, 'label', ct_file.replace('volume', 'segmentation')) 34 | new_ct, new_seg = self.process(ct_path, seg_path, classes = self.classes) 35 | if new_ct != None and new_seg != None: 36 | sitk.WriteImage(new_ct, os.path.join(self.fixed_path, 'ct', ct_file)) 37 | sitk.WriteImage(new_seg, os.path.join(self.fixed_path, 'label', ct_file.replace('volume', 'segmentation').replace('.nii', '.nii.gz'))) 38 | 39 | def process(self, ct_path, seg_path, classes=None): 40 | ct = sitk.ReadImage(ct_path, sitk.sitkInt16) 41 | ct_array = sitk.GetArrayFromImage(ct) 42 | seg = sitk.ReadImage(seg_path, sitk.sitkInt8) 43 | seg_array = sitk.GetArrayFromImage(seg) 44 | 45 | print("Ori shape:",ct_array.shape, seg_array.shape) 46 | if classes==2: 47 | # 将金标准中肝脏和肝肿瘤的标签融合为一个 48 | seg_array[seg_array > 0] = 1 49 | # 将灰度值在阈值之外的截断掉 50 | ct_array[ct_array > self.upper] = self.upper 51 | ct_array[ct_array < self.lower] = self.lower 52 | 53 | # 降采样,(对x和y轴进行降采样,slice轴的spacing归一化到slice_down_scale) 54 | ct_array = ndimage.zoom(ct_array, (ct.GetSpacing()[-1] / self.slice_down_scale, self.xy_down_scale, self.xy_down_scale), order=3) 55 | seg_array = ndimage.zoom(seg_array, (ct.GetSpacing()[-1] / self.slice_down_scale, self.xy_down_scale, self.xy_down_scale), order=0) 56 | 57 | # 找到肝脏区域开始和结束的slice,并各向外扩张 58 | z = np.any(seg_array, axis=(1, 2)) 59 | start_slice, end_slice = np.where(z)[0][[0, -1]] 60 | 61 | # 两个方向上各扩张个slice 62 | if start_slice - self.expand_slice < 0: 63 | start_slice = 0 64 | else: 65 | start_slice -= self.expand_slice 66 | 67 | if end_slice + self.expand_slice >= seg_array.shape[0]: 68 | end_slice = seg_array.shape[0] - 1 69 | else: 70 | end_slice += self.expand_slice 71 | 72 | print("Cut out range:",str(start_slice) + '--' + str(end_slice)) 73 | # 如果这时候剩下的slice数量不足size,直接放弃,这样的数据很少 74 | if end_slice - start_slice + 1 < self.size: 75 | print('Too little slice,give up the sample:', ct_file) 76 | return None,None 77 | # 截取保留区域 78 | ct_array = ct_array[start_slice:end_slice + 1, :, :] 79 | seg_array = seg_array[start_slice:end_slice + 1, :, :] 80 | print("Preprocessed shape:",ct_array.shape,seg_array.shape) 81 | # 保存为对应的格式 82 | new_ct = sitk.GetImageFromArray(ct_array) 83 | new_ct.SetDirection(ct.GetDirection()) 84 | new_ct.SetOrigin(ct.GetOrigin()) 85 | new_ct.SetSpacing((ct.GetSpacing()[0] * int(1 / self.xy_down_scale), ct.GetSpacing()[1] * int(1 / self.xy_down_scale), self.slice_down_scale)) 86 | 87 | new_seg = sitk.GetImageFromArray(seg_array) 88 | new_seg.SetDirection(ct.GetDirection()) 89 | new_seg.SetOrigin(ct.GetOrigin()) 90 | new_seg.SetSpacing((ct.GetSpacing()[0] * int(1 / self.xy_down_scale), ct.GetSpacing()[1] * int(1 / self.xy_down_scale), self.slice_down_scale)) 91 | return new_ct, new_seg 92 | 93 | def write_train_val_name_list(self): 94 | data_name_list = os.listdir(join(self.fixed_path, "ct")) 95 | data_num = len(data_name_list) 96 | print('the fixed dataset total numbers of samples is :', data_num) 97 | random.shuffle(data_name_list) 98 | 99 | assert self.valid_rate < 1.0 100 | train_name_list = data_name_list[0:int(data_num*(1-self.valid_rate))] 101 | val_name_list = data_name_list[int(data_num*(1-self.valid_rate)):int(data_num*((1-self.valid_rate) + self.valid_rate))] 102 | 103 | self.write_name_list(train_name_list, "train_path_list.txt") 104 | self.write_name_list(val_name_list, "val_path_list.txt") 105 | 106 | 107 | def write_name_list(self, name_list, file_name): 108 | f = open(join(self.fixed_path, file_name), 'w') 109 | for name in name_list: 110 | ct_path = os.path.join(self.fixed_path, 'ct', name) 111 | seg_path = os.path.join(self.fixed_path, 'label', name.replace('volume', 'segmentation')) 112 | f.write(ct_path + ' ' + seg_path + "\n") 113 | f.close() 114 | 115 | if __name__ == '__main__': 116 | raw_dataset_path = '/ssd/lzq/dataset/LiTS/train' 117 | fixed_dataset_path = '/ssd/lzq/dataset/fixed_lits' 118 | 119 | args = config.args 120 | tool = LITS_preprocess(raw_dataset_path,fixed_dataset_path, args) 121 | tool.fix_data() # 对原始图像进行修剪并保存 122 | tool.write_train_val_name_list() # 创建索引txt文件 123 | -------------------------------------------------------------------------------- /models/ResUNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is referenced from https://github.com/assassint2017/MICCAI-LITS2017 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ResUNet(nn.Module): 11 | def __init__(self, in_channel=1, out_channel=2 ,training=True): 12 | super().__init__() 13 | 14 | self.training = training 15 | self.dorp_rate = 0.2 16 | 17 | self.encoder_stage1 = nn.Sequential( 18 | nn.Conv3d(in_channel, 16, 3, 1, padding=1), 19 | nn.PReLU(16), 20 | 21 | nn.Conv3d(16, 16, 3, 1, padding=1), 22 | nn.PReLU(16), 23 | ) 24 | 25 | self.encoder_stage2 = nn.Sequential( 26 | nn.Conv3d(32, 32, 3, 1, padding=1), 27 | nn.PReLU(32), 28 | 29 | nn.Conv3d(32, 32, 3, 1, padding=1), 30 | nn.PReLU(32), 31 | 32 | nn.Conv3d(32, 32, 3, 1, padding=1), 33 | nn.PReLU(32), 34 | ) 35 | 36 | self.encoder_stage3 = nn.Sequential( 37 | nn.Conv3d(64, 64, 3, 1, padding=1), 38 | nn.PReLU(64), 39 | 40 | nn.Conv3d(64, 64, 3, 1, padding=2, dilation=2), 41 | nn.PReLU(64), 42 | 43 | nn.Conv3d(64, 64, 3, 1, padding=4, dilation=4), 44 | nn.PReLU(64), 45 | ) 46 | 47 | self.encoder_stage4 = nn.Sequential( 48 | nn.Conv3d(128, 128, 3, 1, padding=3, dilation=3), 49 | nn.PReLU(128), 50 | 51 | nn.Conv3d(128, 128, 3, 1, padding=4, dilation=4), 52 | nn.PReLU(128), 53 | 54 | nn.Conv3d(128, 128, 3, 1, padding=5, dilation=5), 55 | nn.PReLU(128), 56 | ) 57 | 58 | self.decoder_stage1 = nn.Sequential( 59 | nn.Conv3d(128, 256, 3, 1, padding=1), 60 | nn.PReLU(256), 61 | 62 | nn.Conv3d(256, 256, 3, 1, padding=1), 63 | nn.PReLU(256), 64 | 65 | nn.Conv3d(256, 256, 3, 1, padding=1), 66 | nn.PReLU(256), 67 | ) 68 | 69 | self.decoder_stage2 = nn.Sequential( 70 | nn.Conv3d(128 + 64, 128, 3, 1, padding=1), 71 | nn.PReLU(128), 72 | 73 | nn.Conv3d(128, 128, 3, 1, padding=1), 74 | nn.PReLU(128), 75 | 76 | nn.Conv3d(128, 128, 3, 1, padding=1), 77 | nn.PReLU(128), 78 | ) 79 | 80 | self.decoder_stage3 = nn.Sequential( 81 | nn.Conv3d(64 + 32, 64, 3, 1, padding=1), 82 | nn.PReLU(64), 83 | 84 | nn.Conv3d(64, 64, 3, 1, padding=1), 85 | nn.PReLU(64), 86 | 87 | nn.Conv3d(64, 64, 3, 1, padding=1), 88 | nn.PReLU(64), 89 | ) 90 | 91 | self.decoder_stage4 = nn.Sequential( 92 | nn.Conv3d(32 + 16, 32, 3, 1, padding=1), 93 | nn.PReLU(32), 94 | 95 | nn.Conv3d(32, 32, 3, 1, padding=1), 96 | nn.PReLU(32), 97 | ) 98 | 99 | self.down_conv1 = nn.Sequential( 100 | nn.Conv3d(16, 32, 2, 2), 101 | nn.PReLU(32) 102 | ) 103 | 104 | self.down_conv2 = nn.Sequential( 105 | nn.Conv3d(32, 64, 2, 2), 106 | nn.PReLU(64) 107 | ) 108 | 109 | self.down_conv3 = nn.Sequential( 110 | nn.Conv3d(64, 128, 2, 2), 111 | nn.PReLU(128) 112 | ) 113 | 114 | self.down_conv4 = nn.Sequential( 115 | nn.Conv3d(128, 256, 3, 1, padding=1), 116 | nn.PReLU(256) 117 | ) 118 | 119 | self.up_conv2 = nn.Sequential( 120 | nn.ConvTranspose3d(256, 128, 2, 2), 121 | nn.PReLU(128) 122 | ) 123 | 124 | self.up_conv3 = nn.Sequential( 125 | nn.ConvTranspose3d(128, 64, 2, 2), 126 | nn.PReLU(64) 127 | ) 128 | 129 | self.up_conv4 = nn.Sequential( 130 | nn.ConvTranspose3d(64, 32, 2, 2), 131 | nn.PReLU(32) 132 | ) 133 | 134 | # 最后大尺度下的映射(256*256),下面的尺度依次递减 135 | self.map4 = nn.Sequential( 136 | nn.Conv3d(32, out_channel, 1, 1), 137 | nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear', align_corners=False), 138 | nn.Softmax(dim=1) 139 | ) 140 | 141 | # 128*128 尺度下的映射 142 | self.map3 = nn.Sequential( 143 | nn.Conv3d(64, out_channel, 1, 1), 144 | nn.Upsample(scale_factor=(2, 4, 4), mode='trilinear', align_corners=False), 145 | nn.Softmax(dim=1) 146 | ) 147 | 148 | # 64*64 尺度下的映射 149 | self.map2 = nn.Sequential( 150 | nn.Conv3d(128, out_channel, 1, 1), 151 | nn.Upsample(scale_factor=(4, 8, 8), mode='trilinear', align_corners=False), 152 | 153 | nn.Softmax(dim=1) 154 | ) 155 | 156 | # 32*32 尺度下的映射 157 | self.map1 = nn.Sequential( 158 | nn.Conv3d(256, out_channel, 1, 1), 159 | nn.Upsample(scale_factor=(8, 16, 16), mode='trilinear', align_corners=False), 160 | nn.Softmax(dim=1) 161 | ) 162 | 163 | def forward(self, inputs): 164 | 165 | long_range1 = self.encoder_stage1(inputs) + inputs 166 | 167 | short_range1 = self.down_conv1(long_range1) 168 | 169 | long_range2 = self.encoder_stage2(short_range1) + short_range1 170 | long_range2 = F.dropout(long_range2, self.dorp_rate, self.training) 171 | 172 | short_range2 = self.down_conv2(long_range2) 173 | 174 | long_range3 = self.encoder_stage3(short_range2) + short_range2 175 | long_range3 = F.dropout(long_range3, self.dorp_rate, self.training) 176 | 177 | short_range3 = self.down_conv3(long_range3) 178 | 179 | long_range4 = self.encoder_stage4(short_range3) + short_range3 180 | long_range4 = F.dropout(long_range4, self.dorp_rate, self.training) 181 | 182 | short_range4 = self.down_conv4(long_range4) 183 | 184 | outputs = self.decoder_stage1(long_range4) + short_range4 185 | outputs = F.dropout(outputs, self.dorp_rate, self.training) 186 | 187 | output1 = self.map1(outputs) 188 | 189 | short_range6 = self.up_conv2(outputs) 190 | 191 | outputs = self.decoder_stage2(torch.cat([short_range6, long_range3], dim=1)) + short_range6 192 | outputs = F.dropout(outputs, self.dorp_rate, self.training) 193 | 194 | output2 = self.map2(outputs) 195 | 196 | short_range7 = self.up_conv3(outputs) 197 | 198 | outputs = self.decoder_stage3(torch.cat([short_range7, long_range2], dim=1)) + short_range7 199 | outputs = F.dropout(outputs, self.dorp_rate, self.training) 200 | 201 | output3 = self.map3(outputs) 202 | 203 | short_range8 = self.up_conv4(outputs) 204 | 205 | outputs = self.decoder_stage4(torch.cat([short_range8, long_range1], dim=1)) + short_range8 206 | 207 | output4 = self.map4(outputs) 208 | 209 | if self.training is True: 210 | return output1, output2, output3, output4 211 | else: 212 | return output4 213 | -------------------------------------------------------------------------------- /models/KiUNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is referenced from https://github.com/jeya-maria-jose/KiU-Net-pytorch/blob/master/LiTS/net/models.py 3 | """ 4 | 5 | import os 6 | import sys 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | # 该模型是一种轻量kiuNet实现,我在原始代码基础上,减去了最小分辨率的编解码过程 13 | class KiUNet_min(nn.Module): 14 | def __init__(self, in_channel=1, out_channel=2, training=True): 15 | super(KiUNet_min, self).__init__() 16 | self.training = training 17 | self.encoder1 = nn.Conv3d(in_channel, 32, 3, stride=1, padding=1) 18 | self.encoder2= nn.Conv3d(32, 64, 3, stride=1, padding=1) 19 | self.encoder3= nn.Conv3d(64, 128, 3, stride=1, padding=1) 20 | self.encoder4= nn.Conv3d(128, 256, 3, stride=1, padding=1) 21 | # self.encoder5= nn.Conv3d(256, 512, 3, stride=1, padding=1) 22 | 23 | self.kencoder1 = nn.Conv3d(1, 32, 3, stride=1, padding=1) 24 | self.kdecoder1 = nn.Conv3d(32, 2, 3, stride=1, padding=1) 25 | 26 | # self.decoder1 = nn.Conv3d(512, 256, 3, stride=1,padding=1) 27 | self.decoder2 = nn.Conv3d(256, 128, 3, stride=1, padding=1) 28 | self.decoder3 = nn.Conv3d(128, 64, 3, stride=1, padding=1) 29 | self.decoder4 = nn.Conv3d(64, 32, 3, stride=1, padding=1) 30 | self.decoder5 = nn.Conv3d(32, 2, 3, stride=1, padding=1) 31 | 32 | self.map4 = nn.Sequential( 33 | nn.Conv3d(2, out_channel, 1, 1), 34 | nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear'), 35 | # nn.Sigmoid() 36 | nn.Softmax(dim =1) 37 | ) 38 | 39 | # 128*128 尺度下的映射 40 | self.map3 = nn.Sequential( 41 | nn.Conv3d(64, out_channel, 1, 1), 42 | nn.Upsample(scale_factor=(4, 8, 8), mode='trilinear'), 43 | # nn.Sigmoid() 44 | nn.Softmax(dim =1) 45 | ) 46 | 47 | # 64*64 尺度下的映射 48 | self.map2 = nn.Sequential( 49 | nn.Conv3d(128, out_channel, 1, 1), 50 | nn.Upsample(scale_factor=(8, 16, 16), mode='trilinear'), 51 | # nn.Sigmoid() 52 | nn.Softmax(dim =1) 53 | ) 54 | 55 | # 32*32 尺度下的映射 56 | self.map1 = nn.Sequential( 57 | nn.Conv3d(256, out_channel, 1, 1), 58 | nn.Upsample(scale_factor=(16, 32, 32), mode='trilinear'), 59 | # nn.Sigmoid() 60 | nn.Softmax(dim =1) 61 | ) 62 | 63 | def forward(self, x): 64 | 65 | out = F.relu(F.max_pool3d(self.encoder1(x),2,2)) 66 | t1 = out 67 | out = F.relu(F.max_pool3d(self.encoder2(out),2,2)) 68 | t2 = out 69 | out = F.relu(F.max_pool3d(self.encoder3(out),2,2)) 70 | t3 = out 71 | out = F.relu(F.max_pool3d(self.encoder4(out),2,2)) 72 | 73 | # t4 = out 74 | # out = F.relu(F.max_pool3d(self.encoder5(out),2,2)) 75 | 76 | # t2 = out 77 | # out = F.relu(F.interpolate(self.decoder1(out),scale_factor=(2,2,2),mode ='trilinear')) 78 | # print(out.shape,t4.shape) 79 | # out = torch.add(F.pad(out,[0,0,0,0,0,1]),t4) 80 | # out = torch.add(out,t4) 81 | 82 | output1 = self.map1(out) 83 | out = F.relu(F.interpolate(self.decoder2(out),scale_factor=(2,2,2),mode ='trilinear')) 84 | out = torch.add(out,t3) 85 | output2 = self.map2(out) 86 | out = F.relu(F.interpolate(self.decoder3(out),scale_factor=(2,2,2),mode ='trilinear')) 87 | out = torch.add(out,t2) 88 | output3 = self.map3(out) 89 | out = F.relu(F.interpolate(self.decoder4(out),scale_factor=(2,2,2),mode ='trilinear')) 90 | out = torch.add(out,t1) 91 | 92 | out1 = F.relu(F.interpolate(self.kencoder1(x),scale_factor=(1,2,2),mode ='trilinear')) 93 | out1 = F.relu(F.interpolate(self.kdecoder1(out1),scale_factor=(1,0.5,0.5),mode ='trilinear')) 94 | 95 | out = F.relu(F.interpolate(self.decoder5(out),scale_factor=(2,2,2),mode ='trilinear')) 96 | # print(out.shape,out1.shape) 97 | out = torch.add(out,out1) 98 | output4 = self.map4(out) 99 | 100 | # print(out.shape) 101 | 102 | # print(output1.shape,output2.shape,output3.shape,output4.shape) 103 | if self.training is True: 104 | return output1, output2, output3, output4 105 | else: 106 | return output4 107 | 108 | # 该实现还有些问题,官方代码本身也存在问题,还未给出回应。我尝试按照论文复现,但显存占用率过高无法训练推理 109 | class KiUNet_org(nn.Module): 110 | 111 | def __init__(self,in_channel=1, out_channel=2, training=True): 112 | super(KiUNet_org, self).__init__() 113 | self.training = training 114 | self.start = nn.Conv3d(in_channel, 1, 3, stride=1, padding=1) 115 | 116 | self.encoder1 = nn.Conv3d(1, 16, 3, stride=1, padding=1) # First Layer GrayScale Image , change to input channels to 3 in case of RGB 117 | self.en1_bn = nn.BatchNorm3d(16) 118 | self.encoder2= nn.Conv3d(16, 32, 3, stride=1, padding=1) 119 | self.en2_bn = nn.BatchNorm3d(32) 120 | self.encoder3= nn.Conv3d(32, 64, 3, stride=1, padding=1) 121 | self.en3_bn = nn.BatchNorm3d(64) 122 | 123 | self.decoder1 = nn.Conv3d(64, 32, 3, stride=1, padding=1) 124 | self.de1_bn = nn.BatchNorm3d(32) 125 | self.decoder2 = nn.Conv3d(32,16, 3, stride=1, padding=1) 126 | self.de2_bn = nn.BatchNorm3d(16) 127 | self.decoder3 = nn.Conv3d(16, 8, 3, stride=1, padding=1) 128 | self.de3_bn = nn.BatchNorm3d(8) 129 | 130 | self.decoderf1 = nn.Conv3d(64, 32, 3, stride=1, padding=1) 131 | self.def1_bn = nn.BatchNorm3d(32) 132 | self.decoderf2= nn.Conv3d(32, 16, 3, stride=1, padding=1) 133 | self.def2_bn = nn.BatchNorm3d(16) 134 | self.decoderf3 = nn.Conv3d(16, 8, 3, stride=1, padding=1) 135 | self.def3_bn = nn.BatchNorm3d(8) 136 | 137 | self.encoderf1 = nn.Conv3d(1, 16, 3, stride=1, padding=1) # First Layer GrayScale Image , change to input channels to 3 in case of RGB 138 | self.enf1_bn = nn.BatchNorm3d(16) 139 | self.encoderf2= nn.Conv3d(16, 32, 3, stride=1, padding=1) 140 | self.enf2_bn = nn.BatchNorm3d(32) 141 | self.encoderf3 = nn.Conv3d(32, 64, 3, stride=1, padding=1) 142 | self.enf3_bn = nn.BatchNorm3d(64) 143 | 144 | self.intere1_1 = nn.Conv3d(16,16,3, stride=1, padding=1) 145 | self.inte1_1bn = nn.BatchNorm3d(16) 146 | self.intere2_1 = nn.Conv3d(32,32,3, stride=1, padding=1) 147 | self.inte2_1bn = nn.BatchNorm3d(32) 148 | self.intere3_1 = nn.Conv3d(64,64,3, stride=1, padding=1) 149 | self.inte3_1bn = nn.BatchNorm3d(64) 150 | 151 | 152 | self.intere1_2 = nn.Conv3d(16,16,3, stride=1, padding=1) 153 | self.inte1_2bn = nn.BatchNorm3d(16) 154 | self.intere2_2 = nn.Conv3d(32,32,3, stride=1, padding=1) 155 | self.inte2_2bn = nn.BatchNorm3d(32) 156 | self.intere3_2 = nn.Conv3d(64,64,3, stride=1, padding=1) 157 | self.inte3_2bn = nn.BatchNorm3d(64) 158 | 159 | self.interd1_1 = nn.Conv3d(32,32,3, stride=1, padding=1) 160 | self.intd1_1bn = nn.BatchNorm3d(32) 161 | self.interd2_1 = nn.Conv3d(16,16,3, stride=1, padding=1) 162 | self.intd2_1bn = nn.BatchNorm3d(16) 163 | self.interd3_1 = nn.Conv3d(64,64,3, stride=1, padding=1) 164 | self.intd3_1bn = nn.BatchNorm3d(64) 165 | 166 | self.interd1_2 = nn.Conv3d(32,32,3, stride=1, padding=1) 167 | self.intd1_2bn = nn.BatchNorm3d(32) 168 | self.interd2_2 = nn.Conv3d(16,16,3, stride=1, padding=1) 169 | self.intd2_2bn = nn.BatchNorm3d(16) 170 | self.interd3_2 = nn.Conv3d(64,64,3, stride=1, padding=1) 171 | self.intd3_2bn = nn.BatchNorm3d(64) 172 | 173 | # self.start = nn.Conv3d(1, 1, 3, stride=1, padding=1) 174 | self.final = nn.Conv3d(8,1,1,stride=1,padding=0) 175 | self.fin = nn.Conv3d(1,1,1,stride=1,padding=0) 176 | 177 | self.map4 = nn.Sequential( 178 | nn.Conv3d(32, out_channel, 1, 1), 179 | nn.Upsample(scale_factor=(4, 16, 16), mode='trilinear'), 180 | # nn.Sigmoid() 181 | nn.Softmax(dim =1) 182 | ) 183 | 184 | # 128*128 尺度下的映射 185 | self.map3 = nn.Sequential( 186 | nn.Conv3d(16, out_channel, 1, 1), 187 | nn.Upsample(scale_factor=(4, 8, 8), mode='trilinear'), 188 | # nn.Sigmoid() 189 | nn.Softmax(dim =1) 190 | ) 191 | 192 | # 64*64 尺度下的映射 193 | self.map2 = nn.Sequential( 194 | nn.Conv3d(8, out_channel, 1, 1), 195 | nn.Upsample(scale_factor=(4, 4, 4), mode='trilinear'), 196 | # nn.Sigmoid() 197 | nn.Softmax(dim =1) 198 | ) 199 | 200 | # 32*32 尺度下的映射 201 | self.map1 = nn.Sequential( 202 | nn.Conv3d(256, out_channel, 1, 1), 203 | nn.Upsample(scale_factor=(16, 32, 32), mode='trilinear'), 204 | # nn.Sigmoid() 205 | nn.Softmax(dim =1) 206 | ) 207 | 208 | # self.soft = nn.Softmax(dim =1) 209 | 210 | def forward(self, x): 211 | # print(x.shape) 212 | outx = self.start(x) 213 | # print(outx.shape) 214 | out = F.relu(self.en1_bn(F.max_pool3d(self.encoder1(outx),2,2))) #U-Net branch 215 | out1 = F.relu(self.enf1_bn(F.interpolate(self.encoderf1(outx),scale_factor=(0.5,1,1),mode ='trilinear'))) #Ki-Net branch 216 | tmp = out 217 | # print(out.shape,out1.shape) 218 | out = torch.add(out,F.interpolate(F.relu(self.inte1_1bn(self.intere1_1(out1))),scale_factor=(1,0.5,0.5),mode ='trilinear')) #CRFB 219 | out1 = torch.add(out1,F.interpolate(F.relu(self.inte1_2bn(self.intere1_2(tmp))),scale_factor=(1,2,2),mode ='trilinear')) #CRFB 220 | 221 | u1 = out #skip conn 222 | o1 = out1 #skip conn 223 | 224 | out = F.relu(self.en2_bn(F.max_pool3d(self.encoder2(out),2,2))) 225 | out1 = F.relu(self.enf2_bn(F.interpolate(self.encoderf2(out1),scale_factor=(1,1,1),mode ='trilinear'))) 226 | tmp = out 227 | # print(out.shape,out1.shape) 228 | out = torch.add(out,F.interpolate(F.relu(self.inte2_1bn(self.intere2_1(out1))),scale_factor=(0.5,0.25,0.25),mode ='trilinear')) 229 | out1 = torch.add(out1,F.interpolate(F.relu(self.inte2_2bn(self.intere2_2(tmp))),scale_factor=(2,4,4),mode ='trilinear')) 230 | 231 | u2 = out 232 | o2 = out1 233 | # out = F.pad(out,[0,0,0,0,0,1]) 234 | # print(out.shape) 235 | out = F.relu(self.en3_bn(F.max_pool3d(self.encoder3(out),2,2))) 236 | out1 = F.relu(self.enf3_bn(F.interpolate(self.encoderf3(out1),scale_factor=(1,2,2),mode ='trilinear'))) 237 | # print(out.shape,out1.shape) 238 | tmp = out 239 | out = torch.add(out,F.interpolate(F.relu(self.inte3_1bn(self.intere3_1(out1))),scale_factor=(0.5,0.0625,0.0625),mode ='trilinear')) 240 | out1 = torch.add(out1,F.interpolate(F.relu(self.inte3_2bn(self.intere3_2(tmp))),scale_factor=(2,16,16),mode ='trilinear')) 241 | 242 | ### End of encoder block 243 | 244 | ### Start Decoder 245 | 246 | out = F.relu(self.de1_bn(F.interpolate(self.decoder1(out),scale_factor=(2,2,2),mode ='trilinear'))) #U-NET 247 | out1 = F.relu(self.def1_bn(F.max_pool3d(self.decoderf1(out1),2,2))) #Ki-NET 248 | tmp = out 249 | # print(out.shape,out1.shape) 250 | out = torch.add(out,F.interpolate(F.relu(self.intd1_1bn(self.interd1_1(out1))),scale_factor=(2,0.25,0.25),mode ='trilinear')) 251 | out1 = torch.add(out1,F.interpolate(F.relu(self.intd1_2bn(self.interd1_2(tmp))),scale_factor=(0.5,4,4),mode ='trilinear')) 252 | # print(out.shape) 253 | output1 = self.map4(out) 254 | out = torch.add(out,u2) #skip conn 255 | out1 = torch.add(out1,o2) #skip conn 256 | 257 | out = F.relu(self.de2_bn(F.interpolate(self.decoder2(out),scale_factor=(1,2,2),mode ='trilinear'))) 258 | out1 = F.relu(self.def2_bn(F.max_pool3d(self.decoderf2(out1),1,1))) 259 | # print(out.shape,out1.shape) 260 | tmp = out 261 | out = torch.add(out,F.interpolate(F.relu(self.intd2_1bn(self.interd2_1(out1))),scale_factor=(1,0.5,0.5),mode ='trilinear')) 262 | out1 = torch.add(out1,F.interpolate(F.relu(self.intd2_2bn(self.interd2_2(tmp))),scale_factor=(1,2,2),mode ='trilinear')) 263 | output2 = self.map3(out) 264 | # print(out1.shape,o1.shape) 265 | out = torch.add(out,u1) 266 | out1 = torch.add(out1,o1) 267 | 268 | out = F.relu(self.de3_bn(F.interpolate(self.decoder3(out),scale_factor=(1,2,2),mode ='trilinear'))) 269 | out1 = F.relu(self.def3_bn(F.max_pool3d(self.decoderf3(out1),1,1))) 270 | # print(out.shape,out1.shape) 271 | output3 = self.map2(out) 272 | 273 | 274 | out = torch.add(out,out1) # fusion of both branches 275 | 276 | out = F.relu(self.final(out)) #1*1 conv 277 | 278 | output4 = F.interpolate(self.fin(out),scale_factor=(4,4,4),mode ='trilinear') 279 | # print(out.shape) 280 | # out = self.soft(out) 281 | # print(output1.shape,output2.shape,output3.shape,output4.shape) 282 | if self.training is True: 283 | return output1, output2, output3, output4 284 | else: 285 | return output4 286 | 287 | 288 | if __name__ == "__main__": 289 | net =KiUNet_org(training=False) 290 | in1 = torch.rand((2,1,48,256,256)) 291 | out = net(in1) 292 | for i in range(len(out)): 293 | print(out[i].size()) --------------------------------------------------------------------------------