├── pancreas.jpg ├── _fast_functions.so ├── pytorch_iou └── __init__.py ├── draw_loss_2.py ├── ramps.py ├── draw_loss_1.py ├── losses.py ├── pytorch_gauss └── __init__.py ├── dataprocess.py ├── logs └── id168 │ └── README.md ├── README.md ├── layers.py ├── vgg.py ├── fast_functions.py ├── Data.py ├── init.py ├── coarse_testing.py ├── training.py ├── oracle_testing.py ├── model.py ├── oracle_fusion.py ├── utils.py ├── coarse_fusion.py ├── coarse2fine_testing.py ├── f2.sh ├── f0.sh ├── f1.sh └── f3.sh /pancreas.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunzhan/CKS_Pancreas/HEAD/pancreas.jpg -------------------------------------------------------------------------------- /_fast_functions.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunzhan/CKS_Pancreas/HEAD/_fast_functions.so -------------------------------------------------------------------------------- /pytorch_iou/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | def _iou(pred, target, size_average = True): 7 | b = pred.shape[0] 8 | IoU = 0.0 9 | for i in range(0,b): 10 | Iand1 = torch.sum(target[i,:,:,:]*pred[i,:,:,:]) 11 | Ior1 = torch.sum(target[i,:,:,:]) + torch.sum(pred[i,:,:,:])-Iand1 12 | IoU1 = Iand1/Ior1 13 | IoU = IoU + (1-IoU1) 14 | return IoU/b 15 | 16 | class IOU(torch.nn.Module): 17 | def __init__(self, size_average = True): 18 | super(IOU, self).__init__() 19 | self.size_average = size_average 20 | def forward(self, pred, target): 21 | return _iou(pred, target, self.size_average) 22 | -------------------------------------------------------------------------------- /draw_loss_2.py: -------------------------------------------------------------------------------- 1 | import re 2 | import ipdb 3 | import matplotlib.pyplot as plt 4 | import os.path as osp 5 | 6 | filepath = "/home/datasets/Pancreas82NIH/logs/FD0:Z3_1_20211101_191714.txt" 7 | 8 | 9 | txt = open(filepath, "r").read() 10 | 11 | result="" 12 | test_text = re.findall("Loss+(......................)", txt) 13 | result = result +'\n'.join(test_text) 14 | 15 | result = result.replace(" ", "") 16 | #result = result.replace("/","") 17 | #print(result) 18 | 19 | loss="" 20 | avg_loss = re.findall("(......)+,",result) 21 | loss = loss + '\n'.join(avg_loss) 22 | #print(loss) 23 | mode = {'Loss'} 24 | count = 0 25 | Loss = [] 26 | x = [] 27 | i = 0 28 | 29 | with open('1.txt','w') as f: 30 | f.write(loss) 31 | 32 | with open("1.txt", 'r') as f1: 33 | while True: 34 | #ipdb.set_trace() 35 | line = f1.readline().replace("\n","") 36 | 37 | if line == '': 38 | break 39 | #line = line.replace("\n","") 40 | count += 1 41 | if mode == {'Loss'}: 42 | Loss.append(float(line)) 43 | x.append(count) 44 | 45 | #ipdb.set_trace() 46 | if mode == {'Loss'}: 47 | plt.plot(x, Loss) 48 | 49 | plt.savefig('Z3_1_20211101_191714') 50 | plt.show() -------------------------------------------------------------------------------- /ramps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Functions for ramping hyperparameters up or down 9 | 10 | Each function takes the current training step or epoch, and the 11 | ramp length in the same format, and returns a multiplier between 12 | 0 and 1. 13 | """ 14 | 15 | 16 | import numpy as np 17 | 18 | 19 | def sigmoid_rampup(current, rampup_length): 20 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 21 | if rampup_length == 0: 22 | return 1.0 23 | else: 24 | current = np.clip(current, 0.0, rampup_length) 25 | phase = 1.0 - current / rampup_length 26 | return float(np.exp(-5.0 * phase * phase)) 27 | 28 | 29 | def linear_rampup(current, rampup_length): 30 | """Linear rampup""" 31 | assert current >= 0 and rampup_length >= 0 32 | if current >= rampup_length: 33 | return 1.0 34 | else: 35 | return current / rampup_length 36 | 37 | 38 | def cosine_rampdown(current, rampdown_length): 39 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 40 | assert 0 <= current <= rampdown_length 41 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 42 | -------------------------------------------------------------------------------- /draw_loss_1.py: -------------------------------------------------------------------------------- 1 | import re 2 | import matplotlib.pyplot as plt 3 | import os.path as osp 4 | fullpath = osp.abspath('/home/datasets/Pancreas82NIH/logs/FD0:X3_1_20211101_191714.txt') 5 | # mode = {'Loss'} 6 | mode = {'Coarse', 'Fine', 'Avg'} 7 | #ipdb.set_trace() 8 | 9 | filedir, filename = osp.split(fullpath) 10 | count = 0 11 | Coarse, Fine, Avg, x = [], [], [], [] 12 | with open(fullpath, 'r') as f: 13 | rbsh = f.readline() 14 | while True: 15 | line = f.readline() 16 | if line == '': 17 | break 18 | if not line.startswith('0X'): 19 | continue 20 | count += 1 21 | ipdb.set_trace() 22 | line = line.replace(' ', '').replace('\t', '') 23 | pattern = re.compile(r'\w*.\w+') 24 | find_list = pattern.findall(line) 25 | if mode == {'Loss'}: 26 | Loss.append(float(find_list[0])) 27 | elif mode == {'Coarse', 'Fine', 'Avg'}: 28 | Coarse.append(float(find_list[0])) 29 | Fine.append(float(find_list[1])) 30 | Avg.append(float(find_list[2])) 31 | x.append(count) 32 | 33 | pngName = filename.split('.')[0] 34 | 35 | if mode == {'Loss'}: 36 | plt.plot(x, Loss) 37 | elif mode == {'Coarse', 'Fine', 'Avg'}: 38 | plt.plot(x, Coarse, color='red', marker='o', linestyle='dashed', linewidth=2, markersize=1) 39 | plt.plot(x, Fine, color='green', marker='o', linestyle='dashed', linewidth=2, markersize=1) 40 | plt.plot(x, Avg, color='blue', marker='o', linestyle='dashed', linewidth=2, markersize=1) 41 | plt.legend(labels=('Coarse', 'Fine', 'Avg')) 42 | 43 | plt.savefig(osp.join(filedir, pngName)) 44 | plt.show() -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Custom loss functions""" 9 | 10 | import torch 11 | from torch.nn import functional as F 12 | from torch.autograd import Variable 13 | # import pdb 14 | 15 | 16 | def softmax_mse_loss(input_logits, target_logits): 17 | """Takes softmax on both sides and returns MSE loss 18 | 19 | Note: 20 | - Returns the sum over all examples. Divide by the batch size afterwards 21 | if you want the mean. 22 | - Sends gradients to inputs but not the targets. 23 | """ 24 | #pdb.set_trace() 25 | assert input_logits.size() == target_logits.size() 26 | input_softmax = F.softmax(input_logits, dim=1) 27 | target_softmax = F.softmax(target_logits, dim=1) 28 | num_classes = input_logits.size()[1] 29 | return F.mse_loss(input_softmax, target_softmax, size_average=False) / num_classes 30 | 31 | 32 | def softmax_kl_loss(input_logits, target_logits): 33 | """Takes softmax on both sides and returns KL divergence 34 | 35 | Note: 36 | - Returns the sum over all examples. Divide by the batch size afterwards 37 | if you want the mean. 38 | - Sends gradients to inputs but not the targets. 39 | """ 40 | assert input_logits.size() == target_logits.size() 41 | input_log_softmax = F.log_softmax(input_logits, dim=1) 42 | target_softmax = F.softmax(target_logits, dim=1) 43 | return F.kl_div(input_log_softmax, target_softmax, size_average=False) 44 | 45 | 46 | def symmetric_mse_loss(input1, input2): 47 | """Like F.mse_loss but sends gradients to both directions 48 | 49 | Note: 50 | - Returns the sum over all examples. Divide by the batch size afterwards 51 | if you want the mean. 52 | - Sends gradients to both input1 and input2. 53 | """ 54 | assert input1.size() == input2.size() 55 | num_classes = input1.size()[1] 56 | return torch.sum((input1 - input2)**2) / num_classes 57 | -------------------------------------------------------------------------------- /pytorch_gauss/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from math import exp 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 10 | return gauss/gauss.sum() 11 | 12 | def create_window(window_size, channel): 13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 16 | return window 17 | 18 | def _bce(img1, img2, window, window_size, channel, size_average = True): 19 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 20 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | C1 = 0.01**2 25 | ssim_map = (2*mu1_mu2 + C1)/(mu1_sq + mu2_sq + C1) 26 | return ssim_map.mean() 27 | 28 | class Gauss(torch.nn.Module): 29 | def __init__(self, window_size = 11, size_average = True): 30 | super(Gauss, self).__init__() 31 | self.window_size = window_size 32 | self.size_average = size_average 33 | self.channel = 1 34 | self.window = create_window(window_size, self.channel) 35 | 36 | def forward(self, img1, img2): 37 | (_, channel, _, _) = img1.size() 38 | 39 | if channel == self.channel and self.window.data.type() == img1.data.type(): 40 | window = self.window 41 | else: 42 | window = create_window(self.window_size, channel) 43 | 44 | if img1.is_cuda: 45 | window = window.cuda(img1.get_device()) 46 | window = window.type_as(img1) 47 | 48 | self.window = window 49 | self.channel = channel 50 | 51 | 52 | return _bce(img1, img2, window, self.window_size, channel, self.size_average) 53 | 54 | def bce(img1, img2, window_size = 11, size_average = True): 55 | (_, channel, _, _) = img1.size() 56 | window = create_window(window_size, channel) 57 | 58 | if img1.is_cuda: 59 | window = window.cuda(img1.get_device()) 60 | window = window.type_as(img1) 61 | 62 | return _bce(img1, img2, window, window_size, channel, size_average) -------------------------------------------------------------------------------- /dataprocess.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | class RandomTranslateWithReflect: 5 | """Translate image randomly 6 | 7 | Translate vertically and horizontally by n pixels where 8 | n is integer drawn uniformly independently for each axis 9 | from [-max_translation, max_translation]. 10 | 11 | Fill the uncovered blank area with reflect padding. 12 | """ 13 | 14 | def __init__(self, max_translation): 15 | self.max_translation = max_translation 16 | 17 | def __call__(self, old_image): 18 | xtranslation, ytranslation = np.random.randint(-self.max_translation, 19 | self.max_translation + 1, 20 | size=2) 21 | xpad, ypad = abs(xtranslation), abs(ytranslation) 22 | xsize, ysize = old_image.size 23 | 24 | flipped_lr = old_image.transpose(Image.FLIP_LEFT_RIGHT) 25 | flipped_tb = old_image.transpose(Image.FLIP_TOP_BOTTOM) 26 | flipped_both = old_image.transpose(Image.ROTATE_180) 27 | 28 | new_image = Image.new("RGB", (xsize + 2 * xpad, ysize + 2 * ypad)) 29 | 30 | new_image.paste(old_image, (xpad, ypad)) 31 | 32 | new_image.paste(flipped_lr, (xpad + xsize - 1, ypad)) 33 | new_image.paste(flipped_lr, (xpad - xsize + 1, ypad)) 34 | 35 | new_image.paste(flipped_tb, (xpad, ypad + ysize - 1)) 36 | new_image.paste(flipped_tb, (xpad, ypad - ysize + 1)) 37 | 38 | new_image.paste(flipped_both, (xpad - xsize + 1, ypad - ysize + 1)) 39 | new_image.paste(flipped_both, (xpad + xsize - 1, ypad - ysize + 1)) 40 | new_image.paste(flipped_both, (xpad - xsize + 1, ypad + ysize - 1)) 41 | new_image.paste(flipped_both, (xpad + xsize - 1, ypad + ysize - 1)) 42 | 43 | new_image = new_image.crop((xpad - xtranslation, 44 | ypad - ytranslation, 45 | xpad + xsize - xtranslation, 46 | ypad + ysize - ytranslation)) 47 | 48 | return new_image 49 | 50 | class TransformTwice: 51 | def __init__(self, transform): 52 | self.transform = transform 53 | 54 | def __call__(self, inp): 55 | out1 = self.transform(inp) 56 | out2 = self.transform(inp) 57 | return out1, out2 -------------------------------------------------------------------------------- /logs/id168/README.md: -------------------------------------------------------------------------------- 1 | # Momentum Update for Pancreas Segmentation 2 | Pancreas segmentation is very difficult since pancreas occupies only a very small fraction less than 0.5\% of a CT volume and suffers from high anatomical variability. Most existing methods use a two-stage framework: the coarse and the fine. We argue that both stages have the same purpose of improving pancreatic-pixel classification accuracies. Inspired by this observation, we transfer fine-model weights to the coarse. If we directly copy the pre-trained fine model, the performance is low due to the domain gap of the different input images, so we further propose a momentum update strategy for transferring models. Our momentum update stands on the other observation that input images are in three different domains: the small image cropped by the ground-truth bounding box ($D_1$), the small image cropped by the coarse predicted bounding box ($D_2$), and the large raw image ($D_3$). The momentum update training approach of the coarse model is cast into three steps: train the coarse model by $D_1$ firstly, by $D_2$ secondly, and by $D_3$ thirdly. In the three steps, we copy the model weights step-by-step with the momentum update approach in order to improve the coarse accuracy. The coarse benefits from domain adaptively since the first-step model is trained with strong supervision of the ground-truth bounding box and has a good pancreatic pixel-wise accuracy. The second and the third steps gradually adapt to domain $D_3$ that is the true domain of the coarse model. A higher detection accuracy produces a better region proposal and it renders the fine obtain a better segmentation accuracy. We conduct several experiments on the NIH dataset with different neural network backbones and the results show we obtain the state-of-the-art performance in terms of DSC metric. 3 | 4 | # Experiment 5 | ```sh 6 | bash f2.sh 7 | ``` 8 | 9 | # Citation 10 | We appreciate it if you cite the following paper: 11 | ``` 12 | @InProceedings{TangMICCAI2022, 13 | author = {Yumou Tang and Zhibo Tian and Saisai Wang and Xueming Wen and Kun Zhan}, 14 | title = {Momentum update for pancreas segmentation}, 15 | booktitle = {ICIP}, 16 | year = {2022} 17 | } 18 | 19 | ``` 20 | 21 | # Contact 22 | https://kunzhan.github.io/ 23 | 24 | If you have any questions, feel free to contact me. (Email: `ice.echo#gmail.com`) 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Curriculum Knowledge Switching for Pancreas Segmentation 2 | Pancreas segmentation is very difficult since pancreas occupies only a very small fraction less than 0.5\% of a CT volume and suffers from high anatomical variability. Most existing methods use a two-stage framework: the coarse and the fine. We argue that both stages have the same purpose of improving pancreatic-pixel classification accuracies. Inspired by this observation, we transfer fine-model weights to the coarse. If we directly copy the pre-trained fine model, the performance is low due to the domain gap of the different input images, so we further propose a momentum update strategy for transferring models. Our momentum update stands on the other observation that input images are in three different domains: the small image cropped by the ground-truth bounding box ($D_1$), the small image cropped by the coarse predicted bounding box ($D_2$), and the large raw image ($D_3$). The momentum update training approach of the coarse model is cast into three steps: train the coarse model by $D_1$ firstly, by $D_2$ secondly, and by $D_3$ thirdly. In the three steps, we copy the model weights step-by-step with the momentum update approach in order to improve the coarse accuracy. The coarse benefits from domain adaptively since the first-step model is trained with strong supervision of the ground-truth bounding box and has a good pancreatic pixel-wise accuracy. The second and the third steps gradually adapt to domain $D_3$ that is the true domain of the coarse model. A higher detection accuracy produces a better region proposal and it renders the fine obtain a better segmentation accuracy. We conduct several experiments on the NIH dataset with different neural network backbones and the results show we obtain the state-of-the-art performance in terms of DSC metric. 3 | 4 | ![](pancreas.jpg) 5 | 6 | 7 | # Experiment 8 | ```sh 9 | bash f2.sh 10 | ``` 11 | 12 | # Citation 13 | We appreciate it if you cite the following paper: 14 | ``` 15 | @InProceedings{TangMICCAI2022, 16 | author = {Yumou Tang and Kun Zhan and Zhibo Tian and Mingxuan Zhang and Saisai Wang and Xueming Wen}, 17 | title = {Curriculum knowledge switching for pancreas segmentation}, 18 | booktitle = {ICIP}, 19 | year = {2023} 20 | } 21 | 22 | ``` 23 | 24 | # Contact 25 | https://kunzhan.github.io/ 26 | 27 | If you have any questions, feel free to contact me. (Email: `ice.echo#gmail.com`) 28 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils import init_weights 4 | from torch.nn import functional as F 5 | #import ipdb 6 | 7 | class unetConv2(nn.Module): 8 | def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): 9 | super(unetConv2, self).__init__() 10 | self.n = n 11 | self.ks = ks 12 | self.stride = stride 13 | self.padding = padding 14 | s = stride 15 | p = padding 16 | if is_batchnorm: 17 | for i in range(1, n+1): 18 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 19 | nn.BatchNorm2d(out_size), 20 | nn.ReLU(inplace=True),) 21 | setattr(self, 'conv%d'%i, conv) 22 | in_size = out_size 23 | 24 | else: 25 | for i in range(1, n+1): 26 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 27 | nn.ReLU(inplace=True),) 28 | setattr(self, 'conv%d'%i, conv) 29 | in_size = out_size 30 | 31 | # initialise the blocks 32 | for m in self.children(): 33 | init_weights(m, init_type='kaiming') 34 | 35 | def forward(self, inputs): 36 | x = inputs 37 | for i in range(1, self.n+1): 38 | conv = getattr(self, 'conv%d'%i) 39 | x = conv(x) 40 | 41 | return x 42 | 43 | class unetUp(nn.Module): 44 | def __init__(self, in_size, out_size, is_deconv, n_concat=2): 45 | super(unetUp, self).__init__() 46 | self.conv = unetConv2(in_size+(n_concat-2)*out_size, out_size, False) 47 | if is_deconv: 48 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, padding=0, bias=False) 49 | else: 50 | self.up = nn.Sequential( 51 | nn.UpsamplingBilinear2d(scale_factor=2), 52 | nn.Conv2d(in_size, out_size, 1)) 53 | 54 | # initialise the blocks 55 | for m in self.children(): 56 | if m.__class__.__name__.find('unetConv2') != -1: continue 57 | init_weights(m, init_type='kaiming') 58 | 59 | def forward(self, high_feature, *low_feature): 60 | #ipdb.set_trace() 61 | for feature in low_feature: 62 | outputs0 = self.up(high_feature) 63 | outputs0 = F.interpolate(outputs0,(feature.shape[2],feature.shape[3]), mode='bilinear', align_corners=True) 64 | for feature in low_feature: 65 | outputs0 = torch.cat([outputs0, feature], 1) 66 | return self.conv(outputs0) 67 | 68 | -------------------------------------------------------------------------------- /vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | #from torchvision.models.utils import load_state_dict_from_url 4 | from torch.hub import load_state_dict_from_url 5 | class VGG(nn.Module): 6 | def __init__(self, features, num_classes=1000): 7 | super(VGG, self).__init__() 8 | self.features = features 9 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 10 | self.classifier = nn.Sequential( 11 | nn.Linear(512 * 7 * 7, 4096), 12 | nn.ReLU(True), 13 | nn.Dropout(), 14 | nn.Linear(4096, 4096), 15 | nn.ReLU(True), 16 | nn.Dropout(), 17 | nn.Linear(4096, num_classes), 18 | ) 19 | self._initialize_weights() 20 | 21 | def forward(self, x): 22 | x = self.features(x) 23 | x = self.avgpool(x) 24 | x = torch.flatten(x, 1) 25 | x = self.classifier(x) 26 | return x 27 | 28 | def _initialize_weights(self): 29 | for m in self.modules(): 30 | if isinstance(m, nn.Conv2d): 31 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 32 | if m.bias is not None: 33 | nn.init.constant_(m.bias, 0) 34 | elif isinstance(m, nn.BatchNorm2d): 35 | nn.init.constant_(m.weight, 1) 36 | nn.init.constant_(m.bias, 0) 37 | elif isinstance(m, nn.Linear): 38 | nn.init.normal_(m.weight, 0, 0.01) 39 | nn.init.constant_(m.bias, 0) 40 | 41 | 42 | def make_layers(cfg, batch_norm=False, in_channels = 3): 43 | layers = [] 44 | for v in cfg: 45 | if v == 'M': 46 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] 47 | else: 48 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 49 | if batch_norm: 50 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 51 | else: 52 | layers += [conv2d, nn.ReLU(inplace=True)] 53 | in_channels = v 54 | return nn.Sequential(*layers) 55 | # 512,512,3 -> 512,512,64 -> 256,256,64 -> 256,256,128 -> 128,128,128 -> 128,128,256 -> 64,64,256 56 | # 64,64,512 -> 32,32,512 -> 32,32,512 57 | cfgs = { 58 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] 59 | } 60 | 61 | 62 | def VGG16(pretrained, in_channels, **kwargs): 63 | model = VGG(make_layers(cfgs["D"], batch_norm = False, in_channels = in_channels), **kwargs) 64 | if pretrained: 65 | state_dict = load_state_dict_from_url("https://download.pytorch.org/models/vgg16-397923af.pth", model_dir="./model_data") 66 | model.load_state_dict(state_dict) 67 | 68 | del model.avgpool 69 | del model.classifier 70 | return model -------------------------------------------------------------------------------- /fast_functions.py: -------------------------------------------------------------------------------- 1 | # This file was automatically generated by SWIG (http://www.swig.org). 2 | # Version 3.0.8 3 | # 4 | # Do not make changes to this file unless you know what you are doing--modify 5 | # the SWIG interface file instead. 6 | 7 | 8 | 9 | 10 | 11 | from sys import version_info 12 | if version_info >= (2, 6, 0): 13 | def swig_import_helper(): 14 | from os.path import dirname 15 | import imp 16 | fp = None 17 | try: 18 | fp, pathname, description = imp.find_module('_fast_functions', [dirname(__file__)]) 19 | except ImportError: 20 | import _fast_functions 21 | return _fast_functions 22 | if fp is not None: 23 | try: 24 | _mod = imp.load_module('_fast_functions', fp, pathname, description) 25 | finally: 26 | fp.close() 27 | return _mod 28 | _fast_functions = swig_import_helper() 29 | del swig_import_helper 30 | else: 31 | import _fast_functions 32 | del version_info 33 | try: 34 | _swig_property = property 35 | except NameError: 36 | pass # Python < 2.2 doesn't have 'property'. 37 | 38 | 39 | def _swig_setattr_nondynamic(self, class_type, name, value, static=1): 40 | if (name == "thisown"): 41 | return self.this.own(value) 42 | if (name == "this"): 43 | if type(value).__name__ == 'SwigPyObject': 44 | self.__dict__[name] = value 45 | return 46 | method = class_type.__swig_setmethods__.get(name, None) 47 | if method: 48 | return method(self, value) 49 | if (not static): 50 | if _newclass: 51 | object.__setattr__(self, name, value) 52 | else: 53 | self.__dict__[name] = value 54 | else: 55 | raise AttributeError("You cannot add attributes to %s" % self) 56 | 57 | 58 | def _swig_setattr(self, class_type, name, value): 59 | return _swig_setattr_nondynamic(self, class_type, name, value, 0) 60 | 61 | 62 | def _swig_getattr_nondynamic(self, class_type, name, static=1): 63 | if (name == "thisown"): 64 | return self.this.own() 65 | method = class_type.__swig_getmethods__.get(name, None) 66 | if method: 67 | return method(self) 68 | if (not static): 69 | return object.__getattr__(self, name) 70 | else: 71 | raise AttributeError(name) 72 | 73 | def _swig_getattr(self, class_type, name): 74 | return _swig_getattr_nondynamic(self, class_type, name, 0) 75 | 76 | 77 | def _swig_repr(self): 78 | try: 79 | strthis = "proxy of " + self.this.__repr__() 80 | except Exception: 81 | strthis = "" 82 | return "<%s.%s; %s >" % (self.__class__.__module__, self.__class__.__name__, strthis,) 83 | 84 | try: 85 | _object = object 86 | _newclass = 1 87 | except AttributeError: 88 | class _object: 89 | pass 90 | _newclass = 0 91 | 92 | 93 | 94 | def post_processing(F, S, threshold, top2): 95 | return _fast_functions.post_processing(F, S, threshold, top2) 96 | post_processing = _fast_functions.post_processing 97 | 98 | def DSC_computation(A, G, P): 99 | return _fast_functions.DSC_computation(A, G, P) 100 | DSC_computation = _fast_functions.DSC_computation 101 | # This file is compatible with both classic and new-style classes. 102 | 103 | 104 | -------------------------------------------------------------------------------- /Data.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | from utils import * 7 | import torch.nn.functional as F 8 | 9 | class DataLayer(data.Dataset): 10 | def __init__(self, data_path, current_fold, organ_number, low_range, high_range, \ 11 | slice_threshold, slice_thickness, organ_ID, plane): 12 | self.low_range = low_range 13 | self.high_range = high_range 14 | self.slice_thickness = slice_thickness 15 | self.organ_ID = organ_ID 16 | 17 | image_list = open(training_set_filename(current_fold), 'r').read().splitlines() 18 | self.training_image_set = np.zeros((len(image_list)), dtype = np.int) 19 | for i in range(len(image_list)): 20 | s = image_list[i].split(' ') 21 | self.training_image_set[i] = int(s[0]) 22 | slice_list = open(list_training[plane], 'r').read().splitlines() 23 | self.slices = len(slice_list) 24 | self.image_ID = np.zeros((self.slices), dtype = np.int) 25 | self.slice_ID = np.zeros((self.slices), dtype = np.int) 26 | self.image_filename = ['' for l in range(self.slices)] 27 | self.label_filename = ['' for l in range(self.slices)] 28 | self.average = np.zeros((self.slices)) 29 | self.pixels = np.zeros((self.slices), dtype = np.int) 30 | 31 | for l in range(self.slices): 32 | s = slice_list[l].split(' ') 33 | self.image_ID[l] = s[0] 34 | self.slice_ID[l] = s[1] 35 | self.image_filename[l] = s[2] 36 | self.label_filename[l] = s[3] 37 | self.average[l] = float(s[4]) 38 | self.pixels[l] = int(s[organ_ID * 5]) 39 | if slice_threshold <= 1: 40 | pixels_index = sorted(range(self.slices), key = lambda l: self.pixels[l]) 41 | last_index = int(math.floor((self.pixels > 0).sum() * slice_threshold)) 42 | min_pixels = self.pixels[pixels_index[-last_index]] 43 | else: # or set up directly 44 | min_pixels = slice_threshold 45 | self.active_index = [l for l, p in enumerate(self.pixels) 46 | if p >= min_pixels and self.image_ID[l] in self.training_image_set] # true active 47 | 48 | def __getitem__(self, index): 49 | self.index1 = self.active_index[index] 50 | self.index0 = self.index1 - 1 51 | if self.index1 == 0 or self.slice_ID[self.index0] != self.slice_ID[self.index1] - 1: 52 | self.index0 = self.index1 53 | self.index2 = self.index1 + 1 54 | if self.index1 == self.slices - 1 or self.slice_ID[self.index2] != self.slice_ID[self.index1] + 1: 55 | self.index2 = self.index1 56 | self.data, self.label = self.load_data() 57 | return torch.from_numpy(self.data), torch.from_numpy(self.label) 58 | 59 | def __len__(self): 60 | return len(self.active_index) 61 | 62 | def load_data(self): 63 | if self.slice_thickness == 1: 64 | image1 = np.load(self.image_filename[self.index1]).astype(np.float32) 65 | label1 = np.load(self.label_filename[self.index1]) 66 | width = label1.shape[0] 67 | height = label1.shape[1] 68 | image = np.repeat(image1.reshape(1, width, height), 3, axis = 0) 69 | label = label1.reshape(1, width, height) 70 | elif self.slice_thickness == 3: 71 | image0 = np.load(self.image_filename[self.index0]) 72 | width = image0.shape[0] 73 | height = image0.shape[1] 74 | image = np.zeros((3, width, height), dtype = np.float32) 75 | image[0, ...] = image0 76 | image[1, ...] = np.load(self.image_filename[self.index1]) 77 | image[2, ...] = np.load(self.image_filename[self.index2]) 78 | label = np.zeros((3, width, height), dtype = np.uint8) 79 | label[0, ...] = np.load(self.label_filename[self.index0]) 80 | label[1, ...] = np.load(self.label_filename[self.index1]) 81 | label[2, ...] = np.load(self.label_filename[self.index2]) 82 | np.minimum(np.maximum(image, self.low_range, image), self.high_range, image) 83 | image -= self.low_range 84 | image /= (self.high_range - self.low_range) 85 | label = is_organ(label, self.organ_ID).astype(np.uint8) 86 | return image, label -------------------------------------------------------------------------------- /init.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import time 5 | from utils import * 6 | 7 | data_path = sys.argv[1] 8 | organ_number = int(sys.argv[2]) 9 | folds = int(sys.argv[3]) 10 | low_range = int(sys.argv[4]) 11 | high_range = int(sys.argv[5]) 12 | 13 | image_list = [] 14 | image_filename = [] 15 | keyword = '' 16 | for directory, _, file_ in os.walk(image_path): 17 | for filename in sorted(file_): 18 | if keyword in filename: 19 | image_list.append(os.path.join(directory, filename)) 20 | image_filename.append(os.path.splitext(filename)[0]) 21 | label_list = [] 22 | label_filename = [] 23 | for directory, _, file_ in os.walk(label_path): 24 | for filename in sorted(file_): 25 | if keyword in filename: 26 | label_list.append(os.path.join(directory, filename)) 27 | label_filename.append(os.path.splitext(filename)[0]) 28 | if len(image_list) != len(label_list): 29 | exit('Error: the number of labels and the number of images are not equal!') 30 | total_samples = len(image_list) 31 | 32 | for plane in ['X', 'Y', 'Z']: 33 | output = open(list_training[plane], 'w') # create txt 34 | output.close() 35 | print('Initialization starts.') 36 | 37 | for i in range(total_samples): 38 | start_time = time.time() 39 | print('Processing ' + str(i + 1) + ' out of ' + str(total_samples) + ' files.') 40 | image = np.load(image_list[i]) 41 | label = np.load(label_list[i]) 42 | print(' 3D volume is loaded: ' + str(time.time() - start_time) + ' second(s) elapsed.') 43 | for plane in ['X', 'Y', 'Z']: 44 | if plane == 'X': 45 | slice_number = label.shape[0] 46 | elif plane == 'Y': 47 | slice_number = label.shape[1] 48 | elif plane == 'Z': 49 | slice_number = label.shape[2] 50 | print(' Processing data on ' + plane + ' plane (' + str(slice_number) + ' slices): ' + \ 51 | str(time.time() - start_time) + ' second(s) elapsed.') 52 | image_directory_ = os.path.join(image_path_[plane], image_filename[i]) 53 | if not os.path.exists(image_directory_): 54 | os.makedirs(image_directory_) 55 | label_directory_ = os.path.join(label_path_[plane], label_filename[i]) 56 | if not os.path.exists(label_directory_): 57 | os.makedirs(label_directory_) 58 | print(' Slicing data: ' + str(time.time() - start_time) + ' second(s) elapsed.') 59 | sum_ = np.zeros((slice_number, organ_number + 1), dtype = np.int) 60 | minA = np.zeros((slice_number, organ_number + 1), dtype = np.int) 61 | maxA = np.zeros((slice_number, organ_number + 1), dtype = np.int) 62 | minB = np.zeros((slice_number, organ_number + 1), dtype = np.int) 63 | maxB = np.zeros((slice_number, organ_number + 1), dtype = np.int) 64 | average = np.zeros((slice_number), dtype = np.float) 65 | for j in range(0, slice_number): 66 | image_filename_ = os.path.join( \ 67 | image_path_[plane], image_filename[i], '{:0>4}'.format(j) + '.npy') 68 | label_filename_ = os.path.join( \ 69 | label_path_[plane], label_filename[i], '{:0>4}'.format(j) + '.npy') 70 | if plane == 'X': 71 | image_ = image[j, :, :] 72 | label_ = label[j, :, :] 73 | elif plane == 'Y': 74 | image_ = image[:, j, :] 75 | label_ = label[:, j, :] 76 | elif plane == 'Z': 77 | image_ = image[:, :, j] 78 | label_ = label[:, :, j] 79 | if not os.path.isfile(image_filename_) or not os.path.isfile(label_filename_): 80 | np.save(image_filename_, image_) # main function, no truncate 81 | np.save(label_filename_, label_) 82 | np.minimum(np.maximum(image_, low_range, image_), high_range, image_) 83 | 84 | average[j] = float(image_.sum()) / (image_.shape[0] * image_.shape[1]) 85 | for o in range(1, organ_number + 1): 86 | sum_[j, o] = (is_organ(label_, o)).sum() 87 | arr = np.nonzero(is_organ(label_, o)) 88 | minA[j, o] = 0 if not len(arr[0]) else min(arr[0]) # [A*B] min/max nonzero 89 | maxA[j, o] = 0 if not len(arr[0]) else max(arr[0]) 90 | minB[j, o] = 0 if not len(arr[1]) else min(arr[1]) 91 | maxB[j, o] = 0 if not len(arr[1]) else max(arr[1]) 92 | print(' Writing training lists: ' + str(time.time() - start_time) + ' second(s) elapsed.') 93 | output = open(list_training[plane], 'a+') 94 | for j in range(0, slice_number): 95 | image_filename_ = os.path.join( \ 96 | image_path_[plane], image_filename[i], '{:0>4}'.format(j) + '.npy') 97 | label_filename_ = os.path.join( \ 98 | label_path_[plane], label_filename[i], '{:0>4}'.format(j) + '.npy') 99 | output.write(str(i) + ' ' + str(j)) 100 | output.write(' ' + image_filename_ + ' ' + label_filename_) 101 | output.write(' ' + str(average[j])) 102 | for o in range(1, organ_number + 1): 103 | output.write(' ' + str(sum_[j, o]) + ' ' + str(minA[j, o]) + \ 104 | ' ' + str(maxA[j, o]) + ' ' + str(minB[j, o]) + ' ' + str(maxB[j, o])) 105 | output.write('\n') 106 | output.close() 107 | print(' ' + plane + ' plane is done: ' + \ 108 | str(time.time() - start_time) + ' second(s) elapsed.') 109 | print('Processed ' + str(i + 1) + ' out of ' + str(total_samples) + ' files: ' + \ 110 | str(time.time() - start_time) + ' second(s) elapsed.') 111 | 112 | print('Writing training image list.') 113 | for f in range(folds): 114 | list_training_ = training_set_filename(f) 115 | output = open(list_training_, 'w') 116 | for i in range(total_samples): 117 | if in_training_set(total_samples, i, folds, f): 118 | output.write(str(i) + ' ' + image_list[i] + ' ' + label_list[i] + '\n') 119 | output.close() 120 | print('Writing testing image list.') 121 | for f in range(folds): 122 | list_testing_ = testing_set_filename(f) 123 | output = open(list_testing_, 'w') 124 | for i in range(total_samples): 125 | if not in_training_set(total_samples, i, folds, f): 126 | output.write(str(i) + ' ' + image_list[i] + ' ' + label_list[i] + '\n') 127 | output.close() 128 | print('Initialization is done.') -------------------------------------------------------------------------------- /coarse_testing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import time 5 | from utils import * 6 | from model import * 7 | 8 | data_path = sys.argv[1] 9 | current_fold = int(sys.argv[2]) 10 | organ_number = int(sys.argv[3]) 11 | low_range = int(sys.argv[4]) 12 | high_range = int(sys.argv[5]) 13 | slice_threshold = float(sys.argv[6]) 14 | slice_thickness = int(sys.argv[7]) 15 | organ_ID = int(sys.argv[8]) 16 | plane = sys.argv[9] 17 | GPU_ID = int(sys.argv[10]) 18 | learning_rate1 = float(sys.argv[11]) 19 | learning_rate_m1 = int(sys.argv[12]) 20 | learning_rate2 = float(sys.argv[13]) 21 | learning_rate_m2 = int(sys.argv[14]) 22 | crop_margin = int(sys.argv[15]) 23 | crop_prob = float(sys.argv[16]) 24 | crop_sample_batch = int(sys.argv[17]) 25 | snapshot_path = os.path.join(snapshot_path, 'SIJ_training_' + \ 26 | sys.argv[11] + 'x' + str(learning_rate_m1) + ',' + str(crop_margin)) 27 | result_path = os.path.join(result_path, 'coarse_testing_' + \ 28 | sys.argv[11] + 'x' + str(learning_rate_m1) + ',' + str(crop_margin)) 29 | epoch = 'e' + sys.argv[18] + sys.argv[19] + sys.argv[20] + sys.argv[21] 30 | epoch_list = [epoch] 31 | timestamp = sys.argv[22] 32 | 33 | snapshot_name = snapshot_name_from_timestamp(snapshot_path, \ 34 | current_fold, plane, 'I', slice_thickness, organ_ID, timestamp) 35 | if snapshot_name == '': 36 | exit('Error: no valid snapshot directories are detected!') 37 | snapshot_directory = os.path.join(snapshot_path, snapshot_name) 38 | print('Snapshot directory: ' + snapshot_directory + ' .') 39 | snapshot = [snapshot_directory] 40 | print(str(len(snapshot)) + ' snapshots are to be evaluated.') 41 | for t in range(len(snapshot)): 42 | print(' Snapshot #' + str(t + 1) + ': ' + snapshot[t] + ' .') 43 | result_name = snapshot_name 44 | 45 | os.environ["CUDA_VISIBLE_DEVICES"]= str(GPU_ID) 46 | 47 | volume_list = open(testing_set_filename(current_fold), 'r').read().splitlines() 48 | while volume_list[len(volume_list) - 1] == '': 49 | volume_list.pop() 50 | DSC = np.zeros((len(snapshot), len(volume_list))) 51 | result_directory = os.path.join(result_path, result_name, 'volumes') 52 | if not os.path.exists(result_directory): 53 | os.makedirs(result_directory) 54 | result_file = os.path.join(result_path, result_name, 'results.txt') 55 | output = open(result_file, 'w') 56 | output.close() 57 | 58 | for t in range(len(snapshot)): 59 | output = open(result_file, 'a+') 60 | output.write('Evaluating snapshot ' + str(epoch_list[t]) + ':\n') 61 | output.close() 62 | finished = True 63 | for i in range(len(volume_list)): 64 | volume_file = volume_filename_testing(result_directory, epoch_list[t], i) 65 | if not os.path.isfile(volume_file): 66 | finished = False 67 | break 68 | if not finished: 69 | net = RSTN(crop_margin=crop_margin, crop_prob=crop_prob, \ 70 | crop_sample_batch=crop_sample_batch, TEST='C').cuda() 71 | net.load_state_dict(torch.load(snapshot[t])) 72 | net.eval() 73 | 74 | for i in range(len(volume_list)): 75 | start_time = time.time() 76 | print('Testing ' + str(i + 1) + ' out of ' + str(len(volume_list)) + ' testcases, ' + \ 77 | str(t + 1) + ' out of ' + str(len(snapshot)) + ' snapshots.') 78 | volume_file = volume_filename_testing(result_directory, epoch_list[t], i) 79 | s = volume_list[i].split(' ') 80 | label = np.load(s[2]) 81 | label = is_organ(label, organ_ID).astype(np.uint8) 82 | if not os.path.isfile(volume_file): 83 | image = np.load(s[1]).astype(np.float32) 84 | np.minimum(np.maximum(image, low_range, image), high_range, image) 85 | image -= low_range 86 | image /= (high_range - low_range) 87 | print(' Data loading is finished: ' + \ 88 | str(time.time() - start_time) + ' second(s) elapsed.') 89 | pred = np.zeros(image.shape, dtype = np.float32) 90 | minR = 0 91 | if plane == 'X': 92 | maxR = image.shape[0] 93 | shape_ = (1, 3, image.shape[1], image.shape[2]) 94 | elif plane == 'Y': 95 | maxR = image.shape[1] 96 | shape_ = (1, 3, image.shape[0], image.shape[2]) 97 | elif plane == 'Z': 98 | maxR = image.shape[2] 99 | shape_ = (1, 3, image.shape[0], image.shape[1]) 100 | for j in range(minR, maxR): 101 | if slice_thickness == 1: 102 | sID = [j, j, j] 103 | elif slice_thickness == 3: 104 | sID = [max(minR, j - 1), j, min(maxR - 1, j + 1)] 105 | if plane == 'X': 106 | image_ = image[sID, :, :].astype(np.float32) 107 | elif plane == 'Y': 108 | image_ = image[:, sID, :].transpose(1, 0, 2).astype(np.float32) 109 | elif plane == 'Z': 110 | image_ = image[:, :, sID].transpose(2, 0, 1).astype(np.float32) 111 | 112 | image_ = image_.reshape((1, 3, image_.shape[1], image_.shape[2])) 113 | image_ = torch.from_numpy(image_).cuda().float() 114 | #pdb.set_trace() 115 | out = net(image_, 1).data.cpu().numpy()[0, :, :, :] 116 | 117 | if slice_thickness == 1: 118 | if plane == 'X': 119 | pred[j, :, :] = out 120 | elif plane == 'Y': 121 | pred[:, j, :] = out 122 | elif plane == 'Z': 123 | pred[:, :, j] = out 124 | elif slice_thickness == 3: 125 | if plane == 'X': 126 | if j == minR: 127 | pred[j: j + 2, :, :] += out[1: 3, :, :] 128 | elif j == maxR - 1: 129 | pred[j - 1: j + 1, :, :] += out[0: 2, :, :] 130 | else: 131 | pred[j - 1: j + 2, :, :] += out[...] 132 | elif plane == 'Y': 133 | if j == minR: 134 | pred[:, j: j + 2, :] += out[1: 3, :, :].transpose(1, 0, 2) 135 | elif j == maxR - 1: 136 | pred[:, j - 1: j + 1, :] += out[0: 2, :, :].transpose(1, 0, 2) 137 | else: 138 | pred[:, j - 1: j + 2, :] += out[...].transpose(1, 0, 2) 139 | elif plane == 'Z': 140 | if j == minR: 141 | pred[:, :, j: j + 2] += out[1: 3, :, :].transpose(1, 2, 0) 142 | elif j == maxR - 1: 143 | pred[:, :, j - 1: j + 1] += out[0: 2, :, :].transpose(1, 2, 0) 144 | else: 145 | pred[:, :, j - 1: j + 2] += out[...].transpose(1, 2, 0) 146 | if slice_thickness == 3: 147 | if plane == 'X': 148 | pred[minR, :, :] /= 2 149 | pred[minR + 1: maxR - 1, :, :] /= 3 150 | pred[maxR - 1, :, :] /= 2 151 | elif plane == 'Y': 152 | pred[:, minR, :] /= 2 153 | pred[:, minR + 1: maxR - 1, :] /= 3 154 | pred[:, maxR - 1, :] /= 2 155 | elif plane == 'Z': 156 | pred[:, :, minR] /= 2 157 | pred[:, :, minR + 1: maxR - 1] /= 3 158 | pred[:, :, maxR - 1] /= 2 159 | print(' Testing is finished: ' + str(time.time() - start_time) + ' second(s) elapsed.') 160 | pred = np.around(pred * 255).astype(np.uint8) 161 | np.savez_compressed(volume_file, volume = pred) 162 | print(' Data saving is finished: ' + \ 163 | str(time.time() - start_time) + ' second(s) elapsed.') 164 | pred_temp = (pred >= 128) 165 | else: 166 | volume_data = np.load(volume_file) 167 | pred = volume_data['volume'].astype(np.uint8) 168 | print(' Testing result is loaded: ' + \ 169 | str(time.time() - start_time) + ' second(s) elapsed.') 170 | pred_temp = (pred >= 128) 171 | 172 | DSC[t, i], inter_sum, pred_sum, label_sum = DSC_computation(label, pred_temp) 173 | print(' DSC = 2 * ' + str(inter_sum) + ' / (' + str(pred_sum) + \ 174 | ' + ' + str(label_sum) + ') = ' + str(DSC[t, i]) + ' .') 175 | output = open(result_file, 'a+') 176 | output.write(' Testcase ' + str(i + 1) + ': DSC = 2 * ' + str(inter_sum) + ' / (' + \ 177 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(DSC[t, i]) + ' .\n') 178 | output.close() 179 | if pred_sum == 0 and label_sum == 0: 180 | DSC[t, i] = 0 181 | print(' DSC computation is finished: ' + \ 182 | str(time.time() - start_time) + ' second(s) elapsed.') 183 | 184 | print('Snapshot ' + str(epoch_list[t]) + ': average DSC = ' + str(np.mean(DSC[t, :])) + ' .') 185 | output = open(result_file, 'a+') 186 | output.write('Snapshot ' + str(epoch_list[t]) + \ 187 | ': average DSC = ' + str(np.mean(DSC[t, :])) + ' .\n') 188 | output.close() 189 | 190 | print('The testing process is finished.') 191 | for t in range(len(snapshot)): 192 | print(' Snapshot ' + str(epoch_list[t]) + ': average DSC = ' + str(np.mean(DSC[t, :])) + ' .') -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | from operator import mod 2 | import os 3 | import sys 4 | import time 5 | from utils import * 6 | from model import * 7 | import ipdb 8 | import pytorch_iou 9 | import pytorch_gauss 10 | from ramps import * 11 | import torchvision.transforms as transforms 12 | 13 | 14 | if __name__ == '__main__': 15 | data_path = sys.argv[1] 16 | current_fold = sys.argv[2] 17 | organ_number = int(sys.argv[3]) 18 | low_range = int(sys.argv[4]) 19 | high_range = int(sys.argv[5]) 20 | slice_threshold = float(sys.argv[6]) 21 | slice_thickness = int(sys.argv[7]) 22 | organ_ID = int(sys.argv[8]) 23 | plane = sys.argv[9] 24 | GPU_ID = int(sys.argv[10]) 25 | learning_rate1 = float(sys.argv[11]) 26 | learning_rate_m1 = int(sys.argv[12]) 27 | learning_rate2 = float(sys.argv[13]) 28 | learning_rate_m2 = int(sys.argv[14]) 29 | crop_margin = int(sys.argv[15]) 30 | crop_prob = float(sys.argv[16]) 31 | crop_sample_batch = int(sys.argv[17]) 32 | snapshot_path = os.path.join(snapshot_path, 'SIJ_training_' + \ 33 | sys.argv[11] + 'x' + str(learning_rate_m1) + ',' + str(crop_margin)) 34 | epoch = {} 35 | epoch['S'] = int(sys.argv[18]) 36 | epoch['I'] = int(sys.argv[19]) 37 | epoch['J'] = int(sys.argv[20]) 38 | epoch['lr_decay'] = int(sys.argv[21]) 39 | timestamp = sys.argv[22] 40 | 41 | if not os.path.exists(snapshot_path): 42 | os.makedirs(snapshot_path) 43 | 44 | Unet_weights = os.path.join(pretrained_model_path, 'unet_voc.pth') 45 | if not os.path.isfile(Unet_weights): 46 | raise RuntimeError('Please Download from the Internet ...') 47 | 48 | from Data import DataLayer 49 | training_set = DataLayer(data_path=data_path, current_fold=int(current_fold), organ_number=organ_number, \ 50 | low_range=low_range, high_range=high_range, slice_threshold=slice_threshold, slice_thickness=slice_thickness, \ 51 | organ_ID=organ_ID, plane=plane) 52 | 53 | batch_size = 1 54 | os.environ["CUDA_VISIBLE_DEVICES"]= str(GPU_ID) 55 | trainloader = torch.utils.data.DataLoader(training_set, batch_size=batch_size, shuffle=True, num_workers=16, drop_last=True) 56 | print(current_fold + plane, len(trainloader)) 57 | print(epoch) 58 | 59 | RSTN_model = RSTN(crop_margin=crop_margin, \ 60 | crop_prob=crop_prob, crop_sample_batch=crop_sample_batch) 61 | RSTN_snapshot = {} 62 | 63 | model_parameters = filter(lambda p: p.requires_grad, RSTN_model.parameters()) 64 | params = sum([np.prod(p.size()) for p in model_parameters]) 65 | print('model parameters:', params) 66 | 67 | #pdb.set_trace() 68 | for param in RSTN_model.coarse_model.parameters(): 69 | param.detach_() 70 | 71 | optimizer = torch.optim.SGD( 72 | [ 73 | {'params': get_parameters(RSTN_model, coarse=False, bias=False, parallel=False), 74 | 'lr': learning_rate1 * 10}, 75 | {'params': get_parameters(RSTN_model, coarse=False, bias=True, parallel=False), 76 | 'lr': learning_rate1 * 20, 'weight_decay': 0} 77 | ], 78 | lr=learning_rate1, 79 | momentum=0.99, 80 | weight_decay=0.0005) 81 | 82 | criterion = DSC_loss() 83 | COARSE_WEIGHT = 1 / 3 84 | 85 | bce_loss = nn.BCELoss(size_average=True) 86 | gauss_loss = pytorch_gauss.Gauss(window_size=11,size_average=True) 87 | iou_loss = pytorch_iou.IOU(size_average=True) 88 | 89 | def update_ema_variables(model, ema_model, alpha): 90 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 91 | ema_param.data.mul_(alpha).add_(1 - alpha, param.data) 92 | 93 | def update_variables(model, ema_model): 94 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 95 | ema_param.data = param.data 96 | 97 | def overall_loss(pred,target): 98 | gauss_out = 1 - gauss_loss(pred, target) 99 | iou_out = iou_loss(pred, target) 100 | bce_out = bce_loss(pred, target) 101 | 102 | loss = bce_out + gauss_out + iou_out 103 | 104 | return loss 105 | 106 | RSTN_model = RSTN_model.cuda() 107 | RSTN_model.train() 108 | 109 | for mode in ['S','I','J']: 110 | if mode == 'S': 111 | RSTN_dict = RSTN_model.state_dict() 112 | pretrained_dict = torch.load(Unet_weights) 113 | 114 | w = pretrained_dict['final.weight'][20,:,:,:] 115 | b = pretrained_dict['final.bias'][20] 116 | pretrained_dict['final.weight'] = pretrained_dict['final.weight'][:3,:,:,:] 117 | pretrained_dict['final.bias'] = pretrained_dict['final.bias'][:3] 118 | pretrained_dict['final.weight'][0,:,:,:] = w 119 | pretrained_dict['final.bias'][0] = b 120 | pretrained_dict['final.weight'][1,:,:,:] = w 121 | pretrained_dict['final.bias'][1] = b 122 | pretrained_dict['final.weight'][2,:,:,:] = w 123 | pretrained_dict['final.bias'][2] = b 124 | # 1. filter out unnecessary keys 125 | pretrained_dict_coarse = {'coarse_model.' + k : v 126 | for k, v in pretrained_dict.items() 127 | if 'coarse_model.' + k in RSTN_dict and 'score' not in k} 128 | pretrained_dict_fine = {'fine_model.' + k : v 129 | for k, v in pretrained_dict.items() 130 | if 'fine_model.' + k in RSTN_dict and 'score' not in k} 131 | pretrained_dict_fine_ema = {'fine_model_ema.' + k : v 132 | for k, v in pretrained_dict.items() 133 | if 'fine_model_ema.' + k in RSTN_dict and 'score' not in k} 134 | # 2. overwrite entries in the existing state dict 135 | RSTN_dict.update(pretrained_dict_coarse) 136 | RSTN_dict.update(pretrained_dict_fine) 137 | RSTN_dict.update(pretrained_dict_fine_ema) 138 | 139 | # 3. load the new state dict 140 | RSTN_model.load_state_dict(RSTN_dict) 141 | print(plane + mode, 'load pre-trained unet_voc model successfully!') 142 | 143 | elif mode == 'I': 144 | print(plane + mode, 'load S model successfully!') 145 | # update_variables(RSTN_model.coarse_model, RSTN_model.fine_model_ema) 146 | elif mode == 'J': 147 | update_variables(RSTN_model.fine_model_ema, RSTN_model.fine_model) 148 | print(plane + mode, 'reload pretrained model for fine successfully!') 149 | 150 | else: 151 | raise ValueError("wrong value of mode, should be in ['S']") 152 | 153 | try: 154 | for e in range(epoch[mode]): 155 | total_loss = 0.0 156 | total_fine_loss = 0.0 157 | total_coarse_loss = 0.0 158 | start = time.time() 159 | for index, (image, label) in enumerate(trainloader): 160 | 161 | start_it = time.time() 162 | optimizer.zero_grad() 163 | image, label = image.cuda().float(), label.cuda().float() 164 | if mode == 'J': 165 | coarse_prob, fine_prob, fine_prob_gt = RSTN_model(image, e+1, label, mode=mode) 166 | fine_loss = overall_loss(fine_prob, label) 167 | fine_loss_gt = overall_loss(fine_prob_gt, label) 168 | loss = fine_loss + fine_loss_gt 169 | else: 170 | coarse_prob, fine_prob = RSTN_model(image, e+1, label, mode=mode) 171 | fine_loss = overall_loss(fine_prob, label) 172 | loss = fine_loss 173 | coarse_loss = overall_loss(coarse_prob, label) 174 | total_loss += loss.item() 175 | total_fine_loss += fine_loss.item() 176 | total_coarse_loss += coarse_loss.item() 177 | loss.backward() 178 | optimizer.step() 179 | 180 | if mode == 'S': 181 | update_ema_variables(RSTN_model.fine_model, RSTN_model.coarse_model, 0.999) 182 | elif mode == 'I': 183 | update_ema_variables(RSTN_model.fine_model, RSTN_model.coarse_model, 0.999) 184 | else: 185 | update_ema_variables(RSTN_model.fine_model, RSTN_model.fine_model_ema, 0.999) 186 | 187 | 188 | print(current_fold + plane + mode, "Epoch[%d/%d], Iter[%05d], Coarse/Fine Loss %.4f/%.4f, Time Elapsed %.2fs" \ 189 | %(e+1, epoch[mode], index, coarse_loss.item(),fine_loss.item(), time.time()-start_it)) 190 | 191 | del image, label, fine_prob,coarse_prob, loss, fine_loss,coarse_loss 192 | 193 | print(current_fold + plane + mode, "Epoch[%d], Total Coarse/Fine Loss %.4f/%.4f, Time elapsed %.2fs" \ 194 | %(e+1, total_coarse_loss / len(trainloader),total_fine_loss / len(trainloader),time.time()-start)) 195 | except KeyboardInterrupt: 196 | print('!' * 10 , 'save before quitting ...') 197 | finally: 198 | if mode == 'J': 199 | update_variables(RSTN_model.fine_model_ema, RSTN_model.fine_model) 200 | snapshot_name = 'FD' + current_fold + ':' + \ 201 | plane + mode + str(slice_thickness) + '_' + str(organ_ID) + '_' + timestamp 202 | RSTN_snapshot[mode] = os.path.join(snapshot_path, snapshot_name) + '.pkl' 203 | torch.save(RSTN_model.state_dict(), RSTN_snapshot[mode]) 204 | print('#' * 10 , 'end of ' + current_fold + plane + mode + ' training stage!') -------------------------------------------------------------------------------- /oracle_testing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import time 5 | from utils import * 6 | from model import * 7 | 8 | data_path = sys.argv[1] 9 | current_fold = int(sys.argv[2]) 10 | organ_number = int(sys.argv[3]) 11 | low_range = int(sys.argv[4]) 12 | high_range = int(sys.argv[5]) 13 | slice_threshold = float(sys.argv[6]) 14 | slice_thickness = int(sys.argv[7]) 15 | organ_ID = int(sys.argv[8]) 16 | plane = sys.argv[9] 17 | GPU_ID = int(sys.argv[10]) 18 | learning_rate1 = float(sys.argv[11]) 19 | learning_rate_m1 = int(sys.argv[12]) 20 | learning_rate2 = float(sys.argv[13]) 21 | learning_rate_m2 = int(sys.argv[14]) 22 | crop_margin = int(sys.argv[15]) 23 | crop_prob = float(sys.argv[16]) 24 | crop_sample_batch = int(sys.argv[17]) 25 | snapshot_path = os.path.join(snapshot_path, 'SIJ_training_' + \ 26 | sys.argv[11] + 'x' + str(learning_rate_m1) + ',' + str(crop_margin)) 27 | result_path = os.path.join(result_path, 'oracle_testing_' + \ 28 | sys.argv[11] + 'x' + str(learning_rate_m1) + ',' + str(crop_margin)) 29 | epoch = 'e' + sys.argv[18] + sys.argv[19] + sys.argv[20] + sys.argv[21] 30 | epoch_list = [epoch] 31 | 32 | timestamp = sys.argv[22] 33 | 34 | snapshot_name = snapshot_name_from_timestamp(snapshot_path, \ 35 | current_fold, plane, 'S', slice_thickness, organ_ID, timestamp) 36 | if snapshot_name == '': 37 | exit('Error: no valid snapshot directories are detected!') 38 | snapshot_directory = os.path.join(snapshot_path, snapshot_name) 39 | print('Snapshot directory: ' + snapshot_directory + ' .') 40 | snapshot = [snapshot_directory] 41 | print(str(len(snapshot)) + ' snapshots are to be evaluated.') 42 | for t in range(len(snapshot)): 43 | print(' Snapshot #' + str(t + 1) + ': ' + snapshot[t] + ' .') 44 | result_name = snapshot_name 45 | 46 | os.environ["CUDA_VISIBLE_DEVICES"]= str(GPU_ID) 47 | 48 | volume_list = open(testing_set_filename(current_fold), 'r').read().splitlines() 49 | while volume_list[len(volume_list) - 1] == '': 50 | volume_list.pop() 51 | DSC = np.zeros((len(snapshot), len(volume_list))) 52 | result_directory = os.path.join(result_path, result_name, 'volumes') 53 | if not os.path.exists(result_directory): 54 | os.makedirs(result_directory) 55 | result_file = os.path.join(result_path, result_name, 'results.txt') 56 | output = open(result_file, 'w') 57 | output.close() 58 | 59 | for t in range(len(snapshot)): 60 | output = open(result_file, 'a+') 61 | output.write('Evaluating snapshot ' + str(epoch_list[t]) + ':\n') 62 | output.close() 63 | finished = True 64 | for i in range(len(volume_list)): 65 | volume_file = volume_filename_testing(result_directory, epoch_list[t], i) 66 | if not os.path.isfile(volume_file): 67 | finished = False 68 | break 69 | if not finished: 70 | net = RSTN(crop_margin=crop_margin, crop_prob=crop_prob, \ 71 | crop_sample_batch=crop_sample_batch, TEST='O').cuda() 72 | net.load_state_dict(torch.load(snapshot[t])) 73 | net.eval() 74 | 75 | for i in range(len(volume_list)): 76 | start_time = time.time() 77 | print('Testing ' + str(i + 1) + ' out of ' + str(len(volume_list)) + ' testcases, ' + \ 78 | str(t + 1) + ' out of ' + str(len(snapshot)) + ' snapshots.') 79 | volume_file = volume_filename_testing(result_directory, epoch_list[t], i) 80 | s = volume_list[i].split(' ') 81 | label = np.load(s[2]) 82 | label = is_organ(label, organ_ID).astype(np.uint8) 83 | if not os.path.isfile(volume_file): 84 | image = np.load(s[1]).astype(np.float32) 85 | np.minimum(np.maximum(image, low_range, image), high_range, image) 86 | image -= low_range 87 | image /= (high_range - low_range) 88 | print(' Data loading is finished: ' + \ 89 | str(time.time() - start_time) + ' second(s) elapsed.') 90 | pred = np.zeros(image.shape, dtype = np.float32) 91 | label_sumX = np.sum(label, axis = (1, 2)) 92 | label_sumY = np.sum(label, axis = (0, 2)) 93 | label_sumZ = np.sum(label, axis = (0, 1)) 94 | if label_sumX.sum() == 0: 95 | continue 96 | minR = 0 97 | if plane == 'X': 98 | maxR = image.shape[0] 99 | shape_ = (1, 3, image.shape[1], image.shape[2]) 100 | elif plane == 'Y': 101 | maxR = image.shape[1] 102 | shape_ = (1, 3, image.shape[0], image.shape[2]) 103 | elif plane == 'Z': 104 | maxR = image.shape[2] 105 | shape_ = (1, 3, image.shape[0], image.shape[1]) 106 | for j in range(minR, maxR): 107 | if slice_thickness == 1: 108 | sID = [j, j, j] 109 | elif slice_thickness == 3: 110 | sID = [max(minR, j - 1), j, min(maxR - 1, j + 1)] 111 | if plane == 'X': 112 | if label_sumX[sID].sum() == 0: 113 | continue 114 | image_ = image[sID, :, :].astype(np.float32) 115 | label_ = label[sID, :, :].astype(np.float32) 116 | elif plane == 'Y': 117 | if label_sumY[sID].sum() == 0: 118 | continue 119 | image_ = image[:, sID, :].transpose(1, 0, 2).astype(np.float32) 120 | label_ = label[:, sID, :].transpose(1, 0, 2).astype(np.float32) 121 | elif plane == 'Z': 122 | if label_sumZ[sID].sum() == 0: 123 | continue 124 | image_ = image[:, :, sID].transpose(2, 0, 1).astype(np.float32) 125 | label_ = label[:, :, sID].transpose(2, 0, 1).astype(np.float32) 126 | 127 | image_ = image_.reshape((1, 3, image_.shape[1], image_.shape[2])) 128 | image_ = torch.from_numpy(image_).cuda().float() 129 | label_ = label_.reshape((1, 3, label_.shape[1], label_.shape[2])) 130 | label_ = torch.from_numpy(label_).cuda().float() 131 | out = net(image_,1, label_).data.cpu().numpy()[0, :, :, :] 132 | 133 | if slice_thickness == 1: 134 | if plane == 'X': 135 | pred[j, :, :] = out 136 | elif plane == 'Y': 137 | pred[:, j, :] = out 138 | elif plane == 'Z': 139 | pred[:, :, j] = out 140 | elif slice_thickness == 3: 141 | if plane == 'X': 142 | if j == minR: 143 | pred[j: j + 2, :, :] += out[1: 3, :, :] 144 | elif j == maxR - 1: 145 | pred[j - 1: j + 1, :, :] += out[0: 2, :, :] 146 | else: 147 | pred[j - 1: j + 2, :, :] += out[...] 148 | elif plane == 'Y': 149 | if j == minR: 150 | pred[:, j: j + 2, :] += out[1: 3, :, :].transpose(1, 0, 2) 151 | elif j == maxR - 1: 152 | pred[:, j - 1: j + 1, :] += out[0: 2, :, :].transpose(1, 0, 2) 153 | else: 154 | pred[:, j - 1: j + 2, :] += out[...].transpose(1, 0, 2) 155 | elif plane == 'Z': 156 | if j == minR: 157 | pred[:, :, j: j + 2] += out[1: 3, :, :].transpose(1, 2, 0) 158 | elif j == maxR - 1: 159 | pred[:, :, j - 1: j + 1] += out[0: 2, :, :].transpose(1, 2, 0) 160 | else: 161 | pred[:, :, j - 1: j + 2] += out[...].transpose(1, 2, 0) 162 | if slice_thickness == 3: 163 | if plane == 'X': 164 | pred[minR, :, :] /= 2 165 | pred[minR + 1: maxR - 1, :, :] /= 3 166 | pred[maxR - 1, :, :] /= 2 167 | elif plane == 'Y': 168 | pred[:, minR, :] /= 2 169 | pred[:, minR + 1: maxR - 1, :] /= 3 170 | pred[:, maxR - 1, :] /= 2 171 | elif plane == 'Z': 172 | pred[:, :, minR] /= 2 173 | pred[:, :, minR + 1: maxR - 1] /= 3 174 | pred[:, :, maxR - 1] /= 2 175 | print(' Testing is finished: ' + str(time.time() - start_time) + ' second(s) elapsed.') 176 | pred = np.around(pred * 255).astype(np.uint8) 177 | np.savez_compressed(volume_file, volume = pred) 178 | print(' Data saving is finished: ' + \ 179 | str(time.time() - start_time) + ' second(s) elapsed.') 180 | pred_temp = (pred >= 128) 181 | else: 182 | volume_data = np.load(volume_file) 183 | pred = volume_data['volume'].astype(np.uint8) 184 | print(' Testing result is loaded: ' + \ 185 | str(time.time() - start_time) + ' second(s) elapsed.') 186 | pred_temp = (pred >= 128) 187 | 188 | DSC[t, i], inter_sum, pred_sum, label_sum = DSC_computation(label, pred_temp) 189 | print(' DSC = 2 * ' + str(inter_sum) + ' / (' + str(pred_sum) + \ 190 | ' + ' + str(label_sum) + ') = ' + str(DSC[t, i]) + ' .') 191 | output = open(result_file, 'a+') 192 | output.write(' Testcase ' + str(i + 1) + ': DSC = 2 * ' + str(inter_sum) + ' / (' + \ 193 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(DSC[t, i]) + ' .\n') 194 | output.close() 195 | if pred_sum == 0 and label_sum == 0: 196 | DSC[t, i] = 0 197 | print(' DSC computation is finished: ' + \ 198 | str(time.time() - start_time) + ' second(s) elapsed.') 199 | 200 | print('Snapshot ' + str(epoch_list[t]) + ': average DSC = ' + str(np.mean(DSC[t, :])) + ' .') 201 | output = open(result_file, 'a+') 202 | output.write('Snapshot ' + str(epoch_list[t]) + \ 203 | ': average DSC = ' + str(np.mean(DSC[t, :])) + ' .\n') 204 | output.close() 205 | 206 | print('The testing process is finished.') 207 | for t in range(len(snapshot)): 208 | print(' Snapshot ' + str(epoch_list[t]) + ': average DSC = ' + str(np.mean(DSC[t, :])) + ' .') 209 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from vgg import VGG16 7 | 8 | class unetUp(nn.Module): 9 | def __init__(self, in_size, out_size): 10 | super(unetUp, self).__init__() 11 | self.conv1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1) 12 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1) 13 | self.relu = nn.ReLU(inplace=True) 14 | 15 | def forward(self, inputs1, inputs2): 16 | mid = F.interpolate(inputs2,(inputs1.shape[2],inputs1.shape[3]), mode='bilinear', align_corners=True) 17 | outputs = torch.cat([inputs1, mid], 1) 18 | outputs = self.conv1(outputs) 19 | outputs = self.relu(outputs) 20 | outputs = self.conv2(outputs) 21 | outputs = self.relu(outputs) 22 | return outputs 23 | 24 | class Unet(nn.Module): 25 | def __init__(self, num_classes=3, in_channels=3, pretrained=False): 26 | super(Unet, self).__init__() 27 | self.vgg = VGG16(pretrained=pretrained,in_channels=in_channels) 28 | in_filters = [192, 384, 768, 1024] 29 | out_filters = [64, 128, 256, 512] 30 | self.up_concat4 = unetUp(in_filters[3], out_filters[3]) 31 | self.up_concat3 = unetUp(in_filters[2], out_filters[2]) 32 | self.up_concat2 = unetUp(in_filters[1], out_filters[1]) 33 | self.up_concat1 = unetUp(in_filters[0], out_filters[0]) 34 | self.final = nn.Conv2d(out_filters[0], num_classes, 1) 35 | 36 | def forward(self, inputs): 37 | feat1 = self.vgg.features[ :4 ](inputs) 38 | feat2 = self.vgg.features[4 :9 ](feat1) 39 | feat3 = self.vgg.features[9 :16](feat2) 40 | feat4 = self.vgg.features[16:23](feat3) 41 | feat5 = self.vgg.features[23:-1](feat4) 42 | 43 | up4 = self.up_concat4(feat4, feat5) 44 | up3 = self.up_concat3(feat3, up4) 45 | up2 = self.up_concat2(feat2, up3) 46 | up1 = self.up_concat1(feat1, up2) 47 | 48 | final = self.final(up1) 49 | 50 | return final 51 | 52 | def _initialize_weights(self, *stages): 53 | for modules in stages: 54 | for module in modules.modules(): 55 | if isinstance(module, nn.Conv2d): 56 | nn.init.kaiming_normal_(module.weight) 57 | if module.bias is not None: 58 | module.bias.data.zero_() 59 | elif isinstance(module, nn.BatchNorm2d): 60 | module.weight.data.fill_(1) 61 | module.bias.data.zero_() 62 | 63 | class RSTN(nn.Module): 64 | def __init__(self, crop_margin=25, crop_prob=0.5, \ 65 | crop_sample_batch=1, n_class=3, TEST=None): 66 | super(RSTN, self).__init__() 67 | self.TEST = TEST 68 | self.margin = crop_margin 69 | self.prob = crop_prob 70 | self.batch = crop_sample_batch 71 | self.coarse_model = Unet() 72 | self.fine_model = Unet() 73 | self.fine_model_ema = Unet() 74 | self._initialize_weights() 75 | 76 | def _initialize_weights(self): 77 | for name, mod in self.named_children(): 78 | if name == 'saliency1': 79 | nn.init.xavier_normal_(mod.weight.data) 80 | mod.bias.data.fill_(1) 81 | elif name == 'saliency2': 82 | mod.weight.data.zero_() 83 | mod.bias.data = torch.tensor([1.0, 1.5, 2.0]) 84 | elif name == 'tt': 85 | nn.init.xavier_normal_(mod.weight.data) 86 | mod.bias.data.fill_(1) 87 | elif name == 'zz': 88 | nn.init.xavier_normal_(mod.weight.data) 89 | mod.bias.data.fill_(1) 90 | elif name == 'oo': 91 | nn.init.xavier_normal_(mod.weight.data) 92 | mod.bias.data.fill_(1) 93 | 94 | def forward(self, image, e, label=None, mode=None, score=None, mask=None): 95 | if self.TEST is None: 96 | assert label is not None and mode is not None \ 97 | and score is None and mask is None 98 | 99 | if mode == 'S': 100 | h = image 101 | h = self.coarse_model(h) 102 | h = torch.sigmoid(h) 103 | coarse_prob = h 104 | cropped_image, crop_info = self.crop(label, image) 105 | h = cropped_image 106 | h = self.fine_model(h) 107 | h = self.uncrop(crop_info, h, image) 108 | h = torch.sigmoid(h) 109 | fine_prob = h 110 | 111 | return coarse_prob,fine_prob 112 | 113 | elif mode == 'I': 114 | if e <= 2: 115 | h = image 116 | h = self.coarse_model(h) 117 | h = torch.sigmoid(h) 118 | coarse_prob = h 119 | cropped_image, crop_info = self.crop(coarse_prob, image, label) 120 | h = cropped_image 121 | h = self.fine_model(h) 122 | h = self.uncrop(crop_info, h, image) 123 | h = torch.sigmoid(h) 124 | fine_prob = h 125 | return coarse_prob,fine_prob 126 | 127 | elif e > 2: 128 | coarse_prob = image*0 129 | h = image 130 | h = self.fine_model(h) 131 | h = torch.sigmoid(h) 132 | fine_prob = h 133 | return coarse_prob,fine_prob 134 | 135 | elif mode == 'J': 136 | 137 | h = image 138 | h = self.coarse_model(h) 139 | h = torch.sigmoid(h) 140 | coarse_prob = h 141 | 142 | cropped_image, crop_info = self.crop(coarse_prob, image, label) 143 | 144 | h = cropped_image 145 | h = self.fine_model(h) 146 | h = self.uncrop(crop_info, h, image) 147 | h = torch.sigmoid(h) 148 | fine_prob = h 149 | cropped_image, crop_info = self.crop(label, image) 150 | h = cropped_image 151 | h = self.fine_model(h) 152 | h = self.uncrop(crop_info, h, image) 153 | h = torch.sigmoid(h) 154 | fine_prob_gt = h 155 | return coarse_prob, fine_prob, fine_prob_gt 156 | 157 | 158 | elif self.TEST == 'C': 159 | assert label is None and mode is None and \ 160 | score is None and mask is None 161 | h = image 162 | h = self.coarse_model(h) 163 | h = torch.sigmoid(h) 164 | coarse_prob = h 165 | return coarse_prob 166 | 167 | elif self.TEST == 'O': 168 | assert label is not None and mode is None and \ 169 | score is None and mask is None 170 | 171 | cropped_image, crop_info = self.crop(label, image) 172 | h = cropped_image 173 | h = self.fine_model(h) 174 | h = self.uncrop(crop_info, h, image) 175 | h = torch.sigmoid(h) 176 | fine_prob = h 177 | return fine_prob 178 | 179 | elif self.TEST == 'F': 180 | assert label is None and mode is None \ 181 | and score is not None and mask is not None 182 | h = score 183 | cropped_image, crop_info = self.crop(mask, image) 184 | h = cropped_image 185 | fine_prob = self.fine_model(h) 186 | fine_prob = self.uncrop(crop_info, fine_prob, image) 187 | fine_prob = torch.sigmoid(fine_prob) 188 | return fine_prob 189 | 190 | else: 191 | raise ValueError("wrong value of TEST, should be in [None , 'O']") 192 | 193 | def crop(self, prob_map, saliency_data, label=None): 194 | (N, C, W, H) = prob_map.shape 195 | 196 | binary_mask = (prob_map >= 0.5) # torch.uint8 197 | if label is not None and binary_mask.sum().item() == 0: 198 | binary_mask = (label >= 0.5) 199 | 200 | if self.TEST is not None: 201 | self.left = self.margin 202 | self.right = self.margin 203 | self.top = self.margin 204 | self.bottom = self.margin 205 | else: 206 | self.update_margin() 207 | 208 | if binary_mask.sum().item() == 0: # avoid this by pre-condition in TEST 'F' 209 | minA = 0 210 | maxA = W 211 | minB = 0 212 | maxB = H 213 | self.no_forward = True 214 | else: 215 | if N > 1: 216 | mask = torch.zeros(size = (N, C, W, H)) 217 | for n in range(N): 218 | cur_mask = binary_mask[n, :, :, :] 219 | arr = torch.nonzero(cur_mask) 220 | minA = arr[:, 1].min().item() 221 | maxA = arr[:, 1].max().item() 222 | minB = arr[:, 2].min().item() 223 | maxB = arr[:, 2].max().item() 224 | bbox = [int(max(minA - self.left, 0)), int(min(maxA + self.right + 1, W)), \ 225 | int(max(minB - self.top, 0)), int(min(maxB + self.bottom + 1, H))] 226 | mask[n, :, bbox[0]: bbox[1], bbox[2]: bbox[3]] = 1 227 | saliency_data = saliency_data * mask.cuda() 228 | 229 | arr = torch.nonzero(binary_mask) 230 | minA = arr[:, 2].min().item() 231 | maxA = arr[:, 2].max().item() 232 | minB = arr[:, 3].min().item() 233 | maxB = arr[:, 3].max().item() 234 | self.no_forward = False 235 | 236 | bbox = [int(max(minA - self.left, 0)), int(min(maxA + self.right + 1, W)), \ 237 | int(max(minB - self.top, 0)), int(min(maxB + self.bottom + 1, H))] 238 | cropped_image = saliency_data[:, :, bbox[0]: bbox[1], \ 239 | bbox[2]: bbox[3]] 240 | 241 | if self.no_forward == True and self.TEST == 'F': 242 | cropped_image = torch.zeros_like(cropped_image).cuda() 243 | 244 | crop_info = np.zeros((1, 4), dtype = np.int16) 245 | crop_info[0] = bbox 246 | crop_info = torch.from_numpy(crop_info).cuda() 247 | 248 | return cropped_image, crop_info 249 | 250 | def update_margin(self): 251 | MAX_INT = 256 252 | if random.randint(0, MAX_INT - 1) >= MAX_INT * self.prob: 253 | self.left = self.margin 254 | self.right = self.margin 255 | self.top = self.margin 256 | self.bottom = self.margin 257 | else: 258 | a = np.zeros(self.batch * 4, dtype = np.uint8) 259 | for i in range(self.batch * 4): 260 | a[i] = random.randint(0, self.margin * 2) 261 | self.left = int(a[0: self.batch].sum() / self.batch) 262 | self.right = int(a[self.batch: self.batch * 2].sum() / self.batch) 263 | self.top = int(a[self.batch * 2: self.batch * 3].sum() / self.batch) 264 | self.bottom = int(a[self.batch * 3: self.batch * 4].sum() / self.batch) 265 | 266 | def uncrop(self, crop_info, cropped_image, image): 267 | uncropped_image = torch.ones_like(image).cuda() 268 | uncropped_image *= (-9999999) 269 | bbox = crop_info[0] 270 | uncropped_image[:, :, bbox[0].item(): bbox[1].item(), bbox[2].item(): bbox[3].item()] = cropped_image 271 | return uncropped_image 272 | 273 | def get_parameters(model, coarse=True, bias=False, parallel=False): 274 | print('coarse, bias', coarse, bias) 275 | if parallel: 276 | for name, mod in model.named_children(): 277 | print('parallel', name) 278 | model = mod 279 | break 280 | for name, mod in model.named_children(): 281 | if name == 'coarse_model' and coarse \ 282 | or name in ['saliency1', 'saliency2', 'fine_model'] and not coarse: 283 | # or name in ['saliency1', 'saliency2', 'fine_model', 'fine_model_ema'] and not coarse: 284 | print(name) 285 | for n, m in mod.named_modules(): 286 | if isinstance(m, nn.Conv2d): 287 | print(n, m) 288 | if bias and m.bias is not None: 289 | yield m.bias 290 | elif not bias: 291 | yield m.weight 292 | elif isinstance(m, nn.ConvTranspose2d): 293 | # weight is frozen because it is just a bilinear upsampling 294 | if bias: 295 | assert m.bias is None 296 | 297 | class DSC_loss(nn.Module): 298 | def __init__(self): 299 | super(DSC_loss, self).__init__() 300 | self.epsilon = 0.000001 301 | return 302 | def forward(self, pred, target): # soft mode. per item. 303 | batch_num = pred.shape[0] 304 | pred = pred.contiguous().view(batch_num, -1) 305 | target = target.contiguous().view(batch_num, -1) 306 | DSC = (2 * (pred * target).sum(1) + self.epsilon) / \ 307 | ((pred + target).sum(1) + self.epsilon) 308 | return 1 - DSC.sum() / float(batch_num) 309 | -------------------------------------------------------------------------------- /oracle_fusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import time 5 | from utils import * 6 | 7 | data_path = sys.argv[1] 8 | current_fold = int(sys.argv[2]) 9 | organ_number = int(sys.argv[3]) 10 | low_range = int(sys.argv[4]) 11 | high_range = int(sys.argv[5]) 12 | slice_threshold = float(sys.argv[6]) 13 | slice_thickness = int(sys.argv[7]) 14 | organ_ID = int(sys.argv[8]) 15 | GPU_ID = int(sys.argv[9]) 16 | learning_rate1 = float(sys.argv[10]) 17 | learning_rate_m1 = int(sys.argv[11]) 18 | learning_rate2 = float(sys.argv[12]) 19 | learning_rate_m2 = int(sys.argv[13]) 20 | crop_margin = int(sys.argv[14]) 21 | result_path = os.path.join(result_path, 'oracle_testing_' + \ 22 | sys.argv[10] + 'x' + str(learning_rate_m1) + ',' + str(crop_margin)) 23 | epoch = 'e' + sys.argv[15] + sys.argv[16] + sys.argv[17] + sys.argv[18] 24 | epoch_list = [epoch] 25 | threshold = float(sys.argv[19]) 26 | timestamp = {} 27 | timestamp['X'] = sys.argv[20] 28 | timestamp['Y'] = sys.argv[21] 29 | timestamp['Z'] = sys.argv[22] 30 | 31 | volume_list = open(testing_set_filename(current_fold), 'r').read().splitlines() 32 | while volume_list[len(volume_list) - 1] == '': 33 | volume_list.pop() 34 | 35 | result_name_ = {} 36 | result_directory_ = {} 37 | for plane in ['X', 'Y', 'Z']: 38 | result_name__ = result_name_from_timestamp(result_path, current_fold, \ 39 | plane, 'S', slice_thickness, organ_ID, volume_list, timestamp[plane]) 40 | if result_name__ == '': 41 | exit(' Error: no valid result directories are detected!') 42 | result_directory__ = os.path.join(result_path, result_name__, 'volumes') 43 | print(' Result directory for plane ' + plane + ': ' + result_directory__ + ' .') 44 | if result_name__.startswith('FD'): 45 | index_ = result_name__.find(':') 46 | result_name__ = result_name__[index_ + 1: ] 47 | result_name_[plane] = result_name__ 48 | result_directory_[plane] = result_directory__ 49 | 50 | DSC_X = np.zeros((len(volume_list))) 51 | DSC_Y = np.zeros((len(volume_list))) 52 | DSC_Z = np.zeros((len(volume_list))) 53 | DSC_F1 = np.zeros((len(volume_list))) 54 | DSC_F2 = np.zeros((len(volume_list))) 55 | DSC_F3 = np.zeros((len(volume_list))) 56 | DSC_F1P = np.zeros((len(volume_list))) 57 | DSC_F2P = np.zeros((len(volume_list))) 58 | DSC_F3P = np.zeros((len(volume_list))) 59 | 60 | result_name = 'FD' + str(current_fold) + ':' + 'fusion:' + result_name_['X'] + ',' + \ 61 | result_name_['Y'] + ',' + result_name_['Z'] + '_' + epoch + ',' + str(threshold) 62 | result_directory = os.path.join(result_path, result_name, 'volumes') 63 | if not os.path.exists(result_directory): 64 | os.makedirs(result_directory) 65 | 66 | result_file = os.path.join(result_path, result_name, 'results.txt') 67 | output = open(result_file, 'w') 68 | output.close() 69 | output = open(result_file, 'a+') 70 | output.write('Fusing results of ' + str(len(epoch_list)) + ' snapshots:\n') 71 | output.close() 72 | 73 | for i in range(len(volume_list)): 74 | start_time = time.time() 75 | print('Testing ' + str(i + 1) + ' out of ' + str(len(volume_list)) + ' testcases.') 76 | output = open(result_file, 'a+') 77 | output.write(' Testcase ' + str(i + 1) + ':\n') 78 | output.close() 79 | s = volume_list[i].split(' ') 80 | label = np.load(s[2]) 81 | label = is_organ(label, organ_ID).astype(np.uint8) 82 | 83 | for plane in ['X', 'Y', 'Z']: 84 | #pdb.set_trace() 85 | volume_file = volume_filename_fusion(result_directory, plane, i) 86 | pred = np.zeros(label.shape, dtype = np.float32) 87 | for t in range(len(epoch_list)): 88 | volume_file_ = volume_filename_testing(result_directory_[plane], epoch_list[t], i) 89 | pred += np.load(volume_file_)['volume'] 90 | pred_ = (pred >= threshold * 255 * len(epoch_list)) 91 | if not os.path.isfile(volume_file): 92 | np.savez_compressed(volume_file, volume = pred_) 93 | DSC_, inter_sum, pred_sum, label_sum = DSC_computation(label, pred_) 94 | print(' DSC_' + plane + ' = 2 * ' + str(inter_sum) + ' / (' + \ 95 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(DSC_) + ' .') 96 | output = open(result_file, 'a+') 97 | output.write(' DSC_' + plane + ' = 2 * ' + str(inter_sum) + ' / (' + \ 98 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(DSC_) + ' .\n') 99 | output.close() 100 | if pred_sum == 0 and label_sum == 0: 101 | DSC_ = 0 102 | pred /= (255 * len(epoch_list)) 103 | if plane == 'X': 104 | pred_X = pred 105 | DSC_X[i] = DSC_ 106 | elif plane == 'Y': 107 | pred_Y = pred 108 | DSC_Y[i] = DSC_ 109 | elif plane == 'Z': 110 | pred_Z = pred 111 | DSC_Z[i] = DSC_ 112 | 113 | volume_file_F1 = volume_filename_fusion(result_directory, 'F1', i) 114 | volume_file_F2 = volume_filename_fusion(result_directory, 'F2', i) 115 | volume_file_F3 = volume_filename_fusion(result_directory, 'F3', i) 116 | 117 | if not os.path.isfile(volume_file_F1) or not os.path.isfile(volume_file_F2) or \ 118 | not os.path.isfile(volume_file_F3): 119 | pred_total = pred_X + pred_Y + pred_Z 120 | if os.path.isfile(volume_file_F1): 121 | pred_F1 = np.load(volume_file_F1)['volume'].astype(np.uint8) 122 | else: 123 | pred_F1 = (pred_total >= 0.5).astype(np.uint8) 124 | np.savez_compressed(volume_file_F1, volume = pred_F1) 125 | DSC_F1[i], inter_sum, pred_sum, label_sum = DSC_computation(label, pred_F1) 126 | print(' DSC_F1 = 2 * ' + str(inter_sum) + ' / (' + str(pred_sum) + ' + ' \ 127 | + str(label_sum) + ') = ' + str(DSC_F1[i]) + ' .') 128 | output = open(result_file, 'a+') 129 | output.write(' DSC_F1 = 2 * ' + str(inter_sum) + ' / (' + \ 130 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(DSC_F1[i]) + ' .\n') 131 | output.close() 132 | if pred_sum == 0 and label_sum == 0: 133 | DSC_F1[i] = 0 134 | 135 | if os.path.isfile(volume_file_F2): 136 | pred_F2 = np.load(volume_file_F2)['volume'].astype(np.uint8) 137 | else: 138 | pred_F2 = (pred_total >= 1.5).astype(np.uint8) 139 | np.savez_compressed(volume_file_F2, volume = pred_F2) 140 | DSC_F2[i], inter_sum, pred_sum, label_sum = DSC_computation(label, pred_F2) 141 | print(' DSC_F2 = 2 * ' + str(inter_sum) + ' / (' + str(pred_sum) + ' + ' + \ 142 | str(label_sum) + ') = ' + str(DSC_F2[i]) + ' .') 143 | output = open(result_file, 'a+') 144 | output.write(' DSC_F2 = 2 * ' + str(inter_sum) + ' / (' + \ 145 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(DSC_F2[i]) + ' .\n') 146 | output.close() 147 | if pred_sum == 0 and label_sum == 0: 148 | DSC_F2[i] = 0 149 | 150 | if os.path.isfile(volume_file_F3): 151 | pred_F3 = np.load(volume_file_F3)['volume'].astype(np.uint8) 152 | else: 153 | pred_F3 = (pred_total >= 2.5).astype(np.uint8) 154 | np.savez_compressed(volume_file_F3, volume = pred_F3) 155 | DSC_F3[i], inter_sum, pred_sum, label_sum = DSC_computation(label, pred_F3) 156 | print(' DSC_F3 = 2 * ' + str(inter_sum) + ' / (' + str(pred_sum) + ' + ' + \ 157 | str(label_sum) + ') = ' + str(DSC_F3[i]) + ' .') 158 | output = open(result_file, 'a+') 159 | output.write(' DSC_F3 = 2 * ' + str(inter_sum) + ' / (' + \ 160 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(DSC_F3[i]) + ' .\n') 161 | output.close() 162 | if pred_sum == 0 and label_sum == 0: 163 | DSC_F3[i] = 0 164 | 165 | volume_file_F1P = volume_filename_fusion(result_directory, 'F1P', i) 166 | volume_file_F2P = volume_filename_fusion(result_directory, 'F2P', i) 167 | volume_file_F3P = volume_filename_fusion(result_directory, 'F3P', i) 168 | S = pred_F3 169 | if (S.sum() == 0): 170 | S = pred_F2 171 | if (S.sum() == 0): 172 | S = pred_F1 173 | 174 | if os.path.isfile(volume_file_F1P): 175 | pred_F1P = np.load(volume_file_F1P)['volume'].astype(np.uint8) 176 | else: 177 | pred_F1P = post_processing(pred_F1, S, 0.5, organ_ID) 178 | np.savez_compressed(volume_file_F1P, volume = pred_F1P) 179 | DSC_F1P[i], inter_sum, pred_sum, label_sum = DSC_computation(label, pred_F1P) 180 | print(' DSC_F1P = 2 * ' + str(inter_sum) + ' / (' + str(pred_sum) + ' + ' + \ 181 | str(label_sum) + ') = ' + str(DSC_F1P[i]) + ' .') 182 | output = open(result_file, 'a+') 183 | output.write(' DSC_F1P = 2 * ' + str(inter_sum) + ' / (' + \ 184 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(DSC_F1P[i]) + ' .\n') 185 | output.close() 186 | if pred_sum == 0 and label_sum == 0: 187 | DSC_F1P[i] = 0 188 | 189 | if os.path.isfile(volume_file_F2P): 190 | pred_F2P = np.load(volume_file_F2P)['volume'].astype(np.uint8) 191 | else: 192 | pred_F2P = post_processing(pred_F2, S, 0.5, organ_ID) 193 | np.savez_compressed(volume_file_F2P, volume = pred_F2P) 194 | DSC_F2P[i], inter_sum, pred_sum, label_sum = DSC_computation(label, pred_F2P) 195 | print(' DSC_F2P = 2 * ' + str(inter_sum) + ' / (' + str(pred_sum) + ' + ' + \ 196 | str(label_sum) + ') = ' + str(DSC_F2P[i]) + ' .') 197 | output = open(result_file, 'a+') 198 | output.write(' DSC_F2P = 2 * ' + str(inter_sum) + ' / (' + \ 199 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(DSC_F2P[i]) + ' .\n') 200 | output.close() 201 | if pred_sum == 0 and label_sum == 0: 202 | DSC_F2P[i] = 0 203 | 204 | if os.path.isfile(volume_file_F3P): 205 | pred_F3P = np.load(volume_file_F3P)['volume'].astype(np.uint8) 206 | else: 207 | pred_F3P = post_processing(pred_F3, S, 0.5, organ_ID) 208 | np.savez_compressed(volume_file_F3P, volume = pred_F3P) 209 | DSC_F3P[i], inter_sum, pred_sum, label_sum = DSC_computation(label, pred_F3P) 210 | print(' DSC_F3P = 2 * ' + str(inter_sum) + ' / (' + str(pred_sum) + ' + ' + \ 211 | str(label_sum) + ') = ' + str(DSC_F3P[i]) + ' .') 212 | output = open(result_file, 'a+') 213 | output.write(' DSC_F3P = 2 * ' + str(inter_sum) + ' / (' + \ 214 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(DSC_F3P[i]) + ' .\n') 215 | output.close() 216 | if pred_sum == 0 and label_sum == 0: 217 | DSC_F3P[i] = 0 218 | 219 | pred_X = None 220 | pred_Y = None 221 | pred_Z = None 222 | pred_F1 = None 223 | pred_F2 = None 224 | pred_F3 = None 225 | pred_F1P = None 226 | pred_F2P = None 227 | pred_F3P = None 228 | 229 | output = open(result_file, 'a+') 230 | print('Average DSC_X = ' + str(np.mean(DSC_X)) + ' .') 231 | output.write('Average DSC_X = ' + str(np.mean(DSC_X)) + ' .\n') 232 | print('Average DSC_Y = ' + str(np.mean(DSC_Y)) + ' .') 233 | output.write('Average DSC_Y = ' + str(np.mean(DSC_Y)) + ' .\n') 234 | print('Average DSC_Z = ' + str(np.mean(DSC_Z)) + ' .') 235 | output.write('Average DSC_Z = ' + str(np.mean(DSC_Z)) + ' .\n') 236 | print('Average DSC_F1 = ' + str(np.mean(DSC_F1)) + ' .') 237 | output.write('Average DSC_F1 = ' + str(np.mean(DSC_F1)) + ' .\n') 238 | print('Average DSC_F2 = ' + str(np.mean(DSC_F2)) + ' .') 239 | output.write('Average DSC_F2 = ' + str(np.mean(DSC_F2)) + ' .\n') 240 | print('Average DSC_F3 = ' + str(np.mean(DSC_F3)) + ' .') 241 | output.write('Average DSC_F3 = ' + str(np.mean(DSC_F3)) + ' .\n') 242 | print('Average DSC_F1P = ' + str(np.mean(DSC_F1P)) + ' .') 243 | output.write('Average DSC_F1P = ' + str(np.mean(DSC_F1P)) + ' .\n') 244 | print('Average DSC_F2P = ' + str(np.mean(DSC_F2P)) + ' .') 245 | output.write('Average DSC_F2P = ' + str(np.mean(DSC_F2P)) + ' .\n') 246 | print('Average DSC_F3P = ' + str(np.mean(DSC_F3P)) + ' .') 247 | output.write('Average DSC_F3P = ' + str(np.mean(DSC_F3P)) + ' .\n') 248 | output.close() 249 | print('The fusion process is finished.') -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import math 5 | import torch.nn as nn 6 | import fast_functions as ff 7 | from torch.nn import init 8 | 9 | def init_weights(net, init_type='normal'): 10 | #print('initialization method [%s]' % init_type) 11 | if init_type == 'kaiming': 12 | net.apply(weights_init_kaiming) 13 | else: 14 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 15 | 16 | def weights_init_kaiming(m): 17 | classname = m.__class__.__name__ 18 | #print(classname) 19 | if classname.find('Conv') != -1: 20 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 21 | elif classname.find('Linear') != -1: 22 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 23 | elif classname.find('BatchNorm') != -1: 24 | init.normal_(m.weight.data, 1.0, 0.02) 25 | init.constant_(m.bias.data, 0.0) 26 | 27 | ### compute model params 28 | def count_param(model): 29 | param_count = 0 30 | for param in model.parameters(): 31 | param_count += param.view(-1).size()[0] 32 | return param_count 33 | #################################################################################################### 34 | # returning the binary label map by the organ ID (especially useful under overlapping cases) 35 | # label: the label matrix 36 | # organ_ID: the organ ID 37 | def is_organ(label, organ_ID): 38 | return label == organ_ID 39 | 40 | #################################################################################################### 41 | # determining if a sample belongs to the training set by the fold number 42 | # total_samples: the total number of samples 43 | # i: sample ID, an integer in [0, total_samples - 1] 44 | # folds: the total number of folds 45 | # current_fold: the current fold ID, an integer in [0, folds - 1] 46 | def in_training_set(total_samples, i, folds, current_fold): 47 | fold_remainder = folds - total_samples % folds 48 | fold_size = (total_samples - total_samples % folds) / folds 49 | start_index = fold_size * current_fold + max(0, current_fold - fold_remainder) 50 | end_index = fold_size * (current_fold + 1) + max(0, current_fold + 1 - fold_remainder) 51 | return not (i >= start_index and i < end_index) 52 | 53 | #################################################################################################### 54 | # returning the filename of the training set according to the current fold ID 55 | def training_set_filename(current_fold): 56 | return os.path.join(list_path, 'training_' + 'FD' + str(current_fold) + '.txt') 57 | 58 | #################################################################################################### 59 | # returning the filename of the testing set according to the current fold ID 60 | def testing_set_filename(current_fold): 61 | return os.path.join(list_path, 'testing_' + 'FD' + str(current_fold) + '.txt') 62 | 63 | #################################################################################################### 64 | # returning the filename of the log file 65 | def log_filename(snapshot_directory): 66 | count = 0 67 | while True: 68 | count += 1 69 | if count == 1: 70 | log_file_ = os.path.join(snapshot_directory, 'log.txt') 71 | else: 72 | log_file_ = os.path.join(snapshot_directory, 'log' + str(count) + '.txt') 73 | if not os.path.isfile(log_file_): 74 | return log_file_ 75 | 76 | #################################################################################################### 77 | # returning the snapshot name 78 | def snapshot_name_from_timestamp(snapshot_path, \ 79 | current_fold, plane, stage_code, slice_thickness, organ_ID, timestamp): 80 | snapshot_prefix = 'FD' + str(current_fold) + ':' + plane + \ 81 | stage_code + str(slice_thickness) + '_' + str(organ_ID) 82 | if len(timestamp) == 15: 83 | snapshot_prefix = snapshot_prefix + '_' + timestamp 84 | snapshot_name = snapshot_prefix + '.pkl' 85 | if os.path.isfile(os.path.join(snapshot_path, snapshot_name)): 86 | return snapshot_name 87 | else: 88 | return '' 89 | 90 | #################################################################################################### 91 | # returning the result name 92 | def result_name_from_timestamp(result_path, current_fold, \ 93 | plane, stage_code, slice_thickness, organ_ID, volume_list, timestamp): 94 | result_prefix = 'FD' + str(current_fold) + ':' + plane + \ 95 | stage_code + str(slice_thickness) + '_' + str(organ_ID) 96 | if len(timestamp) == 15: 97 | result_prefix = result_prefix + '_' + timestamp 98 | result_name = result_prefix + '.pkl' 99 | if os.path.exists(os.path.join(result_path, result_name, 'volumes')): 100 | return result_name 101 | else: 102 | return '' 103 | 104 | #################################################################################################### 105 | # returning the volume filename as in the testing stage 106 | def volume_filename_testing(result_directory, t, i): 107 | return os.path.join(result_directory, str(t) + '_' + str(i + 1) + '.npz') 108 | 109 | #################################################################################################### 110 | # returning the volume filename as in the fusion stage 111 | def volume_filename_fusion(result_directory, code, i): 112 | return os.path.join(result_directory, code + '_' + str(i + 1) + '.npz') 113 | 114 | #################################################################################################### 115 | # returning the volume filename as in the coarse-to-fine testing stage 116 | def volume_filename_coarse2fine(result_directory, r, i): 117 | return os.path.join(result_directory, 'R' + str(r) + '_' + str(i + 1) + '.npz') 118 | 119 | #################################################################################################### 120 | # computing the DSC together with other values based on the label and prediction volumes 121 | # def DSC_computation(label, pred): 122 | # pred_sum = pred.sum() 123 | # label_sum = label.sum() 124 | # inter_sum = np.logical_and(pred, label).sum() 125 | # return 2 * float(inter_sum) / (pred_sum + label_sum), inter_sum, pred_sum, label_sum 126 | 127 | def DSC_computation(label, pred): 128 | P = np.zeros(3, dtype = np.uint32) 129 | ff.DSC_computation(label, pred, P) 130 | return 2 * float(P[2]) / (P[0] + P[1]), P[2], P[1], P[0] 131 | 132 | #################################################################################################### 133 | # post-processing: preserving the largest connecting component(s) and discarding other voxels 134 | # The floodfill algorithm is used to detect the connecting components. 135 | # In the future version, this function is to be replaced by a C module for speedup! 136 | # F: a binary volume, the volume to be post-processed 137 | # S: a binary volume, the seed voxels (currently defined as those predicted as FG by all 3 views) 138 | # NOTE: a connected component will not be considered if it does not contain any seed voxels 139 | # threshold: a floating point number in [0, 1] determining if a connected component is accepted 140 | # NOTE: accepted if it is not smaller larger than the largest volume times this number 141 | # NOTE: 1 means to only keep the largest one(s), 0 means to keep all 142 | # organ_ID: passed in case that each organ needs to be dealt with differently 143 | # def post_processing(F, S, threshold, organ_ID): 144 | # if F.sum() == 0: 145 | # return F 146 | # if F.sum() >= np.product(F.shape) / 2: 147 | # return F 148 | # height = F.shape[0] 149 | # width = F.shape[1] 150 | # depth = F.shape[2] 151 | # ll = np.array(np.nonzero(S)) 152 | # marked = np.zeros_like(F, dtype = np.bool) 153 | # queue = np.zeros((F.sum(), 3), dtype = np.int) 154 | # volume = np.zeros(F.sum(), dtype = np.int) 155 | # head = 0 156 | # tail = 0 157 | # bestHead = 0 158 | # bestTail = 0 159 | # bestHead2 = 0 160 | # bestTail2 = 0 161 | # for l in range(ll.shape[1]): 162 | # if not marked[ll[0, l], ll[1, l], ll[2, l]]: 163 | # temp = head 164 | # marked[ll[0, l], ll[1, l], ll[2, l]] = True 165 | # queue[tail, :] = [ll[0, l], ll[1, l], ll[2, l]] 166 | # tail = tail + 1 167 | # while (head < tail): 168 | # t1 = queue[head, 0] 169 | # t2 = queue[head, 1] 170 | # t3 = queue[head, 2] 171 | # if t1 > 0 and F[t1 - 1, t2, t3] and not marked[t1 - 1, t2, t3]: 172 | # marked[t1 - 1, t2, t3] = True 173 | # queue[tail, :] = [t1 - 1, t2, t3] 174 | # tail = tail + 1 175 | # if t1 < height - 1 and F[t1 + 1, t2, t3] and not marked[t1 + 1, t2, t3]: 176 | # marked[t1 + 1, t2, t3] = True 177 | # queue[tail, :] = [t1 + 1, t2, t3] 178 | # tail = tail + 1 179 | # if t2 > 0 and F[t1, t2 - 1, t3] and not marked[t1, t2 - 1, t3]: 180 | # marked[t1, t2 - 1, t3] = True 181 | # queue[tail, :] = [t1, t2 - 1, t3] 182 | # tail = tail + 1 183 | # if t2 < width - 1 and F[t1, t2 + 1, t3] and not marked[t1, t2 + 1, t3]: 184 | # marked[t1, t2 + 1, t3] = True 185 | # queue[tail, :] = [t1, t2 + 1, t3] 186 | # tail = tail + 1 187 | # if t3 > 0 and F[t1, t2, t3 - 1] and not marked[t1, t2, t3 - 1]: 188 | # marked[t1, t2, t3 - 1] = True 189 | # queue[tail, :] = [t1, t2, t3 - 1] 190 | # tail = tail + 1 191 | # if t3 < depth - 1 and F[t1, t2, t3 + 1] and not marked[t1, t2, t3 + 1]: 192 | # marked[t1, t2, t3 + 1] = True 193 | # queue[tail, :] = [t1, t2, t3 + 1] 194 | # tail = tail + 1 195 | # head = head + 1 196 | # if tail - temp > bestTail - bestHead: 197 | # bestHead2 = bestHead 198 | # bestTail2 = bestTail 199 | # bestHead = temp 200 | # bestTail = tail 201 | # elif tail - temp > bestTail2 - bestHead2: 202 | # bestHead2 = temp 203 | # bestTail2 = tail 204 | # volume[temp: tail] = tail - temp 205 | # volume = volume[0: tail] 206 | # target_voxel = np.where(volume >= (bestTail - bestHead) * threshold) 207 | # F0 = np.zeros_like(F, dtype = np.bool) 208 | # F0[tuple(map(tuple, np.transpose(queue[target_voxel, :])))] = True 209 | # return F0.astype(np.uint8) 210 | 211 | def post_processing(F, S, threshold, organ_ID): 212 | ff.post_processing(F, S, threshold, False) 213 | return F 214 | 215 | #################################################################################################### 216 | # defining the common variables used throughout the entire flowchart 217 | data_path =sys.argv[1] 218 | image_path = os.path.join(data_path, 'images') 219 | image_path_ = {} 220 | for plane in ['X', 'Y', 'Z']: 221 | image_path_[plane] = os.path.join(data_path, 'images_' + plane) 222 | if not os.path.exists(image_path_[plane]): 223 | os.makedirs(image_path_[plane]) 224 | label_path = os.path.join(data_path, 'labels') 225 | label_path_ = {} 226 | for plane in ['X', 'Y', 'Z']: 227 | label_path_[plane] = os.path.join(data_path, 'labels_' + plane) 228 | if not os.path.exists(label_path_[plane]): 229 | os.makedirs(label_path_[plane]) 230 | list_path = os.path.join(data_path, 'lists') 231 | if not os.path.exists(list_path): 232 | os.makedirs(list_path) 233 | list_training = {} 234 | for plane in ['X', 'Y', 'Z']: 235 | list_training[plane] = os.path.join(list_path, 'training_' + plane + '.txt') 236 | model_path = os.path.join(data_path, 'models') 237 | if not os.path.exists(model_path): 238 | os.makedirs(model_path) 239 | pretrained_model_path = os.path.join(data_path, 'models', 'pretrained') 240 | if not os.path.exists(pretrained_model_path): 241 | os.makedirs(pretrained_model_path) 242 | snapshot_path = os.path.join(data_path, 'models', 'snapshots') 243 | if not os.path.exists(snapshot_path): 244 | os.makedirs(snapshot_path) 245 | log_path = os.path.join(data_path, 'logs') 246 | if not os.path.exists(log_path): 247 | os.makedirs(log_path) 248 | result_path = os.path.join(data_path, 'results') 249 | if not os.path.exists(result_path): 250 | os.makedirs(result_path) 251 | -------------------------------------------------------------------------------- /coarse_fusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import time 5 | from utils import * 6 | 7 | data_path = sys.argv[1] 8 | current_fold = int(sys.argv[2]) 9 | organ_number = int(sys.argv[3]) 10 | low_range = int(sys.argv[4]) 11 | high_range = int(sys.argv[5]) 12 | slice_threshold = float(sys.argv[6]) 13 | slice_thickness = int(sys.argv[7]) 14 | organ_ID = int(sys.argv[8]) 15 | GPU_ID = int(sys.argv[9]) 16 | learning_rate1 = float(sys.argv[10]) 17 | learning_rate_m1 = int(sys.argv[11]) 18 | learning_rate2 = float(sys.argv[12]) 19 | learning_rate_m2 = int(sys.argv[13]) 20 | crop_margin = int(sys.argv[14]) 21 | result_path = os.path.join(result_path, 'coarse_testing_' + \ 22 | sys.argv[10] + 'x' + str(learning_rate_m1) + ',' + str(crop_margin)) 23 | epoch = 'e' + sys.argv[15] + sys.argv[16] + sys.argv[17] + sys.argv[18] 24 | epoch_list = [epoch] 25 | threshold = float(sys.argv[19]) 26 | timestamp = {} 27 | 28 | timestamp['X'] = sys.argv[20] 29 | timestamp['Y'] = sys.argv[21] 30 | timestamp['Z'] = sys.argv[22] 31 | 32 | volume_list = open(testing_set_filename(current_fold), 'r').read().splitlines() 33 | while volume_list[len(volume_list) - 1] == '': 34 | volume_list.pop() 35 | 36 | result_name_ = {} 37 | result_directory_ = {} 38 | for plane in ['X', 'Y', 'Z']: 39 | result_name__ = result_name_from_timestamp(result_path, current_fold, \ 40 | plane, 'I', slice_thickness, organ_ID, volume_list, timestamp[plane]) 41 | if result_name__ == '': 42 | exit(' Error: no valid result directories are detected!') 43 | result_directory__ = os.path.join(result_path, result_name__, 'volumes') 44 | print(' Result directory for plane ' + plane + ': ' + result_directory__ + ' .') 45 | if result_name__.startswith('FD'): 46 | index_ = result_name__.find(':') 47 | result_name__ = result_name__[index_ + 1: ] 48 | result_name_[plane] = result_name__ 49 | result_directory_[plane] = result_directory__ 50 | 51 | DSC_X = np.zeros((len(volume_list))) 52 | DSC_Y = np.zeros((len(volume_list))) 53 | DSC_Z = np.zeros((len(volume_list))) 54 | DSC_F1 = np.zeros((len(volume_list))) 55 | DSC_F2 = np.zeros((len(volume_list))) 56 | DSC_F3 = np.zeros((len(volume_list))) 57 | DSC_F1P = np.zeros((len(volume_list))) 58 | DSC_F2P = np.zeros((len(volume_list))) 59 | DSC_F3P = np.zeros((len(volume_list))) 60 | 61 | result_name = 'FD' + str(current_fold) + ':' + 'fusion:' + result_name_['X'] + ',' + \ 62 | result_name_['Y'] + ',' + result_name_['Z'] + '_' + epoch + ',' + str(threshold) 63 | result_directory = os.path.join(result_path, result_name, 'volumes') 64 | if not os.path.exists(result_directory): 65 | os.makedirs(result_directory) 66 | 67 | result_file = os.path.join(result_path, result_name, 'results.txt') 68 | output = open(result_file, 'w') 69 | output.close() 70 | output = open(result_file, 'a+') 71 | output.write('Fusing results of ' + str(len(epoch_list)) + ' snapshots:\n') 72 | output.close() 73 | 74 | for i in range(len(volume_list)): 75 | start_time = time.time() 76 | print('Testing ' + str(i + 1) + ' out of ' + str(len(volume_list)) + ' testcases.') 77 | output = open(result_file, 'a+') 78 | output.write(' Testcase ' + str(i + 1) + ':\n') 79 | output.close() 80 | s = volume_list[i].split(' ') 81 | label = np.load(s[2]) 82 | label = is_organ(label, organ_ID).astype(np.uint8) 83 | 84 | for plane in ['X', 'Y', 'Z']: 85 | volume_file = volume_filename_fusion(result_directory, plane, i) 86 | pred = np.zeros(label.shape, dtype = np.float32) 87 | for t in range(len(epoch_list)): 88 | volume_file_ = volume_filename_testing(result_directory_[plane], epoch_list[t], i) 89 | pred += np.load(volume_file_)['volume'] 90 | pred_ = (pred >= threshold * 255 * len(epoch_list)) 91 | if not os.path.isfile(volume_file): 92 | np.savez_compressed(volume_file, volume = pred_) 93 | DSC_, inter_sum, pred_sum, label_sum = DSC_computation(label, pred_) 94 | print(' DSC_' + plane + ' = 2 * ' + str(inter_sum) + ' / (' + \ 95 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(DSC_) + ' .') 96 | output = open(result_file, 'a+') 97 | output.write(' DSC_' + plane + ' = 2 * ' + str(inter_sum) + ' / (' + \ 98 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(DSC_) + ' .\n') 99 | output.close() 100 | if pred_sum == 0 and label_sum == 0: 101 | DSC_ = 0 102 | pred /= (255 * len(epoch_list)) 103 | if plane == 'X': 104 | pred_X = pred 105 | DSC_X[i] = DSC_ 106 | elif plane == 'Y': 107 | pred_Y = pred 108 | DSC_Y[i] = DSC_ 109 | elif plane == 'Z': 110 | pred_Z = pred 111 | DSC_Z[i] = DSC_ 112 | 113 | volume_file_F1 = volume_filename_fusion(result_directory, 'F1', i) 114 | volume_file_F2 = volume_filename_fusion(result_directory, 'F2', i) 115 | volume_file_F3 = volume_filename_fusion(result_directory, 'F3', i) 116 | 117 | if not os.path.isfile(volume_file_F1) or not os.path.isfile(volume_file_F2) or \ 118 | not os.path.isfile(volume_file_F3): 119 | pred_total = pred_X + pred_Y + pred_Z 120 | if os.path.isfile(volume_file_F1): 121 | pred_F1 = np.load(volume_file_F1)['volume'].astype(np.uint8) 122 | else: 123 | pred_F1 = (pred_total >= 0.5).astype(np.uint8) 124 | np.savez_compressed(volume_file_F1, volume = pred_F1) 125 | DSC_F1[i], inter_sum, pred_sum, label_sum = DSC_computation(label, pred_F1) 126 | print(' DSC_F1 = 2 * ' + str(inter_sum) + ' / (' + str(pred_sum) + ' + ' \ 127 | + str(label_sum) + ') = ' + str(DSC_F1[i]) + ' .') 128 | output = open(result_file, 'a+') 129 | output.write(' DSC_F1 = 2 * ' + str(inter_sum) + ' / (' + \ 130 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(DSC_F1[i]) + ' .\n') 131 | output.close() 132 | if pred_sum == 0 and label_sum == 0: 133 | DSC_F1[i] = 0 134 | 135 | if os.path.isfile(volume_file_F2): 136 | pred_F2 = np.load(volume_file_F2)['volume'].astype(np.uint8) 137 | else: 138 | pred_F2 = (pred_total >= 1.5).astype(np.uint8) 139 | np.savez_compressed(volume_file_F2, volume = pred_F2) 140 | DSC_F2[i], inter_sum, pred_sum, label_sum = DSC_computation(label, pred_F2) 141 | print(' DSC_F2 = 2 * ' + str(inter_sum) + ' / (' + str(pred_sum) + ' + ' + \ 142 | str(label_sum) + ') = ' + str(DSC_F2[i]) + ' .') 143 | output = open(result_file, 'a+') 144 | output.write(' DSC_F2 = 2 * ' + str(inter_sum) + ' / (' + \ 145 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(DSC_F2[i]) + ' .\n') 146 | output.close() 147 | if pred_sum == 0 and label_sum == 0: 148 | DSC_F2[i] = 0 149 | 150 | if os.path.isfile(volume_file_F3): 151 | pred_F3 = np.load(volume_file_F3)['volume'].astype(np.uint8) 152 | else: 153 | pred_F3 = (pred_total >= 2.5).astype(np.uint8) 154 | np.savez_compressed(volume_file_F3, volume = pred_F3) 155 | DSC_F3[i], inter_sum, pred_sum, label_sum = DSC_computation(label, pred_F3) 156 | print(' DSC_F3 = 2 * ' + str(inter_sum) + ' / (' + str(pred_sum) + ' + ' + \ 157 | str(label_sum) + ') = ' + str(DSC_F3[i]) + ' .') 158 | output = open(result_file, 'a+') 159 | output.write(' DSC_F3 = 2 * ' + str(inter_sum) + ' / (' + \ 160 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(DSC_F3[i]) + ' .\n') 161 | output.close() 162 | if pred_sum == 0 and label_sum == 0: 163 | DSC_F3[i] = 0 164 | 165 | volume_file_F1P = volume_filename_fusion(result_directory, 'F1P', i) 166 | volume_file_F2P = volume_filename_fusion(result_directory, 'F2P', i) 167 | volume_file_F3P = volume_filename_fusion(result_directory, 'F3P', i) 168 | S = pred_F3 169 | if (S.sum() == 0): 170 | S = pred_F2 171 | if (S.sum() == 0): 172 | S = pred_F1 173 | 174 | if os.path.isfile(volume_file_F1P): 175 | pred_F1P = np.load(volume_file_F1P)['volume'].astype(np.uint8) 176 | else: 177 | pred_F1P = post_processing(pred_F1, S, 0.5, organ_ID) 178 | np.savez_compressed(volume_file_F1P, volume = pred_F1P) 179 | DSC_F1P[i], inter_sum, pred_sum, label_sum = DSC_computation(label, pred_F1P) 180 | print(' DSC_F1P = 2 * ' + str(inter_sum) + ' / (' + str(pred_sum) + ' + ' + \ 181 | str(label_sum) + ') = ' + str(DSC_F1P[i]) + ' .') 182 | output = open(result_file, 'a+') 183 | output.write(' DSC_F1P = 2 * ' + str(inter_sum) + ' / (' + \ 184 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(DSC_F1P[i]) + ' .\n') 185 | output.close() 186 | if pred_sum == 0 and label_sum == 0: 187 | DSC_F1P[i] = 0 188 | 189 | if os.path.isfile(volume_file_F2P): 190 | pred_F2P = np.load(volume_file_F2P)['volume'].astype(np.uint8) 191 | else: 192 | pred_F2P = post_processing(pred_F2, S, 0.5, organ_ID) 193 | np.savez_compressed(volume_file_F2P, volume = pred_F2P) 194 | DSC_F2P[i], inter_sum, pred_sum, label_sum = DSC_computation(label, pred_F2P) 195 | print(' DSC_F2P = 2 * ' + str(inter_sum) + ' / (' + str(pred_sum) + ' + ' + \ 196 | str(label_sum) + ') = ' + str(DSC_F2P[i]) + ' .') 197 | output = open(result_file, 'a+') 198 | output.write(' DSC_F2P = 2 * ' + str(inter_sum) + ' / (' + \ 199 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(DSC_F2P[i]) + ' .\n') 200 | output.close() 201 | if pred_sum == 0 and label_sum == 0: 202 | DSC_F2P[i] = 0 203 | 204 | if os.path.isfile(volume_file_F3P): 205 | pred_F3P = np.load(volume_file_F3P)['volume'].astype(np.uint8) 206 | else: 207 | pred_F3P = post_processing(pred_F3, S, 0.5, organ_ID) 208 | np.savez_compressed(volume_file_F3P, volume = pred_F3P) 209 | DSC_F3P[i], inter_sum, pred_sum, label_sum = DSC_computation(label, pred_F3P) 210 | print(' DSC_F3P = 2 * ' + str(inter_sum) + ' / (' + str(pred_sum) + ' + ' + \ 211 | str(label_sum) + ') = ' + str(DSC_F3P[i]) + ' .') 212 | output = open(result_file, 'a+') 213 | output.write(' DSC_F3P = 2 * ' + str(inter_sum) + ' / (' + \ 214 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(DSC_F3P[i]) + ' .\n') 215 | output.close() 216 | if pred_sum == 0 and label_sum == 0: 217 | DSC_F3P[i] = 0 218 | 219 | pred_X = None 220 | pred_Y = None 221 | pred_Z = None 222 | pred_F1 = None 223 | pred_F2 = None 224 | pred_F3 = None 225 | pred_F1P = None 226 | pred_F2P = None 227 | pred_F3P = None 228 | 229 | output = open(result_file, 'a+') 230 | print('Average DSC_X = ' + str(np.mean(DSC_X)) + ' .') 231 | output.write('Average DSC_X = ' + str(np.mean(DSC_X)) + ' .\n') 232 | print('Average DSC_Y = ' + str(np.mean(DSC_Y)) + ' .') 233 | output.write('Average DSC_Y = ' + str(np.mean(DSC_Y)) + ' .\n') 234 | print('Average DSC_Z = ' + str(np.mean(DSC_Z)) + ' .') 235 | output.write('Average DSC_Z = ' + str(np.mean(DSC_Z)) + ' .\n') 236 | print('Average DSC_F1 = ' + str(np.mean(DSC_F1)) + ' .') 237 | output.write('Average DSC_F1 = ' + str(np.mean(DSC_F1)) + ' .\n') 238 | print('Average DSC_F2 = ' + str(np.mean(DSC_F2)) + ' .') 239 | output.write('Average DSC_F2 = ' + str(np.mean(DSC_F2)) + ' .\n') 240 | print('Average DSC_F3 = ' + str(np.mean(DSC_F3)) + ' .') 241 | output.write('Average DSC_F3 = ' + str(np.mean(DSC_F3)) + ' .\n') 242 | print('Average DSC_F1P = ' + str(np.mean(DSC_F1P)) + ' .') 243 | output.write('Average DSC_F1P = ' + str(np.mean(DSC_F1P)) + ' .\n') 244 | print('Average DSC_F2P = ' + str(np.mean(DSC_F2P)) + ' .') 245 | output.write('Average DSC_F2P = ' + str(np.mean(DSC_F2P)) + ' .\n') 246 | print('Average DSC_F3P = ' + str(np.mean(DSC_F3P)) + ' .') 247 | output.write('Average DSC_F3P = ' + str(np.mean(DSC_F3P)) + ' .\n') 248 | output.close() 249 | print('The fusion process is finished.') -------------------------------------------------------------------------------- /coarse2fine_testing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import time 5 | from utils import * 6 | from model import * 7 | # import pdb 8 | 9 | data_path = sys.argv[1] 10 | current_fold = int(sys.argv[2]) 11 | organ_number = int(sys.argv[3]) 12 | low_range = int(sys.argv[4]) 13 | high_range = int(sys.argv[5]) 14 | slice_threshold = float(sys.argv[6]) 15 | slice_thickness = int(sys.argv[7]) 16 | organ_ID = int(sys.argv[8]) 17 | GPU_ID = int(sys.argv[9]) 18 | learning_rate1 = float(sys.argv[10]) 19 | learning_rate_m1 = int(sys.argv[11]) 20 | learning_rate2 = float(sys.argv[12]) 21 | learning_rate_m2 = int(sys.argv[13]) 22 | crop_margin = int(sys.argv[14]) 23 | 24 | fine_snapshot_path = os.path.join(snapshot_path, 'SIJ_training_' + \ 25 | sys.argv[10] + 'x' + str(learning_rate_m1) + ',' + str(crop_margin)) 26 | coarse_result_path = os.path.join(result_path, 'coarse_testing_' + \ 27 | sys.argv[10] + 'x' + str(learning_rate_m1) + ',' + str(crop_margin)) 28 | coarse2fine_result_path = os.path.join(result_path, 'coarse2fine_testing_' + \ 29 | sys.argv[10] + 'x' + str(learning_rate_m1) + ',' + str(crop_margin)) 30 | 31 | epoch = 'e' + sys.argv[15] + sys.argv[16] + sys.argv[17] + sys.argv[18] 32 | epoch_list = [epoch] 33 | coarse_threshold = float(sys.argv[19]) 34 | fine_threshold = float(sys.argv[20]) 35 | max_rounds = int(sys.argv[21]) 36 | timestamp = {} 37 | timestamp['X'] = sys.argv[22] 38 | timestamp['Y'] = sys.argv[23] 39 | timestamp['Z'] = sys.argv[24] 40 | 41 | 42 | volume_list = open(testing_set_filename(current_fold), 'r').read().splitlines() 43 | while volume_list[len(volume_list) - 1] == '': 44 | volume_list.pop() 45 | 46 | print('Looking for snapshots:') 47 | fine_snapshot_ = {} 48 | fine_snapshot_name_ = {} 49 | for plane in ['X', 'Y', 'Z']: 50 | #pdb.set_trace() 51 | fine_snapshot_name = snapshot_name_from_timestamp(fine_snapshot_path, \ 52 | current_fold, plane, 'J', slice_thickness, organ_ID, timestamp[plane]) 53 | if fine_snapshot_name == '': 54 | exit(' Error: no valid snapshot directories are detected!') 55 | fine_snapshot_directory = os.path.join(fine_snapshot_path, fine_snapshot_name) 56 | print(' Snapshot directory 1 for plane ' + plane + ': ' + fine_snapshot_directory + ' .') 57 | fine_snapshot = [fine_snapshot_directory] 58 | print(' ' + str(len(fine_snapshot)) + ' snapshots are to be evaluated.') 59 | for t in range(len(fine_snapshot)): 60 | print(' Snapshot #' + str(t + 1) + ': ' + fine_snapshot[t] + ' .') 61 | fine_snapshot_[plane] = fine_snapshot 62 | fine_snapshot_name = fine_snapshot_name.split(':')[1] 63 | fine_snapshot_name_[plane] = fine_snapshot_name.split('.')[0] 64 | 65 | print('In the coarse stage:') 66 | coarse_result_name_ = {} 67 | coarse_result_directory_ = {} 68 | for plane in ['X', 'Y', 'Z']: 69 | #pdb.set_trace() 70 | coarse_result_name__ = result_name_from_timestamp(coarse_result_path, current_fold, \ 71 | plane, 'I', slice_thickness, organ_ID, volume_list, timestamp[plane]) 72 | if coarse_result_name__ == '': 73 | exit(' Error: no valid result directories are detected!') 74 | coarse_result_directory__ = os.path.join(coarse_result_path, coarse_result_name__, 'volumes') 75 | print(' Result directory for plane ' + plane + ': ' + coarse_result_directory__ + ' .') 76 | if coarse_result_name__.startswith('FD'): 77 | index_ = coarse_result_name__.find(':') 78 | coarse_result_name__ = coarse_result_name__[index_ + 1: ] 79 | coarse_result_name_[plane] = coarse_result_name__ 80 | coarse_result_directory_[plane] = coarse_result_directory__ 81 | 82 | coarse2fine_result_name = 'FD' + str(current_fold) + ':' + \ 83 | fine_snapshot_name_['X'] + ',' + \ 84 | fine_snapshot_name_['Y'] + ',' + \ 85 | fine_snapshot_name_['Z'] + ':' + \ 86 | epoch + '_' + str(coarse_threshold) + '_' + \ 87 | str(fine_threshold) + ',' + str(max_rounds) 88 | coarse2fine_result_directory = os.path.join( \ 89 | coarse2fine_result_path, coarse2fine_result_name, 'volumes') 90 | 91 | finished = np.ones((len(volume_list)), dtype = np.int) 92 | for i in range(len(volume_list)): 93 | for r in range(max_rounds + 1): 94 | volume_file = volume_filename_coarse2fine(coarse2fine_result_directory, r, i) 95 | if not os.path.isfile(volume_file): 96 | finished[i] = 0 97 | break 98 | finished_all = (finished.sum() == len(volume_list)) 99 | if finished_all: 100 | exit() 101 | 102 | os.environ["CUDA_VISIBLE_DEVICES"]= str(GPU_ID) 103 | net_ = {} 104 | for plane in ['X', 'Y', 'Z']: 105 | net_[plane] = [] 106 | for t in range(len(epoch_list)): 107 | net = RSTN(crop_margin=crop_margin, TEST='F').cuda() 108 | net.load_state_dict(torch.load(fine_snapshot_[plane][t])) 109 | net.eval() 110 | net_[plane].append(net) 111 | 112 | DSC = np.zeros((max_rounds + 1, len(volume_list))) 113 | DSC_90 = np.zeros((len(volume_list))) 114 | DSC_95 = np.zeros((len(volume_list))) 115 | DSC_98 = np.zeros((len(volume_list))) 116 | DSC_99 = np.zeros((len(volume_list))) 117 | coarse2fine_result_directory = os.path.join(coarse2fine_result_path, \ 118 | coarse2fine_result_name, 'volumes') 119 | if not os.path.exists(coarse2fine_result_directory): 120 | os.makedirs(coarse2fine_result_directory) 121 | coarse2fine_result_file = os.path.join(coarse2fine_result_path, \ 122 | coarse2fine_result_name, 'results.txt') 123 | output = open(coarse2fine_result_file, 'w') 124 | output.close() 125 | output = open(coarse2fine_result_file, 'a+') 126 | output.write('Fusing results of ' + str(len(epoch_list)) + ' snapshots:\n') 127 | output.close() 128 | 129 | for i in range(len(volume_list)): 130 | start_time = time.time() 131 | print('Testing ' + str(i + 1) + ' out of ' + str(len(volume_list)) + ' testcases.') 132 | output = open(coarse2fine_result_file, 'a+') 133 | output.write(' Testcase ' + str(i + 1) + ':\n') 134 | output.close() 135 | s = volume_list[i].split(' ') 136 | label = np.load(s[2]) 137 | label = is_organ(label, organ_ID).astype(np.uint8) 138 | finished = True 139 | for r in range(max_rounds + 1): 140 | volume_file = volume_filename_coarse2fine(coarse2fine_result_directory, r, i) 141 | if not os.path.isfile(volume_file): 142 | finished = False 143 | break 144 | if not finished: 145 | image = np.load(s[1]).astype(np.float32) 146 | np.minimum(np.maximum(image, low_range, image), high_range, image) 147 | image -= low_range 148 | image /= (high_range - low_range) 149 | imageX = image 150 | imageY = image.transpose(1, 0, 2).copy() 151 | imageZ = image.transpose(2, 0, 1).copy() 152 | print(' Data loading is finished: ' + str(time.time() - start_time) + ' second(s) elapsed.') 153 | for r in range(max_rounds + 1): 154 | print(' Iteration round ' + str(r) + ':') 155 | volume_file = volume_filename_coarse2fine(coarse2fine_result_directory, r, i) 156 | if not finished: 157 | if r == 0: # coarse majority voting 158 | pred_ = np.zeros(label.shape, dtype = np.float32) 159 | for plane in ['X', 'Y', 'Z']: 160 | for t in range(len(epoch_list)): 161 | volume_file_ = volume_filename_testing( \ 162 | coarse_result_directory_[plane], epoch_list[t], i) 163 | volume_data = np.load(volume_file_) 164 | pred_ += volume_data['volume'] 165 | pred_ /= (255 * len(epoch_list) * 3) 166 | print(' Fusion is finished: ' + \ 167 | str(time.time() - start_time) + ' second(s) elapsed.') 168 | else: 169 | mask_sumX = np.sum(mask, axis = (1, 2)) 170 | if mask_sumX.sum() == 0: 171 | continue 172 | mask_sumY = np.sum(mask, axis = (0, 2)) 173 | mask_sumZ = np.sum(mask, axis = (0, 1)) 174 | scoreX = score 175 | scoreY = score.transpose(1, 0, 2).copy() 176 | scoreZ = score.transpose(2, 0, 1).copy() 177 | maskX = mask 178 | maskY = mask.transpose(1, 0, 2).copy() 179 | maskZ = mask.transpose(2, 0, 1).copy() 180 | pred_ = np.zeros(label.shape, dtype = np.float32) 181 | for plane in ['X', 'Y', 'Z']: 182 | for t in range(len(epoch_list)): 183 | net = net_[plane][t] 184 | minR = 0 185 | if plane == 'X': 186 | maxR = label.shape[0] 187 | shape_ = (1, 3, image.shape[1], image.shape[2]) 188 | pred__ = np.zeros((image.shape[0], image.shape[1], image.shape[2]), \ 189 | dtype = np.float32) 190 | elif plane == 'Y': 191 | maxR = label.shape[1] 192 | shape_ = (1, 3, image.shape[0], image.shape[2]) 193 | pred__ = np.zeros((image.shape[1], image.shape[0], image.shape[2]), \ 194 | dtype = np.float32) 195 | elif plane == 'Z': 196 | maxR = label.shape[2] 197 | shape_ = (1, 3, image.shape[0], image.shape[1]) 198 | pred__ = np.zeros((image.shape[2], image.shape[0], image.shape[1]), \ 199 | dtype = np.float32) 200 | for j in range(minR, maxR): 201 | if slice_thickness == 1: 202 | sID = [j, j, j] 203 | elif slice_thickness == 3: 204 | sID = [max(minR, j - 1), j, min(maxR - 1, j + 1)] 205 | if (plane == 'X' and mask_sumX[sID].sum() == 0) or \ 206 | (plane == 'Y' and mask_sumY[sID].sum() == 0) or \ 207 | (plane == 'Z' and mask_sumZ[sID].sum() == 0): 208 | continue 209 | if plane == 'X': 210 | image_ = imageX[sID, :, :] 211 | score_ = scoreX[sID, :, :] 212 | mask_ = maskX[sID, :, :] 213 | elif plane == 'Y': 214 | image_ = imageY[sID, :, :] 215 | score_ = scoreY[sID, :, :] 216 | mask_ = maskY[sID, :, :] 217 | elif plane == 'Z': 218 | image_ = imageZ[sID, :, :] 219 | score_ = scoreZ[sID, :, :] 220 | mask_ = maskZ[sID, :, :] 221 | 222 | image_ = image_.reshape(1, 3, image_.shape[1], image_.shape[2]) 223 | score_ = score_.reshape(1, 3, score_.shape[1], score_.shape[2]) 224 | mask_ = mask_.reshape(1, 3, mask_.shape[1], mask_.shape[2]) 225 | image_ = torch.from_numpy(image_).cuda().float() 226 | score_ = torch.from_numpy(score_).cuda().float() 227 | mask_ = torch.from_numpy(mask_).cuda().float() 228 | out = net(image_,1, score=score_, mask=mask_).data.cpu().numpy()[0, :, :, :] 229 | 230 | if slice_thickness == 1: 231 | pred__[j, :, :] = out 232 | elif slice_thickness == 3: 233 | if j == minR: 234 | pred__[minR: minR + 2, :, :] += out[1: 3, :, :] 235 | elif j == maxR - 1: 236 | pred__[maxR - 2: maxR, :, :] += out[0: 2, :, :] 237 | else: 238 | pred__[j - 1: j + 2, :, :] += out 239 | if slice_thickness == 3: 240 | pred__[minR, :, :] /= 2 241 | pred__[minR + 1: maxR - 1, :, :] /= 3 242 | pred__[maxR - 1, :, :] /= 2 243 | print(' Testing on plane ' + plane + ' and snapshot ' + str(t + 1) + \ 244 | ' is finished: ' + str(time.time() - start_time) + \ 245 | ' second(s) elapsed.') 246 | if plane == 'X': 247 | pred_ += pred__ 248 | elif plane == 'Y': 249 | pred_ += pred__.transpose(1, 0, 2) 250 | elif plane == 'Z': 251 | pred_ += pred__.transpose(1, 2, 0) 252 | pred_ /= (len(epoch_list) * 3) 253 | print(' Testing is finished: ' + \ 254 | str(time.time() - start_time) + ' second(s) elapsed.') 255 | 256 | pred = (pred_ >= fine_threshold).astype(np.uint8) 257 | if r > 0: 258 | pred = post_processing(pred, pred, 0.5, organ_ID) 259 | np.savez_compressed(volume_file, volume = pred) 260 | else: 261 | volume_data = np.load(volume_file) 262 | pred = volume_data['volume'].astype(np.uint8) 263 | print(' Testing result is loaded: ' + \ 264 | str(time.time() - start_time) + ' second(s) elapsed.') 265 | 266 | DSC[r, i], inter_sum, pred_sum, label_sum = DSC_computation(label, pred) 267 | print(' DSC = 2 * ' + str(inter_sum) + ' / (' + str(pred_sum) + ' + ' + \ 268 | str(label_sum) + ') = ' + str(DSC[r, i]) + ' .') 269 | output = open(coarse2fine_result_file, 'a+') 270 | output.write(' Round ' + str(r) + ', ' + 'DSC = 2 * ' + str(inter_sum) + ' / (' + \ 271 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(DSC[r, i]) + ' .\n') 272 | output.close() 273 | 274 | if pred_sum == 0 and label_sum == 0: 275 | DSC[r, i] = 0 276 | if r > 0: 277 | inter_DSC, inter_sum, pred_sum, label_sum = DSC_computation(mask, pred) 278 | if pred_sum == 0 and label_sum == 0: 279 | inter_DSC = 1 280 | print(' Inter-iteration DSC = 2 * ' + str(inter_sum) + ' / (' + \ 281 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(inter_DSC) + ' .') 282 | output = open(coarse2fine_result_file, 'a+') 283 | output.write(' Inter-iteration DSC = 2 * ' + str(inter_sum) + ' / (' + \ 284 | str(pred_sum) + ' + ' + str(label_sum) + ') = ' + str(inter_DSC) + ' .\n') 285 | output.close() 286 | if DSC_90[i] == 0 and (r == max_rounds or inter_DSC >= 0.90): 287 | DSC_90[i] = DSC[r, i] 288 | if DSC_95[i] == 0 and (r == max_rounds or inter_DSC >= 0.95): 289 | DSC_95[i] = DSC[r, i] 290 | if DSC_98[i] == 0 and (r == max_rounds or inter_DSC >= 0.98): 291 | DSC_98[i] = DSC[r, i] 292 | if DSC_99[i] == 0 and (r == max_rounds or inter_DSC >= 0.99): 293 | DSC_99[i] = DSC[r, i] 294 | if r <= max_rounds: 295 | if not finished: 296 | score = pred_ # [0,1] 297 | mask = pred # {0,1} after postprocessing 298 | 299 | for r in range(max_rounds + 1): 300 | print('Round ' + str(r) + ', ' + 'Average DSC = ' + str(np.mean(DSC[r, :])) + ' .') 301 | output = open(coarse2fine_result_file, 'a+') 302 | output.write('Round ' + str(r) + ', ' + 'Average DSC = ' + str(np.mean(DSC[r, :])) + ' .\n') 303 | output.close() 304 | 305 | print('DSC threshold = 0.90, ' + 'Average DSC = ' + str(np.mean(DSC_90)) + ' std = ' + str(np.std(DSC_90)) + ' .') 306 | print('DSC threshold = 0.95, ' + 'Average DSC = ' + str(np.mean(DSC_95)) + ' std = ' + str(np.std(DSC_95)) + ' .') 307 | print('DSC threshold = 0.98, ' + 'Average DSC = ' + str(np.mean(DSC_98)) + ' std = ' + str(np.std(DSC_98)) + ' .') 308 | print('DSC threshold = 0.99, ' + 'Average DSC = ' + str(np.mean(DSC_99)) + ' std = ' + str(np.std(DSC_99)) + ' .') 309 | output = open(coarse2fine_result_file, 'a+') 310 | output.write('DSC threshold = 0.90, ' + 'Average DSC = ' + str(np.mean(DSC_90)) + ' std = ' + str(np.std(DSC_90)) + ' .\n') 311 | output.write('DSC threshold = 0.95, ' + 'Average DSC = ' + str(np.mean(DSC_95)) + ' std = ' + str(np.std(DSC_95)) + ' .\n') 312 | output.write('DSC threshold = 0.98, ' + 'Average DSC = ' + str(np.mean(DSC_98)) + ' std = ' + str(np.std(DSC_98)) + ' .\n') 313 | output.write('DSC threshold = 0.99, ' + 'Average DSC = ' + str(np.mean(DSC_99)) + ' std = ' + str(np.std(DSC_99)) + ' .\n') 314 | output.close() 315 | print('The coarse-to-fine testing process is finished.') -------------------------------------------------------------------------------- /f2.sh: -------------------------------------------------------------------------------- 1 | #################################################################################################### 2 | # RSTN: Recurrent Saliency Transformation Network for organ segmentation framework # 3 | # This is PyTorch 0.4.0 Python 3.6 verison of OrganSegRSTN in CAFFE Python 2.7 . # 4 | # Author: Tianwei Ni, Huangjie Zheng, Lingxi Xie. # 5 | # # 6 | # If you use our codes, please cite our paper accordingly: # 7 | # Qihang Yu, Lingxi Xie, Yan Wang, Yuyin Zhou, Elliot K. Fishman, Alan L. Yuille, # 8 | # "Recurrent Saliency Transformation Network: # 9 | # Incorporating Multi-Stage Visual Cues for Small Organ Segmentation", # 10 | # in IEEE Conference on Computer Vision and Pattern Recognition, 2018. # 11 | # # 12 | # NOTE: this program can be used for multi-organ segmentation. # 13 | # Please also refer to its previous version, OrganSegC2F. # 14 | #################################################################################################### 15 | 16 | #################################################################################################### 17 | # variables for conveniencer 18 | CURRENT_ORGAN_ID=1 19 | CURRENT_PLANE=A 20 | CURRENT_FOLD=2 21 | CURRENT_GPU=0 22 | 23 | #################################################################################################### 24 | # turn on these switches to execute each module 25 | ENABLE_INITIALIZATION=0 26 | ENABLE_TRAINING=1 27 | ENABLE_COARSE_TESTING=1 28 | ENABLE_COARSE_FUSION=1 29 | ENABLE_ORACLE_TESTING=1 30 | ENABLE_ORACLE_FUSION=1 31 | ENABLE_COARSE2FINE_TESTING=1 32 | # training settings: X|Y|Z 33 | TRAINING_ORGAN_ID=$CURRENT_ORGAN_ID 34 | TRAINING_PLANE=$CURRENT_PLANE 35 | TRAINING_GPU=$CURRENT_GPU 36 | # coarse_testing settings: X|Y|Z, before this, coarse-scaled models shall be ready 37 | COARSE_TESTING_ORGAN_ID=$CURRENT_ORGAN_ID 38 | COARSE_TESTING_PLANE=$CURRENT_PLANE 39 | COARSE_TESTING_GPU=$CURRENT_GPU 40 | # coarse_fusion settings: before this, coarse-scaled results on 3 views shall be ready 41 | COARSE_FUSION_ORGAN_ID=$CURRENT_ORGAN_ID 42 | # oracle_testing settings: X|Y|Z, before this, fine-scaled models shall be ready 43 | ORACLE_TESTING_ORGAN_ID=$CURRENT_ORGAN_ID 44 | ORACLE_TESTING_PLANE=$CURRENT_PLANE 45 | ORACLE_TESTING_GPU=$CURRENT_GPU 46 | # oracle_fusion settings: before this, fine-scaled results on 3 views shall be ready 47 | ORACLE_FUSION_ORGAN_ID=$CURRENT_ORGAN_ID 48 | # fine_testing settings: before this, both coarse-scaled and fine-scaled models shall be ready 49 | COARSE2FINE_TESTING_ORGAN_ID=$CURRENT_ORGAN_ID 50 | COARSE2FINE_TESTING_GPU=$CURRENT_GPU 51 | 52 | #################################################################################################### 53 | # defining the root path which stores image and label data 54 | DATA_PATH='/data/pan/' 55 | 56 | #################################################################################################### 57 | # data initialization: only needs to be run once 58 | # variables 59 | ORGAN_NUMBER=1 60 | FOLDS=4 61 | LOW_RANGE=-100 62 | HIGH_RANGE=240 63 | # init.py : data_path, organ_number, folds, low_range, high_range 64 | if [ "$ENABLE_INITIALIZATION" = "1" ] 65 | then 66 | python init.py \ 67 | $DATA_PATH $ORGAN_NUMBER $FOLDS $LOW_RANGE $HIGH_RANGE 68 | fi 69 | 70 | #################################################################################################### 71 | # the individual and joint training processes 72 | # variables 73 | SLICE_THRESHOLD=0.98 74 | SLICE_THICKNESS=3 75 | LEARNING_RATE1=1e-5 76 | LEARNING_RATE2=1e-5 77 | LEARNING_RATE_M1=10 78 | LEARNING_RATE_M2=10 79 | TRAINING_MARGIN=20 80 | TRAINING_PROB=0.5 81 | TRAINING_SAMPLE_BATCH=1 82 | TRAINING_EPOCH_S=6 83 | TRAINING_EPOCH_I=6 84 | TRAINING_EPOCH_J=6 85 | LR_DECAY_EPOCH_J_STEP=2 86 | if [ "$ENABLE_TRAINING" = "1" ] 87 | then 88 | TRAINING_TIMESTAMP=$(date +'%Y%m%d_%H%M%S') 89 | else 90 | TRAINING_TIMESTAMP='20220221_210935' 91 | fi 92 | # training.py : data_path, current_fold, organ_number, low_range, high_range, 93 | # slice_threshold, slice_thickness, organ_ID, plane, GPU_ID, 94 | # learning_rate1, learning_rate2 (not used), margin, prob, sample_batch, 95 | # step, ·max_iterations1, max_iterations2 (not used), fraction, timestamp 96 | if [ "$ENABLE_TRAINING" = "1" ] 97 | then 98 | if [ "$TRAINING_PLANE" = "X" ] || [ "$TRAINING_PLANE" = "A" ] 99 | then 100 | TRAINING_MODELNAME=X${SLICE_THICKNESS}_${TRAINING_ORGAN_ID} 101 | TRAINING_LOG=${DATA_PATH}logs/FD${CURRENT_FOLD}:${TRAINING_MODELNAME}_${TRAINING_TIMESTAMP}.txt 102 | python training.py \ 103 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 104 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 105 | $TRAINING_ORGAN_ID X $TRAINING_GPU \ 106 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 107 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 108 | $TRAINING_EPOCH_S $TRAINING_EPOCH_I $TRAINING_EPOCH_J \ 109 | $LR_DECAY_EPOCH_J_STEP $TRAINING_TIMESTAMP 1 2>&1 | tee $TRAINING_LOG 110 | fi 111 | if [ "$TRAINING_PLANE" = "Y" ] || [ "$TRAINING_PLANE" = "A" ] 112 | then 113 | TRAINING_MODELNAME=Y${SLICE_THICKNESS}_${TRAINING_ORGAN_ID} 114 | TRAINING_LOG=${DATA_PATH}logs/FD${CURRENT_FOLD}:${TRAINING_MODELNAME}_${TRAINING_TIMESTAMP}.txt 115 | python training.py \ 116 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 117 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 118 | $TRAINING_ORGAN_ID Y $TRAINING_GPU \ 119 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 120 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 121 | $TRAINING_EPOCH_S $TRAINING_EPOCH_I $TRAINING_EPOCH_J \ 122 | $LR_DECAY_EPOCH_J_STEP $TRAINING_TIMESTAMP 1 2>&1 | tee $TRAINING_LOG 123 | fi 124 | if [ "$TRAINING_PLANE" = "Z" ] || [ "$TRAINING_PLANE" = "A" ] 125 | then 126 | TRAINING_MODELNAME=Z${SLICE_THICKNESS}_${TRAINING_ORGAN_ID} 127 | TRAINING_LOG=${DATA_PATH}logs/FD${CURRENT_FOLD}:${TRAINING_MODELNAME}_${TRAINING_TIMESTAMP}.txt 128 | python training.py \ 129 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 130 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 131 | $TRAINING_ORGAN_ID Z $TRAINING_GPU \ 132 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 133 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 134 | $TRAINING_EPOCH_S $TRAINING_EPOCH_I $TRAINING_EPOCH_J \ 135 | $LR_DECAY_EPOCH_J_STEP $TRAINING_TIMESTAMP 1 2>&1 | tee $TRAINING_LOG 136 | fi 137 | fi 138 | 139 | #################################################################################################### 140 | # the coarse-scaled testing processes 141 | # variables 142 | COARSE_TESTING_EPOCH_S=$TRAINING_EPOCH_S 143 | COARSE_TESTING_EPOCH_I=$TRAINING_EPOCH_I 144 | COARSE_TESTING_EPOCH_J=$TRAINING_EPOCH_J 145 | COARSE_TESTING_EPOCH_STEP=$LR_DECAY_EPOCH_J_STEP 146 | COARSE_TIMESTAMP1=$TRAINING_TIMESTAMP 147 | COARSE_TIMESTAMP2=$TRAINING_TIMESTAMP 148 | # coarse_testing.py : data_path, current_fold, organ_number, low_range, high_range, 149 | # slice_threshold, slice_thickness, organ_ID, plane, GPU_ID, 150 | # learning_rate1, learning_rate2, margin, prob, sample_batch, 151 | # EPOCH_S, EPOCH_I, EPOCH_J, EPOCH_STEP, 152 | # timestamp1, timestamp2 (optional) 153 | if [ "$ENABLE_COARSE_TESTING" = "1" ] 154 | then 155 | if [ "$COARSE_TESTING_PLANE" = "X" ] || [ "$COARSE_TESTING_PLANE" = "A" ] 156 | then 157 | python coarse_testing.py \ 158 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 159 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 160 | $COARSE_TESTING_ORGAN_ID X $COARSE_TESTING_GPU \ 161 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 162 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 163 | $COARSE_TESTING_EPOCH_S $COARSE_TESTING_EPOCH_I \ 164 | $COARSE_TESTING_EPOCH_J $COARSE_TESTING_EPOCH_STEP \ 165 | $COARSE_TIMESTAMP1 $COARSE_TIMESTAMP2 166 | fi 167 | if [ "$COARSE_TESTING_PLANE" = "Y" ] || [ "$COARSE_TESTING_PLANE" = "A" ] 168 | then 169 | python coarse_testing.py \ 170 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 171 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 172 | $COARSE_TESTING_ORGAN_ID Y $COARSE_TESTING_GPU \ 173 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 174 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 175 | $COARSE_TESTING_EPOCH_S $COARSE_TESTING_EPOCH_I \ 176 | $COARSE_TESTING_EPOCH_J $COARSE_TESTING_EPOCH_STEP \ 177 | $COARSE_TIMESTAMP1 $COARSE_TIMESTAMP2 178 | fi 179 | if [ "$COARSE_TESTING_PLANE" = "Z" ] || [ "$COARSE_TESTING_PLANE" = "A" ] 180 | then 181 | python coarse_testing.py \ 182 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 183 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 184 | $COARSE_TESTING_ORGAN_ID Z $COARSE_TESTING_GPU \ 185 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 186 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 187 | $COARSE_TESTING_EPOCH_S $COARSE_TESTING_EPOCH_I \ 188 | $COARSE_TESTING_EPOCH_J $COARSE_TESTING_EPOCH_STEP \ 189 | $COARSE_TIMESTAMP1 $COARSE_TIMESTAMP2 190 | fi 191 | fi 192 | 193 | #################################################################################################### 194 | # the coarse-scaled fusion process 195 | # variables 196 | COARSE_FUSION_EPOCH_S=$TRAINING_EPOCH_S 197 | COARSE_FUSION_EPOCH_I=$TRAINING_EPOCH_I 198 | COARSE_FUSION_EPOCH_J=$TRAINING_EPOCH_J 199 | COARSE_FUSION_EPOCH_STEP=$LR_DECAY_EPOCH_J_STEP 200 | COARSE_FUSION_THRESHOLD=0.5 201 | COARSE_TIMESTAMP1_X=$TRAINING_TIMESTAMP 202 | COARSE_TIMESTAMP1_Y=$TRAINING_TIMESTAMP 203 | COARSE_TIMESTAMP1_Z=$TRAINING_TIMESTAMP 204 | COARSE_TIMESTAMP2_X=$TRAINING_TIMESTAMP 205 | COARSE_TIMESTAMP2_Y=$TRAINING_TIMESTAMP 206 | COARSE_TIMESTAMP2_Z=$TRAINING_TIMESTAMP 207 | # coarse_fusion.py : data_path, current_fold, organ_number, low_range, high_range, 208 | # slice_threshold, slice_thickness, organ_ID, plane, GPU_ID, 209 | # learning_rate1, learning_rate_m1, learning_rate2, learning_rate_m2, margin, 210 | # EPOCH_S, EPOCH_I, EPOCH_J, EPOCH_STEP, threshold, 211 | # timestamp1_X, timestamp1_Y, timestamp1_Z, 212 | # timestamp2_X (optional), timestamp2_Y (optional), timestamp2_Z (optional) 213 | if [ "$ENABLE_COARSE_FUSION" = "1" ] 214 | then 215 | python coarse_fusion.py \ 216 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 217 | $SLICE_THRESHOLD $SLICE_THICKNESS $COARSE_TESTING_ORGAN_ID $COARSE_TESTING_GPU \ 218 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 $TRAINING_MARGIN \ 219 | $COARSE_FUSION_EPOCH_S $COARSE_FUSION_EPOCH_I $COARSE_FUSION_EPOCH_J \ 220 | $COARSE_FUSION_EPOCH_STEP $COARSE_FUSION_THRESHOLD \ 221 | $COARSE_TIMESTAMP1_X $COARSE_TIMESTAMP1_Y $COARSE_TIMESTAMP1_Z \ 222 | $COARSE_TIMESTAMP2_X $COARSE_TIMESTAMP2_Y $COARSE_TIMESTAMP2_Z 223 | fi 224 | 225 | #################################################################################################### 226 | # the oracle testing processes 227 | # variables 228 | ORACLE_TESTING_EPOCH_S=$TRAINING_EPOCH_S 229 | ORACLE_TESTING_EPOCH_I=$TRAINING_EPOCH_I 230 | ORACLE_TESTING_EPOCH_J=$TRAINING_EPOCH_J 231 | ORACLE_TESTING_EPOCH_STEP=$LR_DECAY_EPOCH_J_STEP 232 | ORACLE_TIMESTAMP1=$TRAINING_TIMESTAMP 233 | ORACLE_TIMESTAMP2=$TRAINING_TIMESTAMP 234 | # oracle_testing.py : data_path, current_fold, organ_number, low_range, high_range, 235 | # slice_threshold, slice_thickness, organ_ID, plane, GPU_ID, 236 | # learning_rate1, learning_rate_m1, learning_rate2, learning_rate_m2, 237 | # margin, prob, sample_batch, 238 | # EPOCH_S, EPOCH_I, EPOCH_J, EPOCH_STEP, 239 | # timestamp1, timestamp2 (optional) 240 | if [ "$ENABLE_ORACLE_TESTING" = "1" ] 241 | then 242 | if [ "$ORACLE_TESTING_PLANE" = "X" ] || [ "$ORACLE_TESTING_PLANE" = "A" ] 243 | then 244 | python oracle_testing.py \ 245 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 246 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 247 | $ORACLE_TESTING_ORGAN_ID X $ORACLE_TESTING_GPU \ 248 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 249 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 250 | $ORACLE_TESTING_EPOCH_S $ORACLE_TESTING_EPOCH_I \ 251 | $ORACLE_TESTING_EPOCH_J $ORACLE_TESTING_EPOCH_STEP \ 252 | $ORACLE_TIMESTAMP1 $ORACLE_TIMESTAMP2 253 | fi 254 | if [ "$ORACLE_TESTING_PLANE" = "Y" ] || [ "$ORACLE_TESTING_PLANE" = "A" ] 255 | then 256 | python oracle_testing.py \ 257 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 258 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 259 | $ORACLE_TESTING_ORGAN_ID Y $ORACLE_TESTING_GPU \ 260 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 261 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 262 | $ORACLE_TESTING_EPOCH_S $ORACLE_TESTING_EPOCH_I \ 263 | $ORACLE_TESTING_EPOCH_J $ORACLE_TESTING_EPOCH_STEP \ 264 | $ORACLE_TIMESTAMP1 $ORACLE_TIMESTAMP2 265 | fi 266 | if [ "$ORACLE_TESTING_PLANE" = "Z" ] || [ "$ORACLE_TESTING_PLANE" = "A" ] 267 | then 268 | python oracle_testing.py \ 269 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 270 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 271 | $ORACLE_TESTING_ORGAN_ID Z $ORACLE_TESTING_GPU \ 272 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 273 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 274 | $ORACLE_TESTING_EPOCH_S $ORACLE_TESTING_EPOCH_I \ 275 | $ORACLE_TESTING_EPOCH_J $ORACLE_TESTING_EPOCH_STEP \ 276 | $ORACLE_TIMESTAMP1 $ORACLE_TIMESTAMP2 277 | fi 278 | fi 279 | 280 | #################################################################################################### 281 | # the oracle-scaled fusion process 282 | # variables 283 | ORACLE_FUSION_EPOCH_S=$TRAINING_EPOCH_S 284 | ORACLE_FUSION_EPOCH_I=$TRAINING_EPOCH_I 285 | ORACLE_FUSION_EPOCH_J=$TRAINING_EPOCH_J 286 | ORACLE_FUSION_EPOCH_STEP=$LR_DECAY_EPOCH_J_STEP 287 | ORACLE_FUSION_THRESHOLD=0.5 288 | ORACLE_TIMESTAMP1_X=$TRAINING_TIMESTAMP 289 | ORACLE_TIMESTAMP1_Y=$TRAINING_TIMESTAMP 290 | ORACLE_TIMESTAMP1_Z=$TRAINING_TIMESTAMP 291 | ORACLE_TIMESTAMP2_X=$TRAINING_TIMESTAMP 292 | ORACLE_TIMESTAMP2_Y=$TRAINING_TIMESTAMP 293 | ORACLE_TIMESTAMP2_Z=$TRAINING_TIMESTAMP 294 | # oracle_fusion.py : data_path, current_fold, organ_number, low_range, high_range, 295 | # slice_threshold, slice_thickness, organ_ID, plane, GPU_ID, 296 | # learning_rate1, learning_rate_m1, learning_rate2, learning_rate_m2, margin, 297 | # EPOCH_S, EPOCH_I, EPOCH_J, EPOCH_STEP, threshold, 298 | # timestamp1_X, timestamp1_Y, timestamp1_Z, 299 | # timestamp2_X (optional), timestamp2_Y (optional), timestamp2_Z (optional) 300 | if [ "$ENABLE_ORACLE_FUSION" = "1" ] 301 | then 302 | python oracle_fusion.py \ 303 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 304 | $SLICE_THRESHOLD $SLICE_THICKNESS $ORACLE_TESTING_ORGAN_ID $ORACLE_TESTING_GPU \ 305 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 $TRAINING_MARGIN \ 306 | $ORACLE_FUSION_EPOCH_S $ORACLE_FUSION_EPOCH_I $ORACLE_FUSION_EPOCH_J \ 307 | $ORACLE_FUSION_EPOCH_STEP $ORACLE_FUSION_THRESHOLD \ 308 | $ORACLE_TIMESTAMP1_X $ORACLE_TIMESTAMP1_Y $ORACLE_TIMESTAMP1_Z \ 309 | $ORACLE_TIMESTAMP2_X $ORACLE_TIMESTAMP2_Y $ORACLE_TIMESTAMP2_Z 310 | fi 311 | 312 | #################################################################################################### 313 | # the coarse-to-fine testing process 314 | # variables 315 | FINE_TESTING_EPOCH_S=$TRAINING_EPOCH_S 316 | FINE_TESTING_EPOCH_I=$TRAINING_EPOCH_I 317 | FINE_TESTING_EPOCH_J=$TRAINING_EPOCH_J 318 | FINE_TESTING_EPOCH_STEP=$LR_DECAY_EPOCH_J_STEP 319 | FINE_FUSION_THRESHOLD=0.5 320 | COARSE2FINE_TIMESTAMP1_X=$TRAINING_TIMESTAMP 321 | COARSE2FINE_TIMESTAMP1_Y=$TRAINING_TIMESTAMP 322 | COARSE2FINE_TIMESTAMP1_Z=$TRAINING_TIMESTAMP 323 | COARSE2FINE_TIMESTAMP2_X=$TRAINING_TIMESTAMP 324 | COARSE2FINE_TIMESTAMP2_Y=$TRAINING_TIMESTAMP 325 | COARSE2FINE_TIMESTAMP2_Z=$TRAINING_TIMESTAMP 326 | MAX_ROUNDS=10 327 | # coarse2fine_testing.py : data_path, current_fold, organ_number, low_range, high_range, 328 | # slice_threshold, slice_thickness, organ_ID, GPU_ID, 329 | # learning_rate1, learning_rate_m1, learning_rate2, learning_rate_m2, margin, 330 | # coarse_fusion_starting_iterations, coarse_fusion_step, coarse_fusion_max_iterations, 331 | # coarse_fusion_threshold, coarse_fusion_code, 332 | # EPOCH_S, EPOCH_I, EPOCH_J, EPOCH_STEP, 333 | # fine_fusion_threshold, max_rounds, 334 | # timestamp1_X, timestamp1_Y, timestamp1_Z, 335 | # timestamp2_X (optional), timestamp2_Y (optional), timestamp2_Z (optional) 336 | if [ "$ENABLE_COARSE2FINE_TESTING" = "1" ] 337 | then 338 | python coarse2fine_testing.py \ 339 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 340 | $SLICE_THRESHOLD $SLICE_THICKNESS $COARSE2FINE_TESTING_ORGAN_ID $COARSE2FINE_TESTING_GPU \ 341 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 $TRAINING_MARGIN \ 342 | $FINE_TESTING_EPOCH_S $FINE_TESTING_EPOCH_I $FINE_TESTING_EPOCH_J $FINE_TESTING_EPOCH_STEP \ 343 | $COARSE_FUSION_THRESHOLD $FINE_FUSION_THRESHOLD $MAX_ROUNDS \ 344 | $COARSE2FINE_TIMESTAMP1_X $COARSE2FINE_TIMESTAMP1_Y $COARSE2FINE_TIMESTAMP1_Z \ 345 | $COARSE2FINE_TIMESTAMP2_X $COARSE2FINE_TIMESTAMP2_Y $COARSE2FINE_TIMESTAMP2_Z \ 346 | | tee ./logs/FD${CURRENT_FOLD}_test_${TRAINING_TIMESTAMP}.txt 347 | fi 348 | 349 | #################################################################################################### 350 | -------------------------------------------------------------------------------- /f0.sh: -------------------------------------------------------------------------------- 1 | #################################################################################################### 2 | # RSTN: Recurrent Saliency Transformation Network for organ segmentation framework # 3 | # This is PyTorch 0.4.0 Python 3.6 verison of OrganSegRSTN in CAFFE Python 2.7 . # 4 | # Author: Tianwei Ni, Huangjie Zheng, Lingxi Xie. # 5 | # # 6 | # If you use our codes, please cite our paper accordingly: # 7 | # Qihang Yu, Lingxi Xie, Yan Wang, Yuyin Zhou, Elliot K. Fishman, Alan L. Yuille, # 8 | # "Recurrent Saliency Transformation Network: # 9 | # Incorporating Multi-Stage Visual Cues for Small Organ Segmentation", # 10 | # in IEEE Conference on Computer Vision and Pattern Recognition, 2018. # 11 | # # 12 | # NOTE: this program can be used for multi-organ segmentation. # 13 | # Please also refer to its previous version, OrganSegC2F. # 14 | #################################################################################################### 15 | 16 | #################################################################################################### 17 | # variables for conveniencer 18 | CURRENT_ORGAN_ID=1 19 | CURRENT_PLANE=A 20 | CURRENT_FOLD=0 21 | CURRENT_GPU=0 22 | 23 | #################################################################################################### 24 | # turn on these switches to execute each module 25 | ENABLE_INITIALIZATION=0 26 | ENABLE_TRAINING=1 27 | ENABLE_COARSE_TESTING=1 28 | ENABLE_COARSE_FUSION=1 29 | ENABLE_ORACLE_TESTING=1 30 | ENABLE_ORACLE_FUSION=1 31 | ENABLE_COARSE2FINE_TESTING=1 32 | # training settings: X|Y|Z 33 | TRAINING_ORGAN_ID=$CURRENT_ORGAN_ID 34 | TRAINING_PLANE=$CURRENT_PLANE 35 | TRAINING_GPU=$CURRENT_GPU 36 | # coarse_testing settings: X|Y|Z, before this, coarse-scaled models shall be ready 37 | COARSE_TESTING_ORGAN_ID=$CURRENT_ORGAN_ID 38 | COARSE_TESTING_PLANE=$CURRENT_PLANE 39 | COARSE_TESTING_GPU=$CURRENT_GPU 40 | # coarse_fusion settings: before this, coarse-scaled results on 3 views shall be ready 41 | COARSE_FUSION_ORGAN_ID=$CURRENT_ORGAN_ID 42 | # oracle_testing settings: X|Y|Z, before this, fine-scaled models shall be ready 43 | ORACLE_TESTING_ORGAN_ID=$CURRENT_ORGAN_ID 44 | ORACLE_TESTING_PLANE=$CURRENT_PLANE 45 | ORACLE_TESTING_GPU=$CURRENT_GPU 46 | # oracle_fusion settings: before this, fine-scaled results on 3 views shall be ready 47 | ORACLE_FUSION_ORGAN_ID=$CURRENT_ORGAN_ID 48 | # fine_testing settings: before this, both coarse-scaled and fine-scaled models shall be ready 49 | COARSE2FINE_TESTING_ORGAN_ID=$CURRENT_ORGAN_ID 50 | COARSE2FINE_TESTING_GPU=$CURRENT_GPU 51 | 52 | #################################################################################################### 53 | # defining the root path which stores image and label data 54 | DATA_PATH='/home/datasets/Pancreas82NIH/' 55 | 56 | #################################################################################################### 57 | # data initialization: only needs to be run once 58 | # variables 59 | ORGAN_NUMBER=1 60 | FOLDS=4 61 | LOW_RANGE=-100 62 | HIGH_RANGE=240 63 | # init.py : data_path, organ_number, folds, low_range, high_range 64 | if [ "$ENABLE_INITIALIZATION" = "1" ] 65 | then 66 | python init.py \ 67 | $DATA_PATH $ORGAN_NUMBER $FOLDS $LOW_RANGE $HIGH_RANGE 68 | fi 69 | 70 | #################################################################################################### 71 | # the individual and joint training processes 72 | # variables 73 | SLICE_THRESHOLD=0.98 74 | SLICE_THICKNESS=3 75 | LEARNING_RATE1=1e-5 76 | LEARNING_RATE2=1e-5 77 | LEARNING_RATE_M1=10 78 | LEARNING_RATE_M2=10 79 | TRAINING_MARGIN=20 80 | TRAINING_PROB=0.5 81 | TRAINING_SAMPLE_BATCH=1 82 | TRAINING_EPOCH_S=6 83 | TRAINING_EPOCH_I=6 84 | TRAINING_EPOCH_J=6 85 | LR_DECAY_EPOCH_J_STEP=2 86 | if [ "$ENABLE_TRAINING" = "1" ] 87 | then 88 | TRAINING_TIMESTAMP=$(date +'%Y%m%d_%H%M%S') 89 | else 90 | TRAINING_TIMESTAMP='20220221_210935' 91 | fi 92 | # training.py : data_path, current_fold, organ_number, low_range, high_range, 93 | # slice_threshold, slice_thickness, organ_ID, plane, GPU_ID, 94 | # learning_rate1, learning_rate2 (not used), margin, prob, sample_batch, 95 | # step, ·max_iterations1, max_iterations2 (not used), fraction, timestamp 96 | if [ "$ENABLE_TRAINING" = "1" ] 97 | then 98 | if [ "$TRAINING_PLANE" = "X" ] || [ "$TRAINING_PLANE" = "A" ] 99 | then 100 | TRAINING_MODELNAME=X${SLICE_THICKNESS}_${TRAINING_ORGAN_ID} 101 | TRAINING_LOG=${DATA_PATH}logs/FD${CURRENT_FOLD}:${TRAINING_MODELNAME}_${TRAINING_TIMESTAMP}.txt 102 | python training.py \ 103 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 104 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 105 | $TRAINING_ORGAN_ID X $TRAINING_GPU \ 106 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 107 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 108 | $TRAINING_EPOCH_S $TRAINING_EPOCH_I $TRAINING_EPOCH_J \ 109 | $LR_DECAY_EPOCH_J_STEP $TRAINING_TIMESTAMP 1 2>&1 | tee $TRAINING_LOG 110 | fi 111 | if [ "$TRAINING_PLANE" = "Y" ] || [ "$TRAINING_PLANE" = "A" ] 112 | then 113 | TRAINING_MODELNAME=Y${SLICE_THICKNESS}_${TRAINING_ORGAN_ID} 114 | TRAINING_LOG=${DATA_PATH}logs/FD${CURRENT_FOLD}:${TRAINING_MODELNAME}_${TRAINING_TIMESTAMP}.txt 115 | python training.py \ 116 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 117 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 118 | $TRAINING_ORGAN_ID Y $TRAINING_GPU \ 119 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 120 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 121 | $TRAINING_EPOCH_S $TRAINING_EPOCH_I $TRAINING_EPOCH_J \ 122 | $LR_DECAY_EPOCH_J_STEP $TRAINING_TIMESTAMP 1 2>&1 | tee $TRAINING_LOG 123 | fi 124 | if [ "$TRAINING_PLANE" = "Z" ] || [ "$TRAINING_PLANE" = "A" ] 125 | then 126 | TRAINING_MODELNAME=Z${SLICE_THICKNESS}_${TRAINING_ORGAN_ID} 127 | TRAINING_LOG=${DATA_PATH}logs/FD${CURRENT_FOLD}:${TRAINING_MODELNAME}_${TRAINING_TIMESTAMP}.txt 128 | python training.py \ 129 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 130 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 131 | $TRAINING_ORGAN_ID Z $TRAINING_GPU \ 132 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 133 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 134 | $TRAINING_EPOCH_S $TRAINING_EPOCH_I $TRAINING_EPOCH_J \ 135 | $LR_DECAY_EPOCH_J_STEP $TRAINING_TIMESTAMP 1 2>&1 | tee $TRAINING_LOG 136 | fi 137 | fi 138 | 139 | #################################################################################################### 140 | # the coarse-scaled testing processes 141 | # variables 142 | COARSE_TESTING_EPOCH_S=$TRAINING_EPOCH_S 143 | COARSE_TESTING_EPOCH_I=$TRAINING_EPOCH_I 144 | COARSE_TESTING_EPOCH_J=$TRAINING_EPOCH_J 145 | COARSE_TESTING_EPOCH_STEP=$LR_DECAY_EPOCH_J_STEP 146 | COARSE_TIMESTAMP1=$TRAINING_TIMESTAMP 147 | COARSE_TIMESTAMP2=$TRAINING_TIMESTAMP 148 | # coarse_testing.py : data_path, current_fold, organ_number, low_range, high_range, 149 | # slice_threshold, slice_thickness, organ_ID, plane, GPU_ID, 150 | # learning_rate1, learning_rate2, margin, prob, sample_batch, 151 | # EPOCH_S, EPOCH_I, EPOCH_J, EPOCH_STEP, 152 | # timestamp1, timestamp2 (optional) 153 | if [ "$ENABLE_COARSE_TESTING" = "1" ] 154 | then 155 | if [ "$COARSE_TESTING_PLANE" = "X" ] || [ "$COARSE_TESTING_PLANE" = "A" ] 156 | then 157 | python coarse_testing.py \ 158 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 159 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 160 | $COARSE_TESTING_ORGAN_ID X $COARSE_TESTING_GPU \ 161 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 162 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 163 | $COARSE_TESTING_EPOCH_S $COARSE_TESTING_EPOCH_I \ 164 | $COARSE_TESTING_EPOCH_J $COARSE_TESTING_EPOCH_STEP \ 165 | $COARSE_TIMESTAMP1 $COARSE_TIMESTAMP2 166 | fi 167 | if [ "$COARSE_TESTING_PLANE" = "Y" ] || [ "$COARSE_TESTING_PLANE" = "A" ] 168 | then 169 | python coarse_testing.py \ 170 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 171 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 172 | $COARSE_TESTING_ORGAN_ID Y $COARSE_TESTING_GPU \ 173 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 174 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 175 | $COARSE_TESTING_EPOCH_S $COARSE_TESTING_EPOCH_I \ 176 | $COARSE_TESTING_EPOCH_J $COARSE_TESTING_EPOCH_STEP \ 177 | $COARSE_TIMESTAMP1 $COARSE_TIMESTAMP2 178 | fi 179 | if [ "$COARSE_TESTING_PLANE" = "Z" ] || [ "$COARSE_TESTING_PLANE" = "A" ] 180 | then 181 | python coarse_testing.py \ 182 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 183 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 184 | $COARSE_TESTING_ORGAN_ID Z $COARSE_TESTING_GPU \ 185 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 186 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 187 | $COARSE_TESTING_EPOCH_S $COARSE_TESTING_EPOCH_I \ 188 | $COARSE_TESTING_EPOCH_J $COARSE_TESTING_EPOCH_STEP \ 189 | $COARSE_TIMESTAMP1 $COARSE_TIMESTAMP2 190 | fi 191 | fi 192 | 193 | #################################################################################################### 194 | # the coarse-scaled fusion process 195 | # variables 196 | COARSE_FUSION_EPOCH_S=$TRAINING_EPOCH_S 197 | COARSE_FUSION_EPOCH_I=$TRAINING_EPOCH_I 198 | COARSE_FUSION_EPOCH_J=$TRAINING_EPOCH_J 199 | COARSE_FUSION_EPOCH_STEP=$LR_DECAY_EPOCH_J_STEP 200 | COARSE_FUSION_THRESHOLD=0.5 201 | COARSE_TIMESTAMP1_X=$TRAINING_TIMESTAMP 202 | COARSE_TIMESTAMP1_Y=$TRAINING_TIMESTAMP 203 | COARSE_TIMESTAMP1_Z=$TRAINING_TIMESTAMP 204 | COARSE_TIMESTAMP2_X=$TRAINING_TIMESTAMP 205 | COARSE_TIMESTAMP2_Y=$TRAINING_TIMESTAMP 206 | COARSE_TIMESTAMP2_Z=$TRAINING_TIMESTAMP 207 | # coarse_fusion.py : data_path, current_fold, organ_number, low_range, high_range, 208 | # slice_threshold, slice_thickness, organ_ID, plane, GPU_ID, 209 | # learning_rate1, learning_rate_m1, learning_rate2, learning_rate_m2, margin, 210 | # EPOCH_S, EPOCH_I, EPOCH_J, EPOCH_STEP, threshold, 211 | # timestamp1_X, timestamp1_Y, timestamp1_Z, 212 | # timestamp2_X (optional), timestamp2_Y (optional), timestamp2_Z (optional) 213 | if [ "$ENABLE_COARSE_FUSION" = "1" ] 214 | then 215 | python coarse_fusion.py \ 216 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 217 | $SLICE_THRESHOLD $SLICE_THICKNESS $COARSE_TESTING_ORGAN_ID $COARSE_TESTING_GPU \ 218 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 $TRAINING_MARGIN \ 219 | $COARSE_FUSION_EPOCH_S $COARSE_FUSION_EPOCH_I $COARSE_FUSION_EPOCH_J \ 220 | $COARSE_FUSION_EPOCH_STEP $COARSE_FUSION_THRESHOLD \ 221 | $COARSE_TIMESTAMP1_X $COARSE_TIMESTAMP1_Y $COARSE_TIMESTAMP1_Z \ 222 | $COARSE_TIMESTAMP2_X $COARSE_TIMESTAMP2_Y $COARSE_TIMESTAMP2_Z 223 | fi 224 | 225 | #################################################################################################### 226 | # the oracle testing processes 227 | # variables 228 | ORACLE_TESTING_EPOCH_S=$TRAINING_EPOCH_S 229 | ORACLE_TESTING_EPOCH_I=$TRAINING_EPOCH_I 230 | ORACLE_TESTING_EPOCH_J=$TRAINING_EPOCH_J 231 | ORACLE_TESTING_EPOCH_STEP=$LR_DECAY_EPOCH_J_STEP 232 | ORACLE_TIMESTAMP1=$TRAINING_TIMESTAMP 233 | ORACLE_TIMESTAMP2=$TRAINING_TIMESTAMP 234 | # oracle_testing.py : data_path, current_fold, organ_number, low_range, high_range, 235 | # slice_threshold, slice_thickness, organ_ID, plane, GPU_ID, 236 | # learning_rate1, learning_rate_m1, learning_rate2, learning_rate_m2, 237 | # margin, prob, sample_batch, 238 | # EPOCH_S, EPOCH_I, EPOCH_J, EPOCH_STEP, 239 | # timestamp1, timestamp2 (optional) 240 | if [ "$ENABLE_ORACLE_TESTING" = "1" ] 241 | then 242 | if [ "$ORACLE_TESTING_PLANE" = "X" ] || [ "$ORACLE_TESTING_PLANE" = "A" ] 243 | then 244 | python oracle_testing.py \ 245 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 246 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 247 | $ORACLE_TESTING_ORGAN_ID X $ORACLE_TESTING_GPU \ 248 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 249 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 250 | $ORACLE_TESTING_EPOCH_S $ORACLE_TESTING_EPOCH_I \ 251 | $ORACLE_TESTING_EPOCH_J $ORACLE_TESTING_EPOCH_STEP \ 252 | $ORACLE_TIMESTAMP1 $ORACLE_TIMESTAMP2 253 | fi 254 | if [ "$ORACLE_TESTING_PLANE" = "Y" ] || [ "$ORACLE_TESTING_PLANE" = "A" ] 255 | then 256 | python oracle_testing.py \ 257 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 258 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 259 | $ORACLE_TESTING_ORGAN_ID Y $ORACLE_TESTING_GPU \ 260 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 261 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 262 | $ORACLE_TESTING_EPOCH_S $ORACLE_TESTING_EPOCH_I \ 263 | $ORACLE_TESTING_EPOCH_J $ORACLE_TESTING_EPOCH_STEP \ 264 | $ORACLE_TIMESTAMP1 $ORACLE_TIMESTAMP2 265 | fi 266 | if [ "$ORACLE_TESTING_PLANE" = "Z" ] || [ "$ORACLE_TESTING_PLANE" = "A" ] 267 | then 268 | python oracle_testing.py \ 269 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 270 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 271 | $ORACLE_TESTING_ORGAN_ID Z $ORACLE_TESTING_GPU \ 272 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 273 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 274 | $ORACLE_TESTING_EPOCH_S $ORACLE_TESTING_EPOCH_I \ 275 | $ORACLE_TESTING_EPOCH_J $ORACLE_TESTING_EPOCH_STEP \ 276 | $ORACLE_TIMESTAMP1 $ORACLE_TIMESTAMP2 277 | fi 278 | fi 279 | 280 | #################################################################################################### 281 | # the oracle-scaled fusion process 282 | # variables 283 | ORACLE_FUSION_EPOCH_S=$TRAINING_EPOCH_S 284 | ORACLE_FUSION_EPOCH_I=$TRAINING_EPOCH_I 285 | ORACLE_FUSION_EPOCH_J=$TRAINING_EPOCH_J 286 | ORACLE_FUSION_EPOCH_STEP=$LR_DECAY_EPOCH_J_STEP 287 | ORACLE_FUSION_THRESHOLD=0.5 288 | ORACLE_TIMESTAMP1_X=$TRAINING_TIMESTAMP 289 | ORACLE_TIMESTAMP1_Y=$TRAINING_TIMESTAMP 290 | ORACLE_TIMESTAMP1_Z=$TRAINING_TIMESTAMP 291 | ORACLE_TIMESTAMP2_X=$TRAINING_TIMESTAMP 292 | ORACLE_TIMESTAMP2_Y=$TRAINING_TIMESTAMP 293 | ORACLE_TIMESTAMP2_Z=$TRAINING_TIMESTAMP 294 | # oracle_fusion.py : data_path, current_fold, organ_number, low_range, high_range, 295 | # slice_threshold, slice_thickness, organ_ID, plane, GPU_ID, 296 | # learning_rate1, learning_rate_m1, learning_rate2, learning_rate_m2, margin, 297 | # EPOCH_S, EPOCH_I, EPOCH_J, EPOCH_STEP, threshold, 298 | # timestamp1_X, timestamp1_Y, timestamp1_Z, 299 | # timestamp2_X (optional), timestamp2_Y (optional), timestamp2_Z (optional) 300 | if [ "$ENABLE_ORACLE_FUSION" = "1" ] 301 | then 302 | python oracle_fusion.py \ 303 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 304 | $SLICE_THRESHOLD $SLICE_THICKNESS $ORACLE_TESTING_ORGAN_ID $ORACLE_TESTING_GPU \ 305 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 $TRAINING_MARGIN \ 306 | $ORACLE_FUSION_EPOCH_S $ORACLE_FUSION_EPOCH_I $ORACLE_FUSION_EPOCH_J \ 307 | $ORACLE_FUSION_EPOCH_STEP $ORACLE_FUSION_THRESHOLD \ 308 | $ORACLE_TIMESTAMP1_X $ORACLE_TIMESTAMP1_Y $ORACLE_TIMESTAMP1_Z \ 309 | $ORACLE_TIMESTAMP2_X $ORACLE_TIMESTAMP2_Y $ORACLE_TIMESTAMP2_Z 310 | fi 311 | 312 | #################################################################################################### 313 | # the coarse-to-fine testing process 314 | # variables 315 | FINE_TESTING_EPOCH_S=$TRAINING_EPOCH_S 316 | FINE_TESTING_EPOCH_I=$TRAINING_EPOCH_I 317 | FINE_TESTING_EPOCH_J=$TRAINING_EPOCH_J 318 | FINE_TESTING_EPOCH_STEP=$LR_DECAY_EPOCH_J_STEP 319 | FINE_FUSION_THRESHOLD=0.5 320 | COARSE2FINE_TIMESTAMP1_X=$TRAINING_TIMESTAMP 321 | COARSE2FINE_TIMESTAMP1_Y=$TRAINING_TIMESTAMP 322 | COARSE2FINE_TIMESTAMP1_Z=$TRAINING_TIMESTAMP 323 | COARSE2FINE_TIMESTAMP2_X=$TRAINING_TIMESTAMP 324 | COARSE2FINE_TIMESTAMP2_Y=$TRAINING_TIMESTAMP 325 | COARSE2FINE_TIMESTAMP2_Z=$TRAINING_TIMESTAMP 326 | MAX_ROUNDS=10 327 | # coarse2fine_testing.py : data_path, current_fold, organ_number, low_range, high_range, 328 | # slice_threshold, slice_thickness, organ_ID, GPU_ID, 329 | # learning_rate1, learning_rate_m1, learning_rate2, learning_rate_m2, margin, 330 | # coarse_fusion_starting_iterations, coarse_fusion_step, coarse_fusion_max_iterations, 331 | # coarse_fusion_threshold, coarse_fusion_code, 332 | # EPOCH_S, EPOCH_I, EPOCH_J, EPOCH_STEP, 333 | # fine_fusion_threshold, max_rounds, 334 | # timestamp1_X, timestamp1_Y, timestamp1_Z, 335 | # timestamp2_X (optional), timestamp2_Y (optional), timestamp2_Z (optional) 336 | if [ "$ENABLE_COARSE2FINE_TESTING" = "1" ] 337 | then 338 | python coarse2fine_testing.py \ 339 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 340 | $SLICE_THRESHOLD $SLICE_THICKNESS $COARSE2FINE_TESTING_ORGAN_ID $COARSE2FINE_TESTING_GPU \ 341 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 $TRAINING_MARGIN \ 342 | $FINE_TESTING_EPOCH_S $FINE_TESTING_EPOCH_I $FINE_TESTING_EPOCH_J $FINE_TESTING_EPOCH_STEP \ 343 | $COARSE_FUSION_THRESHOLD $FINE_FUSION_THRESHOLD $MAX_ROUNDS \ 344 | $COARSE2FINE_TIMESTAMP1_X $COARSE2FINE_TIMESTAMP1_Y $COARSE2FINE_TIMESTAMP1_Z \ 345 | $COARSE2FINE_TIMESTAMP2_X $COARSE2FINE_TIMESTAMP2_Y $COARSE2FINE_TIMESTAMP2_Z \ 346 | | tee ./logs/FD${CURRENT_FOLD}_test_${TRAINING_TIMESTAMP}.txt 347 | fi 348 | 349 | #################################################################################################### 350 | -------------------------------------------------------------------------------- /f1.sh: -------------------------------------------------------------------------------- 1 | #################################################################################################### 2 | # RSTN: Recurrent Saliency Transformation Network for organ segmentation framework # 3 | # This is PyTorch 0.4.0 Python 3.6 verison of OrganSegRSTN in CAFFE Python 2.7 . # 4 | # Author: Tianwei Ni, Huangjie Zheng, Lingxi Xie. # 5 | # # 6 | # If you use our codes, please cite our paper accordingly: # 7 | # Qihang Yu, Lingxi Xie, Yan Wang, Yuyin Zhou, Elliot K. Fishman, Alan L. Yuille, # 8 | # "Recurrent Saliency Transformation Network: # 9 | # Incorporating Multi-Stage Visual Cues for Small Organ Segmentation", # 10 | # in IEEE Conference on Computer Vision and Pattern Recognition, 2018. # 11 | # # 12 | # NOTE: this program can be used for multi-organ segmentation. # 13 | # Please also refer to its previous version, OrganSegC2F. # 14 | #################################################################################################### 15 | 16 | #################################################################################################### 17 | # variables for conveniencer 18 | CURRENT_ORGAN_ID=1 19 | CURRENT_PLANE=A 20 | CURRENT_FOLD=1 21 | CURRENT_GPU=0 22 | 23 | #################################################################################################### 24 | # turn on these switches to execute each module 25 | ENABLE_INITIALIZATION=0 26 | ENABLE_TRAINING=1 27 | ENABLE_COARSE_TESTING=1 28 | ENABLE_COARSE_FUSION=1 29 | ENABLE_ORACLE_TESTING=1 30 | ENABLE_ORACLE_FUSION=1 31 | ENABLE_COARSE2FINE_TESTING=1 32 | # training settings: X|Y|Z 33 | TRAINING_ORGAN_ID=$CURRENT_ORGAN_ID 34 | TRAINING_PLANE=$CURRENT_PLANE 35 | TRAINING_GPU=$CURRENT_GPU 36 | # coarse_testing settings: X|Y|Z, before this, coarse-scaled models shall be ready 37 | COARSE_TESTING_ORGAN_ID=$CURRENT_ORGAN_ID 38 | COARSE_TESTING_PLANE=$CURRENT_PLANE 39 | COARSE_TESTING_GPU=$CURRENT_GPU 40 | # coarse_fusion settings: before this, coarse-scaled results on 3 views shall be ready 41 | COARSE_FUSION_ORGAN_ID=$CURRENT_ORGAN_ID 42 | # oracle_testing settings: X|Y|Z, before this, fine-scaled models shall be ready 43 | ORACLE_TESTING_ORGAN_ID=$CURRENT_ORGAN_ID 44 | ORACLE_TESTING_PLANE=$CURRENT_PLANE 45 | ORACLE_TESTING_GPU=$CURRENT_GPU 46 | # oracle_fusion settings: before this, fine-scaled results on 3 views shall be ready 47 | ORACLE_FUSION_ORGAN_ID=$CURRENT_ORGAN_ID 48 | # fine_testing settings: before this, both coarse-scaled and fine-scaled models shall be ready 49 | COARSE2FINE_TESTING_ORGAN_ID=$CURRENT_ORGAN_ID 50 | COARSE2FINE_TESTING_GPU=$CURRENT_GPU 51 | 52 | #################################################################################################### 53 | # defining the root path which stores image and label data 54 | DATA_PATH='/home/datasets/Pancreas82NIH/' 55 | 56 | #################################################################################################### 57 | # data initialization: only needs to be run once 58 | # variables 59 | ORGAN_NUMBER=1 60 | FOLDS=4 61 | LOW_RANGE=-100 62 | HIGH_RANGE=240 63 | # init.py : data_path, organ_number, folds, low_range, high_range 64 | if [ "$ENABLE_INITIALIZATION" = "1" ] 65 | then 66 | python init.py \ 67 | $DATA_PATH $ORGAN_NUMBER $FOLDS $LOW_RANGE $HIGH_RANGE 68 | fi 69 | 70 | #################################################################################################### 71 | # the individual and joint training processes 72 | # variables 73 | SLICE_THRESHOLD=0.98 74 | SLICE_THICKNESS=3 75 | LEARNING_RATE1=1e-5 76 | LEARNING_RATE2=1e-5 77 | LEARNING_RATE_M1=10 78 | LEARNING_RATE_M2=10 79 | TRAINING_MARGIN=20 80 | TRAINING_PROB=0.5 81 | TRAINING_SAMPLE_BATCH=1 82 | TRAINING_EPOCH_S=6 83 | TRAINING_EPOCH_I=6 84 | TRAINING_EPOCH_J=6 85 | LR_DECAY_EPOCH_J_STEP=2 86 | if [ "$ENABLE_TRAINING" = "1" ] 87 | then 88 | TRAINING_TIMESTAMP=$(date +'%Y%m%d_%H%M%S') 89 | else 90 | TRAINING_TIMESTAMP='20220221_210935' 91 | fi 92 | # training.py : data_path, current_fold, organ_number, low_range, high_range, 93 | # slice_threshold, slice_thickness, organ_ID, plane, GPU_ID, 94 | # learning_rate1, learning_rate2 (not used), margin, prob, sample_batch, 95 | # step, ·max_iterations1, max_iterations2 (not used), fraction, timestamp 96 | if [ "$ENABLE_TRAINING" = "1" ] 97 | then 98 | if [ "$TRAINING_PLANE" = "X" ] || [ "$TRAINING_PLANE" = "A" ] 99 | then 100 | TRAINING_MODELNAME=X${SLICE_THICKNESS}_${TRAINING_ORGAN_ID} 101 | TRAINING_LOG=${DATA_PATH}logs/FD${CURRENT_FOLD}:${TRAINING_MODELNAME}_${TRAINING_TIMESTAMP}.txt 102 | python training.py \ 103 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 104 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 105 | $TRAINING_ORGAN_ID X $TRAINING_GPU \ 106 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 107 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 108 | $TRAINING_EPOCH_S $TRAINING_EPOCH_I $TRAINING_EPOCH_J \ 109 | $LR_DECAY_EPOCH_J_STEP $TRAINING_TIMESTAMP 1 2>&1 | tee $TRAINING_LOG 110 | fi 111 | if [ "$TRAINING_PLANE" = "Y" ] || [ "$TRAINING_PLANE" = "A" ] 112 | then 113 | TRAINING_MODELNAME=Y${SLICE_THICKNESS}_${TRAINING_ORGAN_ID} 114 | TRAINING_LOG=${DATA_PATH}logs/FD${CURRENT_FOLD}:${TRAINING_MODELNAME}_${TRAINING_TIMESTAMP}.txt 115 | python training.py \ 116 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 117 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 118 | $TRAINING_ORGAN_ID Y $TRAINING_GPU \ 119 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 120 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 121 | $TRAINING_EPOCH_S $TRAINING_EPOCH_I $TRAINING_EPOCH_J \ 122 | $LR_DECAY_EPOCH_J_STEP $TRAINING_TIMESTAMP 1 2>&1 | tee $TRAINING_LOG 123 | fi 124 | if [ "$TRAINING_PLANE" = "Z" ] || [ "$TRAINING_PLANE" = "A" ] 125 | then 126 | TRAINING_MODELNAME=Z${SLICE_THICKNESS}_${TRAINING_ORGAN_ID} 127 | TRAINING_LOG=${DATA_PATH}logs/FD${CURRENT_FOLD}:${TRAINING_MODELNAME}_${TRAINING_TIMESTAMP}.txt 128 | python training.py \ 129 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 130 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 131 | $TRAINING_ORGAN_ID Z $TRAINING_GPU \ 132 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 133 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 134 | $TRAINING_EPOCH_S $TRAINING_EPOCH_I $TRAINING_EPOCH_J \ 135 | $LR_DECAY_EPOCH_J_STEP $TRAINING_TIMESTAMP 1 2>&1 | tee $TRAINING_LOG 136 | fi 137 | fi 138 | 139 | #################################################################################################### 140 | # the coarse-scaled testing processes 141 | # variables 142 | COARSE_TESTING_EPOCH_S=$TRAINING_EPOCH_S 143 | COARSE_TESTING_EPOCH_I=$TRAINING_EPOCH_I 144 | COARSE_TESTING_EPOCH_J=$TRAINING_EPOCH_J 145 | COARSE_TESTING_EPOCH_STEP=$LR_DECAY_EPOCH_J_STEP 146 | COARSE_TIMESTAMP1=$TRAINING_TIMESTAMP 147 | COARSE_TIMESTAMP2=$TRAINING_TIMESTAMP 148 | # coarse_testing.py : data_path, current_fold, organ_number, low_range, high_range, 149 | # slice_threshold, slice_thickness, organ_ID, plane, GPU_ID, 150 | # learning_rate1, learning_rate2, margin, prob, sample_batch, 151 | # EPOCH_S, EPOCH_I, EPOCH_J, EPOCH_STEP, 152 | # timestamp1, timestamp2 (optional) 153 | if [ "$ENABLE_COARSE_TESTING" = "1" ] 154 | then 155 | if [ "$COARSE_TESTING_PLANE" = "X" ] || [ "$COARSE_TESTING_PLANE" = "A" ] 156 | then 157 | python coarse_testing.py \ 158 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 159 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 160 | $COARSE_TESTING_ORGAN_ID X $COARSE_TESTING_GPU \ 161 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 162 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 163 | $COARSE_TESTING_EPOCH_S $COARSE_TESTING_EPOCH_I \ 164 | $COARSE_TESTING_EPOCH_J $COARSE_TESTING_EPOCH_STEP \ 165 | $COARSE_TIMESTAMP1 $COARSE_TIMESTAMP2 166 | fi 167 | if [ "$COARSE_TESTING_PLANE" = "Y" ] || [ "$COARSE_TESTING_PLANE" = "A" ] 168 | then 169 | python coarse_testing.py \ 170 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 171 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 172 | $COARSE_TESTING_ORGAN_ID Y $COARSE_TESTING_GPU \ 173 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 174 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 175 | $COARSE_TESTING_EPOCH_S $COARSE_TESTING_EPOCH_I \ 176 | $COARSE_TESTING_EPOCH_J $COARSE_TESTING_EPOCH_STEP \ 177 | $COARSE_TIMESTAMP1 $COARSE_TIMESTAMP2 178 | fi 179 | if [ "$COARSE_TESTING_PLANE" = "Z" ] || [ "$COARSE_TESTING_PLANE" = "A" ] 180 | then 181 | python coarse_testing.py \ 182 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 183 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 184 | $COARSE_TESTING_ORGAN_ID Z $COARSE_TESTING_GPU \ 185 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 186 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 187 | $COARSE_TESTING_EPOCH_S $COARSE_TESTING_EPOCH_I \ 188 | $COARSE_TESTING_EPOCH_J $COARSE_TESTING_EPOCH_STEP \ 189 | $COARSE_TIMESTAMP1 $COARSE_TIMESTAMP2 190 | fi 191 | fi 192 | 193 | #################################################################################################### 194 | # the coarse-scaled fusion process 195 | # variables 196 | COARSE_FUSION_EPOCH_S=$TRAINING_EPOCH_S 197 | COARSE_FUSION_EPOCH_I=$TRAINING_EPOCH_I 198 | COARSE_FUSION_EPOCH_J=$TRAINING_EPOCH_J 199 | COARSE_FUSION_EPOCH_STEP=$LR_DECAY_EPOCH_J_STEP 200 | COARSE_FUSION_THRESHOLD=0.5 201 | COARSE_TIMESTAMP1_X=$TRAINING_TIMESTAMP 202 | COARSE_TIMESTAMP1_Y=$TRAINING_TIMESTAMP 203 | COARSE_TIMESTAMP1_Z=$TRAINING_TIMESTAMP 204 | COARSE_TIMESTAMP2_X=$TRAINING_TIMESTAMP 205 | COARSE_TIMESTAMP2_Y=$TRAINING_TIMESTAMP 206 | COARSE_TIMESTAMP2_Z=$TRAINING_TIMESTAMP 207 | # coarse_fusion.py : data_path, current_fold, organ_number, low_range, high_range, 208 | # slice_threshold, slice_thickness, organ_ID, plane, GPU_ID, 209 | # learning_rate1, learning_rate_m1, learning_rate2, learning_rate_m2, margin, 210 | # EPOCH_S, EPOCH_I, EPOCH_J, EPOCH_STEP, threshold, 211 | # timestamp1_X, timestamp1_Y, timestamp1_Z, 212 | # timestamp2_X (optional), timestamp2_Y (optional), timestamp2_Z (optional) 213 | if [ "$ENABLE_COARSE_FUSION" = "1" ] 214 | then 215 | python coarse_fusion.py \ 216 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 217 | $SLICE_THRESHOLD $SLICE_THICKNESS $COARSE_TESTING_ORGAN_ID $COARSE_TESTING_GPU \ 218 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 $TRAINING_MARGIN \ 219 | $COARSE_FUSION_EPOCH_S $COARSE_FUSION_EPOCH_I $COARSE_FUSION_EPOCH_J \ 220 | $COARSE_FUSION_EPOCH_STEP $COARSE_FUSION_THRESHOLD \ 221 | $COARSE_TIMESTAMP1_X $COARSE_TIMESTAMP1_Y $COARSE_TIMESTAMP1_Z \ 222 | $COARSE_TIMESTAMP2_X $COARSE_TIMESTAMP2_Y $COARSE_TIMESTAMP2_Z 223 | fi 224 | 225 | #################################################################################################### 226 | # the oracle testing processes 227 | # variables 228 | ORACLE_TESTING_EPOCH_S=$TRAINING_EPOCH_S 229 | ORACLE_TESTING_EPOCH_I=$TRAINING_EPOCH_I 230 | ORACLE_TESTING_EPOCH_J=$TRAINING_EPOCH_J 231 | ORACLE_TESTING_EPOCH_STEP=$LR_DECAY_EPOCH_J_STEP 232 | ORACLE_TIMESTAMP1=$TRAINING_TIMESTAMP 233 | ORACLE_TIMESTAMP2=$TRAINING_TIMESTAMP 234 | # oracle_testing.py : data_path, current_fold, organ_number, low_range, high_range, 235 | # slice_threshold, slice_thickness, organ_ID, plane, GPU_ID, 236 | # learning_rate1, learning_rate_m1, learning_rate2, learning_rate_m2, 237 | # margin, prob, sample_batch, 238 | # EPOCH_S, EPOCH_I, EPOCH_J, EPOCH_STEP, 239 | # timestamp1, timestamp2 (optional) 240 | if [ "$ENABLE_ORACLE_TESTING" = "1" ] 241 | then 242 | if [ "$ORACLE_TESTING_PLANE" = "X" ] || [ "$ORACLE_TESTING_PLANE" = "A" ] 243 | then 244 | python oracle_testing.py \ 245 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 246 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 247 | $ORACLE_TESTING_ORGAN_ID X $ORACLE_TESTING_GPU \ 248 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 249 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 250 | $ORACLE_TESTING_EPOCH_S $ORACLE_TESTING_EPOCH_I \ 251 | $ORACLE_TESTING_EPOCH_J $ORACLE_TESTING_EPOCH_STEP \ 252 | $ORACLE_TIMESTAMP1 $ORACLE_TIMESTAMP2 253 | fi 254 | if [ "$ORACLE_TESTING_PLANE" = "Y" ] || [ "$ORACLE_TESTING_PLANE" = "A" ] 255 | then 256 | python oracle_testing.py \ 257 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 258 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 259 | $ORACLE_TESTING_ORGAN_ID Y $ORACLE_TESTING_GPU \ 260 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 261 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 262 | $ORACLE_TESTING_EPOCH_S $ORACLE_TESTING_EPOCH_I \ 263 | $ORACLE_TESTING_EPOCH_J $ORACLE_TESTING_EPOCH_STEP \ 264 | $ORACLE_TIMESTAMP1 $ORACLE_TIMESTAMP2 265 | fi 266 | if [ "$ORACLE_TESTING_PLANE" = "Z" ] || [ "$ORACLE_TESTING_PLANE" = "A" ] 267 | then 268 | python oracle_testing.py \ 269 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 270 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 271 | $ORACLE_TESTING_ORGAN_ID Z $ORACLE_TESTING_GPU \ 272 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 273 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 274 | $ORACLE_TESTING_EPOCH_S $ORACLE_TESTING_EPOCH_I \ 275 | $ORACLE_TESTING_EPOCH_J $ORACLE_TESTING_EPOCH_STEP \ 276 | $ORACLE_TIMESTAMP1 $ORACLE_TIMESTAMP2 277 | fi 278 | fi 279 | 280 | #################################################################################################### 281 | # the oracle-scaled fusion process 282 | # variables 283 | ORACLE_FUSION_EPOCH_S=$TRAINING_EPOCH_S 284 | ORACLE_FUSION_EPOCH_I=$TRAINING_EPOCH_I 285 | ORACLE_FUSION_EPOCH_J=$TRAINING_EPOCH_J 286 | ORACLE_FUSION_EPOCH_STEP=$LR_DECAY_EPOCH_J_STEP 287 | ORACLE_FUSION_THRESHOLD=0.5 288 | ORACLE_TIMESTAMP1_X=$TRAINING_TIMESTAMP 289 | ORACLE_TIMESTAMP1_Y=$TRAINING_TIMESTAMP 290 | ORACLE_TIMESTAMP1_Z=$TRAINING_TIMESTAMP 291 | ORACLE_TIMESTAMP2_X=$TRAINING_TIMESTAMP 292 | ORACLE_TIMESTAMP2_Y=$TRAINING_TIMESTAMP 293 | ORACLE_TIMESTAMP2_Z=$TRAINING_TIMESTAMP 294 | # oracle_fusion.py : data_path, current_fold, organ_number, low_range, high_range, 295 | # slice_threshold, slice_thickness, organ_ID, plane, GPU_ID, 296 | # learning_rate1, learning_rate_m1, learning_rate2, learning_rate_m2, margin, 297 | # EPOCH_S, EPOCH_I, EPOCH_J, EPOCH_STEP, threshold, 298 | # timestamp1_X, timestamp1_Y, timestamp1_Z, 299 | # timestamp2_X (optional), timestamp2_Y (optional), timestamp2_Z (optional) 300 | if [ "$ENABLE_ORACLE_FUSION" = "1" ] 301 | then 302 | python oracle_fusion.py \ 303 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 304 | $SLICE_THRESHOLD $SLICE_THICKNESS $ORACLE_TESTING_ORGAN_ID $ORACLE_TESTING_GPU \ 305 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 $TRAINING_MARGIN \ 306 | $ORACLE_FUSION_EPOCH_S $ORACLE_FUSION_EPOCH_I $ORACLE_FUSION_EPOCH_J \ 307 | $ORACLE_FUSION_EPOCH_STEP $ORACLE_FUSION_THRESHOLD \ 308 | $ORACLE_TIMESTAMP1_X $ORACLE_TIMESTAMP1_Y $ORACLE_TIMESTAMP1_Z \ 309 | $ORACLE_TIMESTAMP2_X $ORACLE_TIMESTAMP2_Y $ORACLE_TIMESTAMP2_Z 310 | fi 311 | 312 | #################################################################################################### 313 | # the coarse-to-fine testing process 314 | # variables 315 | FINE_TESTING_EPOCH_S=$TRAINING_EPOCH_S 316 | FINE_TESTING_EPOCH_I=$TRAINING_EPOCH_I 317 | FINE_TESTING_EPOCH_J=$TRAINING_EPOCH_J 318 | FINE_TESTING_EPOCH_STEP=$LR_DECAY_EPOCH_J_STEP 319 | FINE_FUSION_THRESHOLD=0.5 320 | COARSE2FINE_TIMESTAMP1_X=$TRAINING_TIMESTAMP 321 | COARSE2FINE_TIMESTAMP1_Y=$TRAINING_TIMESTAMP 322 | COARSE2FINE_TIMESTAMP1_Z=$TRAINING_TIMESTAMP 323 | COARSE2FINE_TIMESTAMP2_X=$TRAINING_TIMESTAMP 324 | COARSE2FINE_TIMESTAMP2_Y=$TRAINING_TIMESTAMP 325 | COARSE2FINE_TIMESTAMP2_Z=$TRAINING_TIMESTAMP 326 | MAX_ROUNDS=10 327 | # coarse2fine_testing.py : data_path, current_fold, organ_number, low_range, high_range, 328 | # slice_threshold, slice_thickness, organ_ID, GPU_ID, 329 | # learning_rate1, learning_rate_m1, learning_rate2, learning_rate_m2, margin, 330 | # coarse_fusion_starting_iterations, coarse_fusion_step, coarse_fusion_max_iterations, 331 | # coarse_fusion_threshold, coarse_fusion_code, 332 | # EPOCH_S, EPOCH_I, EPOCH_J, EPOCH_STEP, 333 | # fine_fusion_threshold, max_rounds, 334 | # timestamp1_X, timestamp1_Y, timestamp1_Z, 335 | # timestamp2_X (optional), timestamp2_Y (optional), timestamp2_Z (optional) 336 | if [ "$ENABLE_COARSE2FINE_TESTING" = "1" ] 337 | then 338 | python coarse2fine_testing.py \ 339 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 340 | $SLICE_THRESHOLD $SLICE_THICKNESS $COARSE2FINE_TESTING_ORGAN_ID $COARSE2FINE_TESTING_GPU \ 341 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 $TRAINING_MARGIN \ 342 | $FINE_TESTING_EPOCH_S $FINE_TESTING_EPOCH_I $FINE_TESTING_EPOCH_J $FINE_TESTING_EPOCH_STEP \ 343 | $COARSE_FUSION_THRESHOLD $FINE_FUSION_THRESHOLD $MAX_ROUNDS \ 344 | $COARSE2FINE_TIMESTAMP1_X $COARSE2FINE_TIMESTAMP1_Y $COARSE2FINE_TIMESTAMP1_Z \ 345 | $COARSE2FINE_TIMESTAMP2_X $COARSE2FINE_TIMESTAMP2_Y $COARSE2FINE_TIMESTAMP2_Z \ 346 | | tee ./logs/FD${CURRENT_FOLD}_test_${TRAINING_TIMESTAMP}.txt 347 | fi 348 | 349 | #################################################################################################### 350 | -------------------------------------------------------------------------------- /f3.sh: -------------------------------------------------------------------------------- 1 | #################################################################################################### 2 | # RSTN: Recurrent Saliency Transformation Network for organ segmentation framework # 3 | # This is PyTorch 0.4.0 Python 3.6 verison of OrganSegRSTN in CAFFE Python 2.7 . # 4 | # Author: Tianwei Ni, Huangjie Zheng, Lingxi Xie. # 5 | # # 6 | # If you use our codes, please cite our paper accordingly: # 7 | # Qihang Yu, Lingxi Xie, Yan Wang, Yuyin Zhou, Elliot K. Fishman, Alan L. Yuille, # 8 | # "Recurrent Saliency Transformation Network: # 9 | # Incorporating Multi-Stage Visual Cues for Small Organ Segmentation", # 10 | # in IEEE Conference on Computer Vision and Pattern Recognition, 2018. # 11 | # # 12 | # NOTE: this program can be used for multi-organ segmentation. # 13 | # Please also refer to its previous version, OrganSegC2F. # 14 | #################################################################################################### 15 | 16 | #################################################################################################### 17 | # variables for conveniencer 18 | CURRENT_ORGAN_ID=1 19 | CURRENT_PLANE=A 20 | CURRENT_FOLD=3 21 | CURRENT_GPU=0 22 | 23 | #################################################################################################### 24 | # turn on these switches to execute each module 25 | ENABLE_INITIALIZATION=0 26 | ENABLE_TRAINING=1 27 | ENABLE_COARSE_TESTING=1 28 | ENABLE_COARSE_FUSION=1 29 | ENABLE_ORACLE_TESTING=1 30 | ENABLE_ORACLE_FUSION=1 31 | ENABLE_COARSE2FINE_TESTING=1 32 | # training settings: X|Y|Z 33 | TRAINING_ORGAN_ID=$CURRENT_ORGAN_ID 34 | TRAINING_PLANE=$CURRENT_PLANE 35 | TRAINING_GPU=$CURRENT_GPU 36 | # coarse_testing settings: X|Y|Z, before this, coarse-scaled models shall be ready 37 | COARSE_TESTING_ORGAN_ID=$CURRENT_ORGAN_ID 38 | COARSE_TESTING_PLANE=$CURRENT_PLANE 39 | COARSE_TESTING_GPU=$CURRENT_GPU 40 | # coarse_fusion settings: before this, coarse-scaled results on 3 views shall be ready 41 | COARSE_FUSION_ORGAN_ID=$CURRENT_ORGAN_ID 42 | # oracle_testing settings: X|Y|Z, before this, fine-scaled models shall be ready 43 | ORACLE_TESTING_ORGAN_ID=$CURRENT_ORGAN_ID 44 | ORACLE_TESTING_PLANE=$CURRENT_PLANE 45 | ORACLE_TESTING_GPU=$CURRENT_GPU 46 | # oracle_fusion settings: before this, fine-scaled results on 3 views shall be ready 47 | ORACLE_FUSION_ORGAN_ID=$CURRENT_ORGAN_ID 48 | # fine_testing settings: before this, both coarse-scaled and fine-scaled models shall be ready 49 | COARSE2FINE_TESTING_ORGAN_ID=$CURRENT_ORGAN_ID 50 | COARSE2FINE_TESTING_GPU=$CURRENT_GPU 51 | 52 | #################################################################################################### 53 | # defining the root path which stores image and label data 54 | DATA_PATH='/home/datasets/Pancreas82NIH/' 55 | 56 | #################################################################################################### 57 | # data initialization: only needs to be run once 58 | # variables 59 | ORGAN_NUMBER=1 60 | FOLDS=4 61 | LOW_RANGE=-100 62 | HIGH_RANGE=240 63 | # init.py : data_path, organ_number, folds, low_range, high_range 64 | if [ "$ENABLE_INITIALIZATION" = "1" ] 65 | then 66 | python init.py \ 67 | $DATA_PATH $ORGAN_NUMBER $FOLDS $LOW_RANGE $HIGH_RANGE 68 | fi 69 | 70 | #################################################################################################### 71 | # the individual and joint training processes 72 | # variables 73 | SLICE_THRESHOLD=0.98 74 | SLICE_THICKNESS=3 75 | LEARNING_RATE1=1e-5 76 | LEARNING_RATE2=1e-5 77 | LEARNING_RATE_M1=10 78 | LEARNING_RATE_M2=10 79 | TRAINING_MARGIN=20 80 | TRAINING_PROB=0.5 81 | TRAINING_SAMPLE_BATCH=1 82 | TRAINING_EPOCH_S=6 83 | TRAINING_EPOCH_I=6 84 | TRAINING_EPOCH_J=6 85 | LR_DECAY_EPOCH_J_STEP=2 86 | if [ "$ENABLE_TRAINING" = "1" ] 87 | then 88 | TRAINING_TIMESTAMP=$(date +'%Y%m%d_%H%M%S') 89 | else 90 | TRAINING_TIMESTAMP='20220221_210935' 91 | fi 92 | # training.py : data_path, current_fold, organ_number, low_range, high_range, 93 | # slice_threshold, slice_thickness, organ_ID, plane, GPU_ID, 94 | # learning_rate1, learning_rate2 (not used), margin, prob, sample_batch, 95 | # step, ·max_iterations1, max_iterations2 (not used), fraction, timestamp 96 | if [ "$ENABLE_TRAINING" = "1" ] 97 | then 98 | if [ "$TRAINING_PLANE" = "X" ] || [ "$TRAINING_PLANE" = "A" ] 99 | then 100 | TRAINING_MODELNAME=X${SLICE_THICKNESS}_${TRAINING_ORGAN_ID} 101 | TRAINING_LOG=${DATA_PATH}logs/FD${CURRENT_FOLD}:${TRAINING_MODELNAME}_${TRAINING_TIMESTAMP}.txt 102 | python training.py \ 103 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 104 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 105 | $TRAINING_ORGAN_ID X $TRAINING_GPU \ 106 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 107 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 108 | $TRAINING_EPOCH_S $TRAINING_EPOCH_I $TRAINING_EPOCH_J \ 109 | $LR_DECAY_EPOCH_J_STEP $TRAINING_TIMESTAMP 1 2>&1 | tee $TRAINING_LOG 110 | fi 111 | if [ "$TRAINING_PLANE" = "Y" ] || [ "$TRAINING_PLANE" = "A" ] 112 | then 113 | TRAINING_MODELNAME=Y${SLICE_THICKNESS}_${TRAINING_ORGAN_ID} 114 | TRAINING_LOG=${DATA_PATH}logs/FD${CURRENT_FOLD}:${TRAINING_MODELNAME}_${TRAINING_TIMESTAMP}.txt 115 | python training.py \ 116 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 117 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 118 | $TRAINING_ORGAN_ID Y $TRAINING_GPU \ 119 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 120 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 121 | $TRAINING_EPOCH_S $TRAINING_EPOCH_I $TRAINING_EPOCH_J \ 122 | $LR_DECAY_EPOCH_J_STEP $TRAINING_TIMESTAMP 1 2>&1 | tee $TRAINING_LOG 123 | fi 124 | if [ "$TRAINING_PLANE" = "Z" ] || [ "$TRAINING_PLANE" = "A" ] 125 | then 126 | TRAINING_MODELNAME=Z${SLICE_THICKNESS}_${TRAINING_ORGAN_ID} 127 | TRAINING_LOG=${DATA_PATH}logs/FD${CURRENT_FOLD}:${TRAINING_MODELNAME}_${TRAINING_TIMESTAMP}.txt 128 | python training.py \ 129 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 130 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 131 | $TRAINING_ORGAN_ID Z $TRAINING_GPU \ 132 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 133 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 134 | $TRAINING_EPOCH_S $TRAINING_EPOCH_I $TRAINING_EPOCH_J \ 135 | $LR_DECAY_EPOCH_J_STEP $TRAINING_TIMESTAMP 1 2>&1 | tee $TRAINING_LOG 136 | fi 137 | fi 138 | 139 | #################################################################################################### 140 | # the coarse-scaled testing processes 141 | # variables 142 | COARSE_TESTING_EPOCH_S=$TRAINING_EPOCH_S 143 | COARSE_TESTING_EPOCH_I=$TRAINING_EPOCH_I 144 | COARSE_TESTING_EPOCH_J=$TRAINING_EPOCH_J 145 | COARSE_TESTING_EPOCH_STEP=$LR_DECAY_EPOCH_J_STEP 146 | COARSE_TIMESTAMP1=$TRAINING_TIMESTAMP 147 | COARSE_TIMESTAMP2=$TRAINING_TIMESTAMP 148 | # coarse_testing.py : data_path, current_fold, organ_number, low_range, high_range, 149 | # slice_threshold, slice_thickness, organ_ID, plane, GPU_ID, 150 | # learning_rate1, learning_rate2, margin, prob, sample_batch, 151 | # EPOCH_S, EPOCH_I, EPOCH_J, EPOCH_STEP, 152 | # timestamp1, timestamp2 (optional) 153 | if [ "$ENABLE_COARSE_TESTING" = "1" ] 154 | then 155 | if [ "$COARSE_TESTING_PLANE" = "X" ] || [ "$COARSE_TESTING_PLANE" = "A" ] 156 | then 157 | python coarse_testing.py \ 158 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 159 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 160 | $COARSE_TESTING_ORGAN_ID X $COARSE_TESTING_GPU \ 161 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 162 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 163 | $COARSE_TESTING_EPOCH_S $COARSE_TESTING_EPOCH_I \ 164 | $COARSE_TESTING_EPOCH_J $COARSE_TESTING_EPOCH_STEP \ 165 | $COARSE_TIMESTAMP1 $COARSE_TIMESTAMP2 166 | fi 167 | if [ "$COARSE_TESTING_PLANE" = "Y" ] || [ "$COARSE_TESTING_PLANE" = "A" ] 168 | then 169 | python coarse_testing.py \ 170 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 171 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 172 | $COARSE_TESTING_ORGAN_ID Y $COARSE_TESTING_GPU \ 173 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 174 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 175 | $COARSE_TESTING_EPOCH_S $COARSE_TESTING_EPOCH_I \ 176 | $COARSE_TESTING_EPOCH_J $COARSE_TESTING_EPOCH_STEP \ 177 | $COARSE_TIMESTAMP1 $COARSE_TIMESTAMP2 178 | fi 179 | if [ "$COARSE_TESTING_PLANE" = "Z" ] || [ "$COARSE_TESTING_PLANE" = "A" ] 180 | then 181 | python coarse_testing.py \ 182 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 183 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 184 | $COARSE_TESTING_ORGAN_ID Z $COARSE_TESTING_GPU \ 185 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 186 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 187 | $COARSE_TESTING_EPOCH_S $COARSE_TESTING_EPOCH_I \ 188 | $COARSE_TESTING_EPOCH_J $COARSE_TESTING_EPOCH_STEP \ 189 | $COARSE_TIMESTAMP1 $COARSE_TIMESTAMP2 190 | fi 191 | fi 192 | 193 | #################################################################################################### 194 | # the coarse-scaled fusion process 195 | # variables 196 | COARSE_FUSION_EPOCH_S=$TRAINING_EPOCH_S 197 | COARSE_FUSION_EPOCH_I=$TRAINING_EPOCH_I 198 | COARSE_FUSION_EPOCH_J=$TRAINING_EPOCH_J 199 | COARSE_FUSION_EPOCH_STEP=$LR_DECAY_EPOCH_J_STEP 200 | COARSE_FUSION_THRESHOLD=0.5 201 | COARSE_TIMESTAMP1_X=$TRAINING_TIMESTAMP 202 | COARSE_TIMESTAMP1_Y=$TRAINING_TIMESTAMP 203 | COARSE_TIMESTAMP1_Z=$TRAINING_TIMESTAMP 204 | COARSE_TIMESTAMP2_X=$TRAINING_TIMESTAMP 205 | COARSE_TIMESTAMP2_Y=$TRAINING_TIMESTAMP 206 | COARSE_TIMESTAMP2_Z=$TRAINING_TIMESTAMP 207 | # coarse_fusion.py : data_path, current_fold, organ_number, low_range, high_range, 208 | # slice_threshold, slice_thickness, organ_ID, plane, GPU_ID, 209 | # learning_rate1, learning_rate_m1, learning_rate2, learning_rate_m2, margin, 210 | # EPOCH_S, EPOCH_I, EPOCH_J, EPOCH_STEP, threshold, 211 | # timestamp1_X, timestamp1_Y, timestamp1_Z, 212 | # timestamp2_X (optional), timestamp2_Y (optional), timestamp2_Z (optional) 213 | if [ "$ENABLE_COARSE_FUSION" = "1" ] 214 | then 215 | python coarse_fusion.py \ 216 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 217 | $SLICE_THRESHOLD $SLICE_THICKNESS $COARSE_TESTING_ORGAN_ID $COARSE_TESTING_GPU \ 218 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 $TRAINING_MARGIN \ 219 | $COARSE_FUSION_EPOCH_S $COARSE_FUSION_EPOCH_I $COARSE_FUSION_EPOCH_J \ 220 | $COARSE_FUSION_EPOCH_STEP $COARSE_FUSION_THRESHOLD \ 221 | $COARSE_TIMESTAMP1_X $COARSE_TIMESTAMP1_Y $COARSE_TIMESTAMP1_Z \ 222 | $COARSE_TIMESTAMP2_X $COARSE_TIMESTAMP2_Y $COARSE_TIMESTAMP2_Z 223 | fi 224 | 225 | #################################################################################################### 226 | # the oracle testing processes 227 | # variables 228 | ORACLE_TESTING_EPOCH_S=$TRAINING_EPOCH_S 229 | ORACLE_TESTING_EPOCH_I=$TRAINING_EPOCH_I 230 | ORACLE_TESTING_EPOCH_J=$TRAINING_EPOCH_J 231 | ORACLE_TESTING_EPOCH_STEP=$LR_DECAY_EPOCH_J_STEP 232 | ORACLE_TIMESTAMP1=$TRAINING_TIMESTAMP 233 | ORACLE_TIMESTAMP2=$TRAINING_TIMESTAMP 234 | # oracle_testing.py : data_path, current_fold, organ_number, low_range, high_range, 235 | # slice_threshold, slice_thickness, organ_ID, plane, GPU_ID, 236 | # learning_rate1, learning_rate_m1, learning_rate2, learning_rate_m2, 237 | # margin, prob, sample_batch, 238 | # EPOCH_S, EPOCH_I, EPOCH_J, EPOCH_STEP, 239 | # timestamp1, timestamp2 (optional) 240 | if [ "$ENABLE_ORACLE_TESTING" = "1" ] 241 | then 242 | if [ "$ORACLE_TESTING_PLANE" = "X" ] || [ "$ORACLE_TESTING_PLANE" = "A" ] 243 | then 244 | python oracle_testing.py \ 245 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 246 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 247 | $ORACLE_TESTING_ORGAN_ID X $ORACLE_TESTING_GPU \ 248 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 249 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 250 | $ORACLE_TESTING_EPOCH_S $ORACLE_TESTING_EPOCH_I \ 251 | $ORACLE_TESTING_EPOCH_J $ORACLE_TESTING_EPOCH_STEP \ 252 | $ORACLE_TIMESTAMP1 $ORACLE_TIMESTAMP2 253 | fi 254 | if [ "$ORACLE_TESTING_PLANE" = "Y" ] || [ "$ORACLE_TESTING_PLANE" = "A" ] 255 | then 256 | python oracle_testing.py \ 257 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 258 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 259 | $ORACLE_TESTING_ORGAN_ID Y $ORACLE_TESTING_GPU \ 260 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 261 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 262 | $ORACLE_TESTING_EPOCH_S $ORACLE_TESTING_EPOCH_I \ 263 | $ORACLE_TESTING_EPOCH_J $ORACLE_TESTING_EPOCH_STEP \ 264 | $ORACLE_TIMESTAMP1 $ORACLE_TIMESTAMP2 265 | fi 266 | if [ "$ORACLE_TESTING_PLANE" = "Z" ] || [ "$ORACLE_TESTING_PLANE" = "A" ] 267 | then 268 | python oracle_testing.py \ 269 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 270 | $SLICE_THRESHOLD $SLICE_THICKNESS \ 271 | $ORACLE_TESTING_ORGAN_ID Z $ORACLE_TESTING_GPU \ 272 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 \ 273 | $TRAINING_MARGIN $TRAINING_PROB $TRAINING_SAMPLE_BATCH \ 274 | $ORACLE_TESTING_EPOCH_S $ORACLE_TESTING_EPOCH_I \ 275 | $ORACLE_TESTING_EPOCH_J $ORACLE_TESTING_EPOCH_STEP \ 276 | $ORACLE_TIMESTAMP1 $ORACLE_TIMESTAMP2 277 | fi 278 | fi 279 | 280 | #################################################################################################### 281 | # the oracle-scaled fusion process 282 | # variables 283 | ORACLE_FUSION_EPOCH_S=$TRAINING_EPOCH_S 284 | ORACLE_FUSION_EPOCH_I=$TRAINING_EPOCH_I 285 | ORACLE_FUSION_EPOCH_J=$TRAINING_EPOCH_J 286 | ORACLE_FUSION_EPOCH_STEP=$LR_DECAY_EPOCH_J_STEP 287 | ORACLE_FUSION_THRESHOLD=0.5 288 | ORACLE_TIMESTAMP1_X=$TRAINING_TIMESTAMP 289 | ORACLE_TIMESTAMP1_Y=$TRAINING_TIMESTAMP 290 | ORACLE_TIMESTAMP1_Z=$TRAINING_TIMESTAMP 291 | ORACLE_TIMESTAMP2_X=$TRAINING_TIMESTAMP 292 | ORACLE_TIMESTAMP2_Y=$TRAINING_TIMESTAMP 293 | ORACLE_TIMESTAMP2_Z=$TRAINING_TIMESTAMP 294 | # oracle_fusion.py : data_path, current_fold, organ_number, low_range, high_range, 295 | # slice_threshold, slice_thickness, organ_ID, plane, GPU_ID, 296 | # learning_rate1, learning_rate_m1, learning_rate2, learning_rate_m2, margin, 297 | # EPOCH_S, EPOCH_I, EPOCH_J, EPOCH_STEP, threshold, 298 | # timestamp1_X, timestamp1_Y, timestamp1_Z, 299 | # timestamp2_X (optional), timestamp2_Y (optional), timestamp2_Z (optional) 300 | if [ "$ENABLE_ORACLE_FUSION" = "1" ] 301 | then 302 | python oracle_fusion.py \ 303 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 304 | $SLICE_THRESHOLD $SLICE_THICKNESS $ORACLE_TESTING_ORGAN_ID $ORACLE_TESTING_GPU \ 305 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 $TRAINING_MARGIN \ 306 | $ORACLE_FUSION_EPOCH_S $ORACLE_FUSION_EPOCH_I $ORACLE_FUSION_EPOCH_J \ 307 | $ORACLE_FUSION_EPOCH_STEP $ORACLE_FUSION_THRESHOLD \ 308 | $ORACLE_TIMESTAMP1_X $ORACLE_TIMESTAMP1_Y $ORACLE_TIMESTAMP1_Z \ 309 | $ORACLE_TIMESTAMP2_X $ORACLE_TIMESTAMP2_Y $ORACLE_TIMESTAMP2_Z 310 | fi 311 | 312 | #################################################################################################### 313 | # the coarse-to-fine testing process 314 | # variables 315 | FINE_TESTING_EPOCH_S=$TRAINING_EPOCH_S 316 | FINE_TESTING_EPOCH_I=$TRAINING_EPOCH_I 317 | FINE_TESTING_EPOCH_J=$TRAINING_EPOCH_J 318 | FINE_TESTING_EPOCH_STEP=$LR_DECAY_EPOCH_J_STEP 319 | FINE_FUSION_THRESHOLD=0.5 320 | COARSE2FINE_TIMESTAMP1_X=$TRAINING_TIMESTAMP 321 | COARSE2FINE_TIMESTAMP1_Y=$TRAINING_TIMESTAMP 322 | COARSE2FINE_TIMESTAMP1_Z=$TRAINING_TIMESTAMP 323 | COARSE2FINE_TIMESTAMP2_X=$TRAINING_TIMESTAMP 324 | COARSE2FINE_TIMESTAMP2_Y=$TRAINING_TIMESTAMP 325 | COARSE2FINE_TIMESTAMP2_Z=$TRAINING_TIMESTAMP 326 | MAX_ROUNDS=10 327 | # coarse2fine_testing.py : data_path, current_fold, organ_number, low_range, high_range, 328 | # slice_threshold, slice_thickness, organ_ID, GPU_ID, 329 | # learning_rate1, learning_rate_m1, learning_rate2, learning_rate_m2, margin, 330 | # coarse_fusion_starting_iterations, coarse_fusion_step, coarse_fusion_max_iterations, 331 | # coarse_fusion_threshold, coarse_fusion_code, 332 | # EPOCH_S, EPOCH_I, EPOCH_J, EPOCH_STEP, 333 | # fine_fusion_threshold, max_rounds, 334 | # timestamp1_X, timestamp1_Y, timestamp1_Z, 335 | # timestamp2_X (optional), timestamp2_Y (optional), timestamp2_Z (optional) 336 | if [ "$ENABLE_COARSE2FINE_TESTING" = "1" ] 337 | then 338 | python coarse2fine_testing.py \ 339 | $DATA_PATH $CURRENT_FOLD $ORGAN_NUMBER $LOW_RANGE $HIGH_RANGE \ 340 | $SLICE_THRESHOLD $SLICE_THICKNESS $COARSE2FINE_TESTING_ORGAN_ID $COARSE2FINE_TESTING_GPU \ 341 | $LEARNING_RATE1 $LEARNING_RATE_M1 $LEARNING_RATE2 $LEARNING_RATE_M2 $TRAINING_MARGIN \ 342 | $FINE_TESTING_EPOCH_S $FINE_TESTING_EPOCH_I $FINE_TESTING_EPOCH_J $FINE_TESTING_EPOCH_STEP \ 343 | $COARSE_FUSION_THRESHOLD $FINE_FUSION_THRESHOLD $MAX_ROUNDS \ 344 | $COARSE2FINE_TIMESTAMP1_X $COARSE2FINE_TIMESTAMP1_Y $COARSE2FINE_TIMESTAMP1_Z \ 345 | $COARSE2FINE_TIMESTAMP2_X $COARSE2FINE_TIMESTAMP2_Y $COARSE2FINE_TIMESTAMP2_Z \ 346 | | tee ./logs/FD${CURRENT_FOLD}_test_${TRAINING_TIMESTAMP}.txt 347 | fi 348 | 349 | #################################################################################################### 350 | --------------------------------------------------------------------------------