├── LICENSE ├── README.md ├── data ├── chaos │ └── .keep └── promise12 │ └── .keep ├── datasets ├── chaos.py └── promise12.py ├── dice_loss.py ├── eval.py ├── models ├── nested_unet.py └── unet.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 ProfessorHuang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 2D-UNet-Pytorch 2 | 使用2D-UNet和2D-UNet++(Nested UNet)对Chaos、Promise12两个数据集进行分割 3 | ## 所使用的Python库 4 | - Pytorch1.x 5 | - numpy 6 | - tqdm 7 | - opencv-python 8 | - PIL 9 | - pydicom 10 | - SimpleITK 11 | ## 数据集获取 12 | ### CHAOS 13 | https://chaos.grand-challenge.org/Combined_Healthy_Abdominal_Organ_Segmentation/ 14 | 在官网下载好数据后,解压CHAOS_Train_Sets.zip压缩包,将其下的CT文件夹复制到代码目录的data/chaos文件夹中。 15 | 16 | ### PROMISE12 17 | https://promise12.grand-challenge.org/ 18 | 在官网下载好数据后,训练数据存放在三个压缩包中,将三个压缩包分别解压,并将内容复制到代码目录的data/promise12文件夹中。 19 | 20 | 具体数据存放格式如下: 21 | 22 | data 23 | ├── chaos 24 | ├──CT 25 | ├──1 26 | ├──2 27 | ├──5 28 | ├──... 29 | ├── promise12 30 | ├──Case00.mhd 31 | ├──Case00.raw 32 | ├──Case00_segmentation.mhd 33 | ├──Case00_segmentation.raw 34 | ├──Case01.mhd 35 | ├──Case01.raw 36 | ├──Case01_segmentation.mhd 37 | ├──Case01_segmentation.raw 38 | ├──... 39 | ## 训练模型 40 | 在终端中输入 41 | 42 | python train.py --model=unet --dataset=promise12 43 | 即可使用unet对promise12数据集进行训练,如果要使用unet++,就令参数--model=nestedunet,如果要使用chaos数据集,就令参数--dataset=chaos。 44 | 在模型训练开始,会在代码所在目录下生成logs_train文件夹,每次训练都会在该文件夹下生成一个子文件夹,记录当次训练的训练日志。 45 | ## 在tensorboard中观察训练曲线 46 | 代码在训练过程中会记录每个epoch在训练集上的loss和dice以及验证集上的loss和dice,并保存在tensorboard中。 47 | 在终端中输入 48 | 49 | tensorboard --logdir=logs_train 50 | 在浏览器中打开对应端口,即可使用tensorboard观察训练记录。 51 | -------------------------------------------------------------------------------- /data/chaos/.keep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/promise12/.keep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/chaos.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pydicom 3 | import PIL.Image as Image 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | import torchvision.transforms as transforms 7 | import cv2 8 | 9 | class Chaos(Dataset): 10 | 11 | def __init__(self, data_dir, mode): 12 | 13 | base_dir = os.path.join(data_dir, 'CT') 14 | patient_numbers = os.listdir(base_dir) 15 | patient_numbers.sort(key=lambda x: eval(x)) 16 | image_dirs = [os.path.join(base_dir, index, 'DICOM_anon') for index in patient_numbers] 17 | mask_dirs = [os.path.join(base_dir, index, 'Ground') for index in patient_numbers] 18 | 19 | 20 | # 5-fold cross validation 21 | fold_size = len(image_dirs) // 5 22 | k=1 23 | val_list = range(k*fold_size, (k+1)*fold_size) 24 | 25 | self.train_image_paths = [] 26 | self.val_image_paths = [] 27 | self.train_mask_paths = [] 28 | self.val_mask_paths = [] 29 | 30 | 31 | for i in range(len(image_dirs)): 32 | 33 | image_files = os.listdir(image_dirs[i]) 34 | image_files.sort(key=image_name_key) 35 | mask_files = os.listdir(mask_dirs[i]) 36 | mask_files.sort(key=mask_name_key) 37 | 38 | if i in val_list: 39 | for image_file in image_files: 40 | self.val_image_paths.append(os.path.join(image_dirs[i], image_file)) 41 | for mask_file in mask_files: 42 | self.val_mask_paths.append(os.path.join(mask_dirs[i], mask_file)) 43 | else: 44 | for image_file in image_files: 45 | self.train_image_paths.append(os.path.join(image_dirs[i], image_file)) 46 | for mask_file in mask_files: 47 | self.train_mask_paths.append(os.path.join(mask_dirs[i], mask_file)) 48 | 49 | self.mean = [0.3667] 50 | self.std = [0.3533] 51 | self.mode = mode 52 | 53 | def __getitem__(self, i): 54 | 55 | if self.mode == 'train': 56 | img_path, mask_path = self.train_image_paths[i], self.train_mask_paths[i] 57 | elif self.mode == 'val': 58 | img_path, mask_path = self.val_image_paths[i], self.val_mask_paths[i] 59 | 60 | mask = Image.open(mask_path).convert('L') 61 | mask = transforms.Resize((256,256), interpolation=Image.NEAREST)(mask) 62 | mask_tensor = transforms.ToTensor()(mask) 63 | 64 | dataset = pydicom.dcmread(img_path) 65 | HU_img = dataset.RescaleSlope * dataset.pixel_array + dataset.RescaleIntercept 66 | MIN_BOUND = -1000.0 67 | MAX_BOUND = 400.0 68 | # change the scope to 0-1 69 | img = (HU_img - MIN_BOUND) / (MAX_BOUND - MIN_BOUND) 70 | img[img>1] = 1. 71 | img[img<0] = 0 72 | # resize 73 | img = cv2.resize(img, (256,256), interpolation=cv2.INTER_CUBIC) 74 | img_tensor = transforms.ToTensor()(img) 75 | img_tensor = transforms.Normalize(mean=self.mean, std=self.std)(img_tensor) 76 | 77 | return {'image': img_tensor, 'mask': mask_tensor} 78 | 79 | def __len__(self): 80 | if self.mode == 'train': 81 | return len(self.train_image_paths) 82 | elif self.mode == 'val': 83 | return len(self.val_image_paths) 84 | 85 | def cal_mean_std(self): 86 | image_paths = self.train_image_paths + self.val_image_paths 87 | image_array = np.zeros((len(image_paths), 512, 512)) 88 | for i in range(len(image_paths)): 89 | dataset = pydicom.dcmread(image_paths[i]) 90 | HU_img = dataset.RescaleSlope * dataset.pixel_array + dataset.RescaleIntercept 91 | MIN_BOUND = -1000.0 92 | MAX_BOUND = 400.0 93 | img = (HU_img - MIN_BOUND) / (MAX_BOUND - MIN_BOUND) 94 | img[img>1] = 1. 95 | img[img<0] = 0 96 | image_array[i,:,:] = img 97 | return np.mean(image_array), np.std(image_array) 98 | 99 | def image_name_key(image_name): 100 | if image_name[0] == 'i': 101 | return int(image_name[1:5]) 102 | elif image_name[0:3] == 'IMG': 103 | return int(image_name[-8:-4]) 104 | 105 | def mask_name_key(mask_name): 106 | return int(mask_name[-7:-4]) 107 | -------------------------------------------------------------------------------- /datasets/promise12.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import SimpleITK as sitk 3 | from skimage.exposure import equalize_adapthist 4 | import torchvision.transforms as transforms 5 | import os 6 | import numpy as np 7 | import cv2 8 | 9 | class Promise12(Dataset): 10 | 11 | def __init__(self, data_dir, mode): 12 | 13 | # store data in the npy file 14 | np_data_path = os.path.join(data_dir, 'npy_image') 15 | if not os.path.exists(np_data_path): 16 | os.makedirs(np_data_path) 17 | data_to_array(data_dir, np_data_path, 256, 256) 18 | else: 19 | print('read the data from: {}'.format(np_data_path)) 20 | 21 | self.mode = mode 22 | # read the data from npy 23 | self.X_train = np.load(os.path.join(np_data_path, 'X_train.npy')) 24 | self.y_train = np.load(os.path.join(np_data_path, 'y_train.npy')) 25 | self.X_val = np.load(os.path.join(np_data_path, 'X_val.npy')) 26 | self.y_val = np.load(os.path.join(np_data_path, 'y_val.npy')) 27 | 28 | 29 | 30 | def __getitem__(self, i): 31 | 32 | if self.mode == 'train': 33 | img, mask = self.X_train[i], self.y_train[i] 34 | elif self.mode == 'val': 35 | img, mask = self.X_val[i], self.y_val[i] 36 | 37 | img_tensor = transforms.ToTensor()(img) 38 | mask_tensor = transforms.ToTensor()(mask.astype(np.float32)) 39 | 40 | return {'image': img_tensor, 'mask': mask_tensor} 41 | 42 | def __len__(self): 43 | if self.mode == 'train': 44 | return self.X_train.shape[0] 45 | elif self.mode == 'val': 46 | return self.X_val.shape[0] 47 | 48 | 49 | def data_to_array(base_path, store_path, img_rows, img_cols): 50 | 51 | fileList = os.listdir(base_path) 52 | 53 | fileList = sorted((x for x in fileList if '.mhd' in x)) 54 | 55 | val_list = [5, 15, 25, 35, 45] 56 | train_list = list(set(range(50)) - set(val_list) ) 57 | count = 0 58 | for the_list in [train_list, val_list]: 59 | images = [] 60 | masks = [] 61 | 62 | filtered = [file for file in fileList for ff in the_list if str(ff).zfill(2) in file ] 63 | 64 | for filename in filtered: 65 | 66 | itkimage = sitk.ReadImage(os.path.join(base_path, filename)) 67 | imgs = sitk.GetArrayFromImage(itkimage) 68 | 69 | if 'segm' in filename.lower(): 70 | imgs = img_resize(imgs, img_rows, img_cols, equalize=False) 71 | masks.append(imgs) 72 | else: 73 | imgs = img_resize(imgs, img_rows, img_cols, equalize=True) 74 | images.append(imgs) 75 | 76 | # images: slices x w x h ==> total number x w x h 77 | images = np.concatenate(images , axis=0 ).reshape(-1, img_rows, img_cols) 78 | masks = np.concatenate(masks, axis=0).reshape(-1, img_rows, img_cols) 79 | masks = masks.astype(np.uint8) 80 | 81 | # Smooth images using CurvatureFlow 82 | images = smooth_images(images) 83 | images = images.astype(np.float32) 84 | 85 | if count==0: 86 | mu = np.mean(images) 87 | sigma = np.std(images) 88 | images = (images - mu)/sigma 89 | 90 | np.save(os.path.join(store_path, 'X_train.npy'), images) 91 | np.save(os.path.join(store_path,'y_train.npy'), masks) 92 | elif count==1: 93 | images = (images - mu)/sigma 94 | np.save(os.path.join(store_path, 'X_val.npy'), images) 95 | np.save(os.path.join(store_path,'y_val.npy'), masks) 96 | count+=1 97 | 98 | def img_resize(imgs, img_rows, img_cols, equalize=True): 99 | 100 | new_imgs = np.zeros([len(imgs), img_rows, img_cols]) 101 | for mm, img in enumerate(imgs): 102 | if equalize: 103 | img = equalize_adapthist(img, clip_limit=0.05) 104 | 105 | new_imgs[mm] = cv2.resize(img, (img_rows, img_cols), interpolation=cv2.INTER_NEAREST ) 106 | 107 | return new_imgs 108 | 109 | def smooth_images(imgs, t_step=0.125, n_iter=5): 110 | """ 111 | Curvature driven image denoising. 112 | In my experience helps significantly with segmentation. 113 | """ 114 | 115 | for mm in range(len(imgs)): 116 | img = sitk.GetImageFromArray(imgs[mm]) 117 | img = sitk.CurvatureFlow(image1=img, 118 | timeStep=t_step, 119 | numberOfIterations=n_iter) 120 | 121 | imgs[mm] = sitk.GetArrayFromImage(img) 122 | 123 | return imgs 124 | -------------------------------------------------------------------------------- /dice_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class DiceCoeff(Function): 7 | """Dice coeff for individual examples""" 8 | 9 | def forward(self, input, target): 10 | self.save_for_backward(input, target) 11 | eps = 0.0001 12 | self.inter = torch.dot(input.view(-1), target.view(-1)) 13 | self.union = torch.sum(input) + torch.sum(target) + eps 14 | 15 | t = (2 * self.inter.float() + eps) / self.union.float() 16 | return t 17 | 18 | # This function has only a single output, so it gets only one gradient 19 | def backward(self, grad_output): 20 | 21 | input, target = self.saved_variables 22 | grad_input = grad_target = None 23 | 24 | if self.needs_input_grad[0]: 25 | grad_input = grad_output * 2 * (target * self.union - self.inter) \ 26 | / (self.union * self.union) 27 | if self.needs_input_grad[1]: 28 | grad_target = None 29 | 30 | return grad_input, grad_target 31 | 32 | 33 | def dice_coeff(input, target): 34 | """Dice coeff for batches""" 35 | if input.is_cuda: 36 | s = torch.FloatTensor(1).cuda().zero_() 37 | else: 38 | s = torch.FloatTensor(1).zero_() 39 | 40 | for i, c in enumerate(zip(input, target)): 41 | s = s + DiceCoeff().forward(c[0], c[1]) 42 | 43 | return s / (i + 1) 44 | 45 | class DiceBCELoss(nn.Module): 46 | def __init__(self, weight=None, size_average=True): 47 | super(DiceBCELoss, self).__init__() 48 | 49 | def forward(self, inputs, targets, smooth=1): 50 | 51 | #comment out if your model contains a sigmoid or equivalent activation layer 52 | inputs = torch.sigmoid(inputs) 53 | 54 | #flatten label and prediction tensors 55 | inputs = inputs.view(-1) 56 | targets = targets.view(-1) 57 | 58 | intersection = (inputs * targets).sum() 59 | dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) 60 | BCE = F.binary_cross_entropy(inputs, targets, reduction='mean') 61 | Dice_BCE = BCE + dice_loss 62 | 63 | return Dice_BCE 64 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from tqdm import tqdm 4 | 5 | from dice_loss import dice_coeff, DiceBCELoss 6 | 7 | 8 | def eval_net(net, loader, device, criterion): 9 | """Evaluation without the densecrf with the dice coefficient""" 10 | net.eval() 11 | batch_size = loader.batch_size 12 | n_val = len(loader) * batch_size # the number of batch 13 | tot_dice = 0 14 | tot_loss = 0 15 | 16 | with tqdm(total=n_val, desc='Validation round', unit='img', leave=False) as pbar: 17 | for batch in loader: 18 | imgs, true_masks = batch['image'], batch['mask'] 19 | imgs = imgs.to(device=device, dtype=torch.float32) 20 | true_masks = true_masks.to(device=device, dtype=torch.float32) 21 | 22 | with torch.no_grad(): 23 | masks_pred = net(imgs) 24 | loss = criterion(masks_pred, true_masks) 25 | tot_loss += loss.item() 26 | 27 | pred = torch.sigmoid(masks_pred) 28 | pred = (pred > 0.5).float() 29 | tot_dice += dice_coeff(pred, true_masks).item() 30 | pbar.update(batch_size) 31 | 32 | net.train() 33 | return tot_dice / len(loader), tot_loss / len(loader) 34 | -------------------------------------------------------------------------------- /models/nested_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class VGGBlock(nn.Module): 7 | def __init__(self, in_channels, middle_channels, out_channels): 8 | super().__init__() 9 | self.relu = nn.ReLU(inplace=True) 10 | self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1) 11 | self.bn1 = nn.BatchNorm2d(middle_channels) 12 | self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1) 13 | self.bn2 = nn.BatchNorm2d(out_channels) 14 | 15 | def forward(self, x): 16 | out = self.conv1(x) 17 | out = self.bn1(out) 18 | out = self.relu(out) 19 | 20 | out = self.conv2(out) 21 | out = self.bn2(out) 22 | out = self.relu(out) 23 | 24 | return out 25 | 26 | class Up(nn.Module): 27 | """Upscaling and concat""" 28 | 29 | def __init__(self): 30 | super().__init__() 31 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 32 | 33 | def forward(self, x1, x2): 34 | x1 = self.up(x1) 35 | # input is CHW 36 | diffY = torch.tensor([x2.size()[2] - x1.size()[2]]) 37 | diffX = torch.tensor([x2.size()[3] - x1.size()[3]]) 38 | 39 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 40 | diffY // 2, diffY - diffY // 2]) 41 | # if you have padding issues, see 42 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 43 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 44 | x = torch.cat([x2, x1], dim=1) 45 | return x 46 | 47 | 48 | class NestedUNet(nn.Module): 49 | def __init__(self, num_classes=1, input_channels=1, deep_supervision=False, **kwargs): 50 | super().__init__() 51 | 52 | nb_filter = [32, 64, 128, 256, 512] 53 | 54 | self.deep_supervision = deep_supervision 55 | 56 | self.pool = nn.MaxPool2d(2, 2) 57 | self.up = Up() 58 | 59 | self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0]) 60 | self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1]) 61 | self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2]) 62 | self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3]) 63 | self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4]) 64 | 65 | self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0]) 66 | self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1]) 67 | self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2]) 68 | self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3]) 69 | 70 | self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0]) 71 | self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1]) 72 | self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2]) 73 | 74 | self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0]) 75 | self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1]) 76 | 77 | self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0]) 78 | 79 | if self.deep_supervision: 80 | self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 81 | self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 82 | self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 83 | self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 84 | else: 85 | self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1) 86 | 87 | 88 | def forward(self, input): 89 | x0_0 = self.conv0_0(input) 90 | x1_0 = self.conv1_0(self.pool(x0_0)) 91 | x0_1 = self.conv0_1(self.up(x1_0, x0_0)) 92 | 93 | x2_0 = self.conv2_0(self.pool(x1_0)) 94 | x1_1 = self.conv1_1(self.up(x2_0, x1_0)) 95 | x0_2 = self.conv0_2(self.up(x1_1, torch.cat([x0_0, x0_1], 1))) 96 | 97 | x3_0 = self.conv3_0(self.pool(x2_0)) 98 | x2_1 = self.conv2_1(self.up(x3_0, x2_0)) 99 | x1_2 = self.conv1_2(self.up(x2_1, torch.cat([x1_0, x1_1], 1))) 100 | x0_3 = self.conv0_3(self.up(x1_2, torch.cat([x0_0, x0_1, x0_2], 1))) 101 | 102 | x4_0 = self.conv4_0(self.pool(x3_0)) 103 | x3_1 = self.conv3_1(self.up(x4_0, x3_0)) 104 | x2_2 = self.conv2_2(self.up(x3_1, torch.cat([x2_0, x2_1], 1))) 105 | x1_3 = self.conv1_3(self.up(x2_2, torch.cat([x1_0, x1_1, x1_2], 1))) 106 | x0_4 = self.conv0_4(self.up(x1_3, torch.cat([x0_0, x0_1, x0_2, x0_3], 1))) 107 | 108 | if self.deep_supervision: 109 | output1 = self.final1(x0_1) 110 | output2 = self.final2(x0_2) 111 | output3 = self.final3(x0_3) 112 | output4 = self.final4(x0_4) 113 | return [output1, output2, output3, output4] 114 | 115 | else: 116 | output = self.final(x0_4) 117 | return output -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class DoubleConv(nn.Module): 6 | """(convolution => [BN] => ReLU) * 2""" 7 | 8 | def __init__(self, in_channels, out_channels, mid_channels=None): 9 | super().__init__() 10 | if not mid_channels: 11 | mid_channels = out_channels 12 | self.double_conv = nn.Sequential( 13 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 14 | nn.BatchNorm2d(mid_channels), 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm2d(out_channels), 18 | nn.ReLU(inplace=True) 19 | ) 20 | 21 | def forward(self, x): 22 | return self.double_conv(x) 23 | 24 | 25 | class Down(nn.Module): 26 | """Downscaling with maxpool then double conv""" 27 | 28 | def __init__(self, in_channels, out_channels): 29 | super().__init__() 30 | self.maxpool_conv = nn.Sequential( 31 | nn.MaxPool2d(2), 32 | DoubleConv(in_channels, out_channels) 33 | ) 34 | 35 | def forward(self, x): 36 | return self.maxpool_conv(x) 37 | 38 | 39 | class Up(nn.Module): 40 | """Upscaling then double conv""" 41 | 42 | def __init__(self, in_channels, out_channels): 43 | super().__init__() 44 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) 45 | self.double_conv = DoubleConv(in_channels, out_channels) 46 | 47 | 48 | def forward(self, x1, x2): 49 | x1 = self.up(x1) 50 | # input is NCHW 51 | diffY = torch.tensor([x2.size()[2] - x1.size()[2]]) 52 | diffX = torch.tensor([x2.size()[3] - x1.size()[3]]) 53 | 54 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 55 | diffY // 2, diffY - diffY // 2]) 56 | # if you have padding issues, see 57 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 58 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 59 | x = torch.cat([x2, x1], dim=1) 60 | return self.double_conv(x) 61 | 62 | 63 | class OutConv(nn.Module): 64 | def __init__(self, in_channels, out_channels): 65 | super(OutConv, self).__init__() 66 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 67 | 68 | def forward(self, x): 69 | return self.conv(x) 70 | 71 | 72 | 73 | class UNet(nn.Module): 74 | def __init__(self, n_channels=1, n_classes=1): 75 | super(UNet, self).__init__() 76 | self.n_channels = n_channels 77 | self.n_classes = n_classes 78 | 79 | n_filters = [32, 64, 128, 256, 512] 80 | 81 | self.inc = DoubleConv(n_channels, n_filters[0]) 82 | self.down1 = Down(n_filters[0], n_filters[1]) 83 | self.down2 = Down(n_filters[1], n_filters[2]) 84 | self.down3 = Down(n_filters[2], n_filters[3]) 85 | self.down4 = Down(n_filters[3], n_filters[4]) 86 | self.up1 = Up(n_filters[4], n_filters[3]) 87 | self.up2 = Up(n_filters[3], n_filters[2]) 88 | self.up3 = Up(n_filters[2], n_filters[1]) 89 | self.up4 = Up(n_filters[1], n_filters[0]) 90 | self.outc = OutConv(n_filters[0], n_classes) 91 | 92 | def forward(self, x): 93 | x1 = self.inc(x) 94 | x2 = self.down1(x1) 95 | x3 = self.down2(x2) 96 | x4 = self.down3(x3) 97 | x5 = self.down4(x4) 98 | x = self.up1(x5, x4) 99 | x = self.up2(x, x3) 100 | x = self.up3(x, x2) 101 | x = self.up4(x, x1) 102 | logits = self.outc(x) 103 | return logits -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import sys 5 | import numpy as np 6 | from tqdm import tqdm 7 | import time 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch import optim 12 | from torch.utils.tensorboard import SummaryWriter 13 | from torch.utils.data import DataLoader 14 | 15 | from models.unet import UNet 16 | from models.nested_unet import NestedUNet 17 | 18 | from datasets.promise12 import Promise12 19 | from datasets.chaos import Chaos 20 | 21 | from dice_loss import DiceBCELoss, dice_coeff 22 | from eval import eval_net 23 | 24 | 25 | torch.manual_seed(2020) 26 | 27 | def train_net(net, trainset, valset, device, epochs, batch_size, lr, weight_decay, log_save_path): 28 | 29 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True) 30 | val_loader = DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, drop_last=True) 31 | 32 | writer = SummaryWriter(log_dir=log_save_path) 33 | 34 | optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay) 35 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.95) 36 | criterion = DiceBCELoss() 37 | 38 | best_DSC = 0.0 39 | for epoch in range(epochs): 40 | logging.info(f'Epoch {epoch + 1}') 41 | epoch_loss = 0 42 | epoch_dice = 0 43 | with tqdm(total=len(trainset), desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: 44 | for batch in train_loader: 45 | net.train() 46 | imgs = batch['image'] 47 | true_masks = batch['mask'] 48 | 49 | imgs = imgs.to(device=device, dtype=torch.float32) 50 | true_masks = true_masks.to(device=device, dtype=torch.float32) 51 | masks_pred = net(imgs) 52 | 53 | pred = torch.sigmoid(masks_pred) 54 | pred = (pred>0.5).float() 55 | loss = criterion(masks_pred, true_masks) 56 | epoch_loss += loss.item() 57 | epoch_dice += dice_coeff(pred, true_masks).item() 58 | optimizer.zero_grad() 59 | loss.backward() 60 | nn.utils.clip_grad_value_(net.parameters(), 5) 61 | optimizer.step() 62 | 63 | pbar.set_postfix(**{'loss (batch)': loss.item()}) 64 | pbar.update(imgs.shape[0]) 65 | 66 | scheduler.step() 67 | 68 | logging.info('Training loss: {}'.format(epoch_loss/len(train_loader))) 69 | writer.add_scalar('Train/loss', epoch_loss/len(train_loader), epoch) 70 | logging.info('Training DSC: {}'.format(epoch_dice/len(train_loader))) 71 | writer.add_scalar('Train/dice', epoch_dice/len(train_loader), epoch) 72 | 73 | val_dice, val_loss = eval_net(net, val_loader, device, criterion) 74 | logging.info('Validation Loss: {}'.format(val_loss)) 75 | writer.add_scalar('Val/loss', val_loss, epoch) 76 | logging.info('Validation DSC: {}'.format(val_dice)) 77 | writer.add_scalar('Val/dice', val_dice, epoch) 78 | 79 | writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch) 80 | 81 | # writer.add_images('images', imgs, epoch) 82 | writer.add_images('masks/true', true_masks, epoch) 83 | writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, epoch) 84 | 85 | writer.close() 86 | 87 | 88 | def get_args(): 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument('--epochs', type=int, default=100, help='Number of epochs') 91 | parser.add_argument('--batch_size', metavar='B', type=int, nargs='?', default=8, help='Batch size') 92 | parser.add_argument('--lr', metavar='LR', type=float, nargs='?', default=1e-3, help='Learning rate') 93 | parser.add_argument('--weight_decay', type=float, nargs='?', default=1e-5, help='Weight decay') 94 | parser.add_argument('--model', type=str, default='unet', help='Model name') 95 | parser.add_argument('--dataset', type=str, default='promise12', help='Dataset name') 96 | parser.add_argument('--gpu', type=int, default='0', help='GPU number') 97 | parser.add_argument('--save', type=str, default='EXP', help='Experiment name') 98 | return parser.parse_args() 99 | 100 | 101 | if __name__ == '__main__': 102 | 103 | args = get_args() 104 | args.save = 'logs_train/{}-{}-{}'.format(args.model, args.dataset, time.strftime("%Y%m%d-%H%M%S")) 105 | if not os.path.exists(args.save): 106 | os.makedirs(args.save) 107 | log_format = '%(asctime)s %(message)s' 108 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') 109 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 110 | fh.setFormatter(logging.Formatter(log_format)) 111 | logging.getLogger().addHandler(fh) 112 | logging.info(f''' 113 | Model: {args.model} 114 | Dataset: {args.dataset} 115 | Total Epochs: {args.epochs} 116 | Batch size: {args.batch_size} 117 | Learning rate: {args.lr} 118 | Weight decay: {args.weight_decay} 119 | Device: GPU{args.gpu} 120 | Log name: {args.save} 121 | ''') 122 | 123 | torch.cuda.set_device(args.gpu) 124 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 125 | 126 | # choose a model 127 | if args.model == 'unet': 128 | net = UNet() 129 | elif args.model == 'nestedunet': 130 | net = NestedUNet() 131 | 132 | net.to(device=device) 133 | 134 | 135 | # choose a dataset 136 | 137 | if args.dataset == 'promise12': 138 | dir_data = '../data/promise12' 139 | trainset = Promise12(dir_data, mode='train') 140 | valset = Promise12(dir_data, mode='val') 141 | elif args.dataset == 'chaos': 142 | dir_data = '../data/chaos' 143 | trainset = Chaos(dir_data, mode='train') 144 | valset = Chaos(dir_data, mode='val') 145 | 146 | try: 147 | train_net(net=net, 148 | trainset=trainset, 149 | valset=valset, 150 | epochs=args.epochs, 151 | batch_size=args.batch_size, 152 | lr=args.lr, 153 | weight_decay=args.weight_decay, 154 | device=device, 155 | log_save_path=args.save) 156 | except KeyboardInterrupt: 157 | try: 158 | sys.exit(0) 159 | except SystemExit: 160 | os._exit(0) 161 | --------------------------------------------------------------------------------