├── .gitignore ├── README.md ├── inputs.py ├── train_vnet.py ├── utils.py └── vnet.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Jupyter Notebook 7 | *.ipynb_checkpoints 8 | 9 | # Source data 10 | data/ 11 | 12 | # PyCharm 13 | .idea/ 14 | 15 | .idea/deployment.xml 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Segmentation & PyTorch 2 | 3 | ## **Tools** 4 | 5 | 6 | ### **PyTorch 0.1.12** 7 | 8 | 9 | ### **Python 3.6** 10 | 11 | 12 | **visdom** 13 | --- 14 | Launch 15 | `python -m visdom.server` 16 | > Access http://localhost:8097 in the browser, see [more](https://github.com/facebookresearch/visdom). 17 | > 18 | > **NOTE**
19 | > Also could access via host_ip:port in LAN. 20 | 21 | Extremely **flexible** compared with TensorFlow, it looks like this, 22 | ![snapshot](http://r.photo.store.qq.com/psb?/V10rff5c47qzzd/VyRyappjA3B3GfSdvxJu.bDtOdn0V9ekPWn.dGaC40E!/o/dB8BAAAAAAAA&bo=awWAAisHTwMRALs!&rf=viewer_4) 23 | -------------------------------------------------------------------------------- /inputs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import nibabel as nib 5 | import numpy as np 6 | import torch 7 | import torch.utils.data as data 8 | from skimage.exposure import adjust_gamma, adjust_sigmoid 9 | from skimage.filters import sobel_h 10 | from skimage.transform import rotate, rescale 11 | 12 | 13 | def _get_boundary(im): 14 | """Find the upper and lower boundary between body and background.""" 15 | edge_sobel = sobel_h(im) 16 | threshold = np.max(edge_sobel) / 20 17 | top, bottom = np.where(edge_sobel > threshold)[0][[0, -1]] # index arrays 18 | return top, bottom 19 | 20 | 21 | def _banish_darkness(xs, ys): 22 | """Clip black background region from nii raw data along y-axis, to alleviate computations. 23 | Argument: 24 | A tuple consists (image_3d, label_3d) 25 | image_3d: int16 with shape [depth, height, width] 26 | label_3d: uint8 with shape [depth, height, width] 27 | Return: 28 | tuples of (images, labels, top, bottom). 29 | + images: [depth, reduced_height, width] 30 | + labels: [depth, reduced_height, width] 31 | + top: upper boundary 32 | + bottom: lower boundary 33 | """ 34 | 35 | boundaries = np.array([_get_boundary(im) for im in xs]) 36 | t, b = np.mean(boundaries, axis=0).astype(np.uint8) 37 | # Empirically the lower boundary is more robust. 38 | if (b - t) < 180: 39 | t = b - 180 40 | return xs[:, t: b, :], ys[:, t: b, :], t, b 41 | 42 | 43 | def _augment(xs): 44 | """Image adjustment doesn't change image shape, but for intensity. 45 | Return: 46 | images: 4-d tensor with shape [depth, height, width, channels] 47 | """ 48 | 49 | # `xs` has shape [depth, height, width] with value in [0, 1]. 50 | gamma = np.random.uniform(low=0.9, high=1.1) 51 | return adjust_gamma(xs, gamma) 52 | 53 | 54 | def _rotate_and_rescale(xs, ys): 55 | """Rotate images and labels and scale image and labels by a certain factor. 56 | Both need to swap axis from [depth, height, width] to [height, width, depth] 57 | required by skimage.transform library. 58 | """ 59 | 60 | degree = np.int(np.random.uniform(low=-3, high=5)) 61 | factor = np.random.uniform(low=0.85, high=0.95) 62 | # swap axis 63 | HWC_xs, HWC_ys = [np.transpose(item, [1, 2, 0]) for item in [xs, ys]] 64 | # rotate and rescale 65 | HWC_xs, HWC_ys = [rotate(item, degree, mode='symmetric', preserve_range=True) for item in [HWC_xs, HWC_ys]] 66 | HWC_xs, HWC_ys = [rescale(item, factor, mode='symmetric', preserve_range=True) for item in [HWC_xs, HWC_ys]] 67 | # swap back 68 | xs, ys = [np.transpose(item, [2, 0, 1]) for item in [HWC_xs, HWC_ys]] 69 | return xs, ys 70 | 71 | 72 | def _translate(xs, ys): 73 | """Perform translate, and the displacement is skewed to 0. 74 | In detail, take samples from the modified power function distribution. 75 | """ 76 | 77 | samples = np.random.power(5, size=4) # samples now in range [0, 1] 78 | skewed_samples = np.int8((- samples + 2) * 15) # skewed_samples in range [1, 16] 79 | r1, c1, r2, c2 = skewed_samples # discard 0 for indexing `-0` 80 | trans_xs, trans_ys = [item[:, r1: -r2, c1: -c2] for item in [xs, ys]] 81 | return trans_xs, trans_ys 82 | 83 | 84 | class DatasetFromFolder(data.Dataset): 85 | def __init__(self, data_path='./data/train'): 86 | """Assume dataset is in directory '.data/train' or './data/val' 87 | 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open). 88 | 2. Preprocess the data (e.g. torchvision.Transform). 89 | 3. Return a data pair (e.g. image and label). 90 | """ 91 | super(DatasetFromFolder, self).__init__() 92 | self.data_path = data_path 93 | self.label_path = [os.path.join(self.data_path, p) 94 | for p in os.listdir(self.data_path) if p.endswith('Label.nii')] 95 | self.image_path = [p.replace('_Label', '') for p in self.label_path] 96 | 97 | def __getitem__(self, index): 98 | # Set random seed for ramdom augment. 99 | np.random.seed(int(time.time())) 100 | 101 | # Load nii file. 102 | xs, ys = [nib.load(p[index]).get_data() for p in [self.image_path, self.label_path]] 103 | 104 | # Crop black region to reduce nii volumes. 105 | xs, ys, *_ = _banish_darkness(xs, ys) 106 | 107 | # Normalize, `xs` with dtype float64 108 | xs = (xs - np.min(xs)) / (np.max(xs) - np.min(xs)) 109 | 110 | # Image augment. 111 | xs, ys = _rotate_and_rescale(xs, ys) 112 | xs, ys = _translate(xs, ys) 113 | xs = _augment(xs) 114 | 115 | # Regenerate the binary label, just in case. 116 | ys = (ys > 0.5).astype(np.uint8) 117 | 118 | # Add gray image channel, with shape [1, depth, height, width] 119 | xs, ys = [item[np.newaxis, ...] for item in [xs, ys]] 120 | return torch.from_numpy(xs), torch.from_numpy(ys) 121 | 122 | def __len__(self): 123 | return len(self.image_path) 124 | -------------------------------------------------------------------------------- /train_vnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import visdom 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | 7 | from torch.autograd import Variable 8 | from torch.utils.data import DataLoader 9 | 10 | from vnet import VNet 11 | from utils import * 12 | from inputs import DatasetFromFolder 13 | from functools import partial 14 | 15 | # Set net configuration 16 | conf = config() 17 | conf.prefix = 'vnet' 18 | conf.checkpoint_dir += conf.prefix 19 | conf.result_dir += conf.prefix 20 | conf.learning_rate = 1e-6 21 | conf.from_scratch = False 22 | conf.resume_step = -1 23 | conf.criterion = 'dice' # 'dice' or 'nll' 24 | 25 | # Instantiate plot 26 | vis = visdom.Visdom() 27 | 28 | # GPU configuration 29 | if conf.cuda: 30 | torch.cuda.set_device(3) 31 | print('===> Current GPU device is', torch.cuda.current_device()) 32 | 33 | torch.manual_seed(conf.seed) 34 | if conf.cuda: 35 | torch.cuda.manual_seed(conf.seed) 36 | 37 | 38 | def training_data_loader(): 39 | return DataLoader(dataset=DatasetFromFolder(), num_workers=conf.threads, batch_size=conf.batch_size, shuffle=True) 40 | 41 | 42 | def validation_data_loader(): 43 | return DataLoader(dataset=DatasetFromFolder('./data/val'), num_workers=conf.threads, batch_size=conf.batch_size, 44 | shuffle=True) 45 | 46 | 47 | def get_resume_path(): 48 | """Return latest checkpoints by default otherwise return the specified one.""" 49 | 50 | names = [os.path.join(conf.checkpoint_dir, p) for p in os.listdir(conf.checkpoint_dir)] 51 | require = os.path.join(conf.checkpoint_dir, conf.prefix + '_' + str(conf.resume_step) + '.pth') 52 | if conf.resume_step == -1: 53 | return sorted(names, key=os.path.getmtime)[-1] 54 | elif os.path.isfile(require): 55 | return require 56 | raise Exception('\'%s\' dose not exist!' % require) 57 | 58 | 59 | def save_checkpoints(model, step): 60 | # Save 20 checkpoints at most 61 | names = os.listdir(conf.checkpoint_dir) 62 | if len(names) >= 20: 63 | os.remove(os.path.join(conf.checkpoint_dir, names[0])) 64 | # Recommand: save and load only the model parameters 65 | filename = conf.prefix + '_' + str(step) + '.pth' 66 | torch.save(model.state_dict(), os.path.join(conf.checkpoint_dir, filename)) 67 | print("===> ===> ===> Save checkpoint {} to {}".format(step, filename)) 68 | 69 | 70 | def main(): 71 | print('===> Building vnet...') 72 | model = VNet(conf.criterion) 73 | if conf.cuda: 74 | model = model.cuda() 75 | 76 | if conf.criterion == 'nll': 77 | # To balance between foreground and backgound for NLL 78 | pos_ratio = np.mean([label.float().mean() for image, label in validation_data_loader()]) 79 | bg_weight = pos_ratio / (1. + pos_ratio) 80 | fg_weight = 1. - bg_weight 81 | class_weight = torch.FloatTensor([bg_weight, fg_weight]) 82 | if conf.cuda: 83 | class_weight = class_weight.cuda() 84 | print('---> Background weight:', bg_weight) 85 | 86 | criterion = partial(F.nll_loss, weight=class_weight) 87 | # Must be 'dice' here 88 | else: 89 | criterion = dice_loss 90 | 91 | print('===> Loss function: {}'.format(conf.criterion)) 92 | print('===> Number of params: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 93 | 94 | stats = ['loss', 'dice', 'acc'] 95 | names = np.array([x + '_' + y for x in ['train', 'val'] for y in stats]) 96 | 97 | start_i = 1 98 | if conf.from_scratch: 99 | # Create new statistics, using dictionary contains 6 list 100 | results_dict = {name: np.zeros(conf.epochs * conf.augment_size) for name in names} 101 | # Initiate weights 102 | model.apply(weights_init) 103 | else: 104 | # Load statistics from `.npy` 105 | results_dict = np.load(os.path.join(conf.result_dir, 'results_dict.npy')).item() 106 | # Load previous checkpoints 107 | cp = get_resume_path() 108 | model.load_state_dict(torch.load(cp)) 109 | cp_name = os.path.basename(cp) 110 | print('---> Loading checkpoint {}...'.format(cp_name)) 111 | start_i = int(cp_name.split('_')[-1].split('.')[0]) + 1 112 | print('===> Begin training at epoch {}'.format(start_i)) 113 | 114 | # Define optimizer, loss is related to conf.criterion 115 | optimizer = optim.Adam(model.parameters(), lr=conf.learning_rate) 116 | total_i = conf.epochs * conf.augment_size 117 | 118 | def reimage(image, true, pred): 119 | """Process flatten `pred` and `true` to images for visualization.""" 120 | 121 | image = image.cpu().squeeze(0).permute(1, 0, 2, 3).data.numpy() 122 | true, pred = [item.cpu().view(*image.shape).numpy() for item in [true, pred]] 123 | # Display serval slices 124 | mid = image.shape[0] // 2 125 | return [item[mid - 3: mid + 3] for item in [image, true, pred]] 126 | 127 | def train(): 128 | epoch_loss = 0 129 | epoch_overlap = 0 130 | epoch_acc = 0 131 | 132 | global image, pred, true 133 | 134 | # Sets the module in training mode. 135 | # This has any effect only on modules such as Dropout or BatchNorm. 136 | model.train() 137 | for partial_epoch, (image, label) in enumerate(training_data_loader(), 1): 138 | image, label = Variable(image).float(), Variable(label).float() 139 | if conf.cuda: 140 | image = image.cuda() 141 | label = label.cuda() 142 | 143 | optimizer.zero_grad() 144 | 145 | output = model(image).contiguous() 146 | target = label.view(-1).long() 147 | 148 | loss = criterion(output, target) 149 | 150 | loss.backward() 151 | optimizer.step() 152 | 153 | epoch_loss += loss.data[0] 154 | 155 | # Compute dice overlap by `argmax` 156 | pred = output.data.max(1)[1] 157 | true = target.data.long() 158 | 159 | dice_overlap = 2 * torch.sum(pred * true) / (torch.sum(pred) + torch.sum(true)) * 100 160 | epoch_overlap += dice_overlap 161 | 162 | # Compute accuracy 163 | accuracy = pred.eq(true).cpu().sum() / true.numel() * 100 164 | epoch_acc += accuracy 165 | 166 | avg_loss, avg_dice, avg_acc = np.array([epoch_loss, epoch_overlap, epoch_acc]) / conf.training_size 167 | print_format = [i, i // conf.augment_size + 1, conf.epochs, avg_loss, avg_dice, avg_acc] 168 | print( 169 | '===> Training step {} ({}/{})\tLoss: {:.5f}\tDice Overlap: {:.5f}\tAccuracy: {:.5f}'.format(*print_format)) 170 | 171 | image, true, pred = reimage(image, true, pred) 172 | return avg_loss, avg_dice, avg_acc, image, true, pred 173 | 174 | def validate(): 175 | epoch_loss = 0 176 | epoch_overlap = 0 177 | epoch_acc = 0 178 | 179 | global image, pred, true 180 | 181 | # Sets the module in evaluation mode 182 | # The model structure is the same as `model.train` because there're no norm/drop layers 183 | model.eval() 184 | 185 | for image, label in validation_data_loader(): 186 | image, label = Variable(image, volatile=True).float(), Variable(label, volatile=True).float() 187 | if conf.cuda: 188 | image = image.cuda() 189 | label = label.cuda() 190 | 191 | output = model(image).contiguous() 192 | target = label.view(-1).long() 193 | 194 | loss = criterion(output, target) 195 | 196 | epoch_loss += loss.data[0] 197 | 198 | # Compute dice overlap 199 | pred = output.data.max(1)[1].float() 200 | true = target.data.float() 201 | dice_overlap = 2 * torch.sum(pred * true) / (torch.sum(pred) + torch.sum(true)) * 100 202 | epoch_overlap += dice_overlap 203 | 204 | # Compute accuracy 205 | accuracy = pred.eq(true).cpu().sum() / true.numel() * 100 206 | epoch_acc += accuracy 207 | 208 | avg_loss, avg_dice, avg_acc = np.array([epoch_loss, epoch_overlap, epoch_acc]) / conf.val_size 209 | print( 210 | '===> ===> Validation Performance', '-' * 60, 211 | 'Loss: %7.5f' % avg_loss, '-' * 2, 212 | 'Dice Overlap: %7.5f' % avg_dice, '-' * 2, 213 | 'Accuracy: %7.5f' % avg_acc 214 | ) 215 | 216 | image, true, pred = reimage(image, true, pred) 217 | return avg_loss, avg_dice, avg_acc, image, true, pred 218 | 219 | for i in range(start_i, total_i + 1): 220 | # `train_results` = (train_loss, train_dice, train_acc, train_im, train_true, train_pred) 221 | *train_results, train_im, train_true, train_pred = train() 222 | for j, stat in enumerate(names[:3]): 223 | results_dict[stat][i - 1] = train_results[j] 224 | 225 | # `val_results` = (val_loss, val_dice, val_acc, val_im, val_true, val_pred) 226 | *val_results, val_im, val_true, val_pred = validate() 227 | for j, stat in enumerate(names[3:]): 228 | results_dict[stat][i - 1] = val_results[j] 229 | 230 | # Visualize - scalar 231 | epoch_results = np.array(list(zip(train_results, val_results))) 232 | basic_opts = partial(dict, xlabel='Epoch', legend=['train', 'val']) 233 | # Visualize - images 234 | im_titles = ['input', 'label', 'prediction'] 235 | 236 | if i == start_i: 237 | # Windows for scalars 238 | wins = [] 239 | # Windows for images 240 | wins_train_im = [vis.images(item, opts=dict(title='train_' + im_titles[j])) for j, item in 241 | enumerate([train_im, train_true, train_pred])] 242 | wins_val_im = [vis.images(item, opts=dict(title='val_' + im_titles[j])) for j, item in 243 | enumerate([val_im, val_true, val_pred])] 244 | 245 | # Resume values from records 246 | if i > 1: 247 | record_results = [np.column_stack((results_dict['train_' + stat][:i], results_dict['val_' + stat][:i])) 248 | for stat in stats] 249 | for j, stat in enumerate(stats): 250 | wins.append(vis.line(X=np.arange(i), Y=record_results[j], opts=basic_opts(title=stat))) 251 | # Plots from scratch 252 | elif i == 1: 253 | for j, stat in enumerate(stats): 254 | wins.append(vis.line(X=np.array([i]), Y=epoch_results[None, j], opts=basic_opts(title=stat))) 255 | # Update windows 256 | else: 257 | for j, win in enumerate(wins): 258 | vis.updateTrace(X=np.array([i]), Y=epoch_results[j][0, None], win=win, name='train') 259 | vis.updateTrace(X=np.array([i]), Y=epoch_results[j][1, None], win=win, name='val') 260 | if i % 10 == 0: 261 | # Update image show per 10 epochs 262 | for j, item in enumerate([train_im, train_true, train_pred]): 263 | vis.images(item, opts=dict(title='train_' + im_titles[j]), win=wins_train_im[j]) 264 | for j, item in enumerate([val_im, val_true, val_pred]): 265 | vis.images(item, opts=dict(title='val_' + im_titles[j]), win=wins_val_im[j]) 266 | 267 | # Save checkpoints 268 | if i % 20 == 0: 269 | save_checkpoints(model, i) 270 | # np.load('path/to/').item() 271 | np.save(os.path.join(conf.result_dir, 'results_dict.npy'), results_dict) 272 | 273 | 274 | if __name__ == '__main__': 275 | main() 276 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from torch.nn import init 5 | 6 | # Training configuration 7 | class config: 8 | def __init__(self): 9 | self.cuda = torch.cuda.is_available() 10 | self.batch_size = 1 11 | self.epochs = 100 12 | self.augment_size = 500 13 | self.training_size = 12 14 | self.val_size = 3 15 | self.learning_rate = 3e-6 16 | self.criterion = 'dice' 17 | self.seed = 714 18 | self.threads = 24 19 | self.from_scratch = False 20 | self.checkpoint_dir = './checkpoints/' 21 | self.result_dir = './results/' 22 | self.resume_step = -1 23 | self.prefix = 'May force be with you.' 24 | 25 | def weights_init(net): 26 | for m in net.modules(): 27 | if isinstance(m, torch.nn.Conv3d) or isinstance(m, torch.nn.ConvTranspose3d): 28 | init.kaiming_normal(m.weight) 29 | init.constant(m.bias, 0.01) 30 | 31 | def dice_loss(y_conv, y_true): 32 | """Compute dice among **positive** labels to avoid unbalance. 33 | Argument: 34 | y_true: [batch_size * depth * height * width, (1)] (torch.cuda.LongTensor) 35 | y_conv: [batch_size * depth * height * width, 2 ] (torch.cuda.FloatTensor) 36 | """ 37 | y_conv = y_conv[:, 1] 38 | y_true = y_true.float() 39 | intersection = torch.sum(y_conv * y_true, 0) 40 | 41 | # `dim = 0` for Tensor result 42 | union = torch.sum(y_conv * y_conv, 0) + torch.sum(y_true * y_true, 0) 43 | dice = 2.0 * intersection / union 44 | return 1 - torch.clamp(dice, 0.0, 1.0 - 1e-7) 45 | 46 | import matplotlib.pyplot as plt 47 | plt.rcParams['image.cmap'] = 'gray' 48 | 49 | def show_center_slices(im_3d, indices=None): 50 | """Function to display slices of 3-d image """ 51 | 52 | if indices is None: 53 | indices = np.array(im_3d.shape) // 2 54 | assert len(indices) == 3, 'Except 3-d array, but receive %d-d array indexing.' % len(indices) 55 | 56 | x_th, y_th, z_th = indices 57 | fig, axes = plt.subplots(1, 3) 58 | axes[0].imshow(im_3d[x_th, :, :]) 59 | axes[1].imshow(im_3d[:, y_th, :]) 60 | axes[2].imshow(im_3d[:, :, z_th]) 61 | plt.suptitle('Center slices for spine image') 62 | -------------------------------------------------------------------------------- /vnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class conv3d(nn.Module): 6 | def __init__(self, in_channels, out_channels, kernel_size=3, activation_func=nn.ReLU): 7 | """ 8 | + Instantiate modules: conv-relu-norm 9 | + Assign them as member variables 10 | """ 11 | super(conv3d, self).__init__() 12 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=1) 13 | self.relu = activation_func() 14 | # with learnable parameters 15 | # self.norm = nn.InstanceNorm3d(out_channels, affine=True) 16 | 17 | def forward(self, x): 18 | return self.relu(self.conv(x)) 19 | 20 | 21 | class conv3d_x3(nn.Module): 22 | """Three serial convs with a residual connection. 23 | 24 | Structure: 25 | inputs --> ① --> ② --> ③ --> outputs 26 | ↓ --> add--> ↑ 27 | """ 28 | 29 | def __init__(self, in_channels, out_channels, kernel_size=3): 30 | super(conv3d_x3, self).__init__() 31 | self.conv_1 = conv3d(in_channels, out_channels, kernel_size) 32 | self.conv_2 = conv3d(out_channels, out_channels, kernel_size) 33 | self.conv_3 = conv3d(out_channels, out_channels, kernel_size) 34 | 35 | def forward(self, x): 36 | z_1 = self.conv_1(x) 37 | z_3 = self.conv_3(self.conv_2(z_1)) 38 | return z_1 + z_3 39 | 40 | 41 | class deconv3d_x3(nn.Module): 42 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, activation_func=nn.ReLU): 43 | super(deconv3d_x3, self).__init__() 44 | self.up = deconv3d_as_up(in_channels, out_channels, kernel_size, stride) 45 | self.lhs_conv = conv3d(out_channels // 2, out_channels, kernel_size) 46 | self.conv_x3 = conv3d_x3(out_channels, out_channels, kernel_size) 47 | 48 | def forward(self, lhs, rhs): 49 | rhs_up = self.up(rhs) 50 | lhs_conv = self.lhs_conv(lhs) 51 | rhs_add = crop(rhs_up, lhs_conv) + lhs_conv 52 | return self.conv_x3(rhs_add) 53 | 54 | 55 | def crop(large, small): 56 | """large / small with shape [batch_size, channels, depth, height, width]""" 57 | 58 | l, s = large.size(), small.size() 59 | offset = [0, 0, (l[2] - s[2]) // 2, (l[3] - s[3]) // 2, (l[4] - s[4]) // 2] 60 | return large[..., offset[2]: offset[2] + s[2], offset[3]: offset[3] + s[3], offset[4]: offset[4] + s[4]] 61 | 62 | 63 | def conv3d_as_pool(in_channels, out_channels, kernel_size=3, stride=2, activation_func=nn.ReLU): 64 | return nn.Sequential( 65 | nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding=1), 66 | activation_func()) 67 | 68 | 69 | def deconv3d_as_up(in_channels, out_channels, kernel_size=3, stride=2, activation_func=nn.ReLU): 70 | return nn.Sequential( 71 | nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride), 72 | activation_func() 73 | ) 74 | 75 | 76 | class softmax_out(nn.Module): 77 | def __init__(self, in_channels, out_channels, criterion): 78 | super(softmax_out, self).__init__() 79 | self.conv_1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1) 80 | self.conv_2 = nn.Conv3d(out_channels, out_channels, kernel_size=1, padding=0) 81 | if criterion == 'nll': 82 | self.softmax = F.log_softmax 83 | else: 84 | assert criterion == 'dice', "Expect `dice` (dice loss) or `nll` (negative log likelihood loss)." 85 | self.softmax = F.softmax 86 | 87 | def forward(self, x): 88 | """Output with shape [batch_size, 1, depth, height, width].""" 89 | # Do NOT add normalize layer, or its values vanish. 90 | y_conv = self.conv_2(self.conv_1(x)) 91 | # Put channel axis in the last dim for softmax. 92 | y_perm = y_conv.permute(0, 2, 3, 4, 1).contiguous() 93 | y_flat = y_perm.view(-1, 2) 94 | return self.softmax(y_flat) 95 | 96 | 97 | class VNet(nn.Module): 98 | def __init__(self, criterion): 99 | super(VNet, self).__init__() 100 | self.conv_1 = conv3d_x3(1, 16) 101 | self.pool_1 = conv3d_as_pool(16, 32) 102 | self.conv_2 = conv3d_x3(32, 32) 103 | self.pool_2 = conv3d_as_pool(32, 64) 104 | self.conv_3 = conv3d_x3(64, 64) 105 | self.pool_3 = conv3d_as_pool(64, 128) 106 | self.conv_4 = conv3d_x3(128, 128) 107 | self.pool_4 = conv3d_as_pool(128, 256) 108 | 109 | self.bottom = conv3d_x3(256, 256) 110 | 111 | self.deconv_4 = deconv3d_x3(256, 256) 112 | self.deconv_3 = deconv3d_x3(256, 128) 113 | self.deconv_2 = deconv3d_x3(128, 64) 114 | self.deconv_1 = deconv3d_x3(64, 32) 115 | 116 | self.out = softmax_out(32, 2, criterion) 117 | 118 | def forward(self, x): 119 | conv_1 = self.conv_1(x) 120 | pool = self.pool_1(conv_1) 121 | conv_2 = self.conv_2(pool) 122 | pool = self.pool_2(conv_2) 123 | conv_3 = self.conv_3(pool) 124 | pool = self.pool_3(conv_3) 125 | conv_4 = self.conv_4(pool) 126 | pool = self.pool_4(conv_4) 127 | bottom = self.bottom(pool) 128 | deconv = self.deconv_4(conv_4, bottom) 129 | deconv = self.deconv_3(conv_3, deconv) 130 | deconv = self.deconv_2(conv_2, deconv) 131 | deconv = self.deconv_1(conv_1, deconv) 132 | return self.out(deconv) 133 | --------------------------------------------------------------------------------