├── LICENSE ├── README.md ├── ScatterNet.py ├── scatter_correct.py ├── train_network.py └── trained_models └── model_103.trch /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 David Christoffer Hansen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository contains the code and the trained model for the article **ScatterNet: a convolutional neural network for cone-beam CT intensity correction**. 2 | 3 | 4 | **scatter_correct.py** - Scatter corrects the input mha files using the pretrained model. Please not that this is a proof of concept. 5 | 6 | **train_network.py** - The code for training the model. This needs to be edited to load your datasets. 7 | 8 | This code should not be used clinically without thorough testing, and even then no guarantees are made for correctness, usefulness or applicability. -------------------------------------------------------------------------------- /ScatterNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | class SqueezeExcitation(nn.Module): 8 | def __init__(self,channels,squeeze_channels=None): 9 | if squeeze_channels is None: 10 | squeeze_channels = channels//8 11 | super(SqueezeExcitation,self).__init__() 12 | 13 | self.channels = channels 14 | self.fc1 = nn.Conv2d(channels,squeeze_channels,kernel_size=1) 15 | self.fc2 = nn.Conv2d(squeeze_channels,channels,kernel_size=1) 16 | 17 | def forward(self,x): 18 | out = F.avg_pool2d(x,x.size()[2:]) 19 | out = F.relu(self.fc1(out)) 20 | out = self.fc2(out) 21 | out = F.sigmoid(out) 22 | return x*out 23 | 24 | class DownBlock(nn.Module): 25 | def __init__(self,inchannels,channels,activation=nn.ReLU, batchnorm=False,squeeze = False,residual=True): 26 | super(DownBlock,self).__init__() 27 | self.residual=residual 28 | self.activation1 = activation() 29 | self.activation2 = activation() 30 | self.activation3 = activation() 31 | self.downconv = nn.Conv2d(inchannels,channels,kernel_size=2,stride=2,padding=1) 32 | 33 | self.conv1 = nn.Conv2d(channels,channels,kernel_size=3,padding=1) 34 | self.conv2 = nn.Conv2d(channels,channels,kernel_size=3,padding=1) 35 | if (batchnorm): 36 | self.bnorm1 = nn.BatchNorm3d(channels) 37 | self.bnorm2 = nn.BatchNorm3d(channels) 38 | self.bnorm3 = nn.BatchNorm3d(channels) 39 | if (squeeze): 40 | self.squeeze = SqueezeExcitation(channels) 41 | else: 42 | self.squeeze = None 43 | self.batchnorm=batchnorm 44 | 45 | def forward(self,x): 46 | 47 | down = self.downconv(x) 48 | if (self.batchnorm): 49 | down = self.bnorm1(down) 50 | down = self.activation1(down) 51 | # print("Down",down.size()) 52 | out = self.conv1(down) 53 | if (self.batchnorm): 54 | out = self.bnorm2(out) 55 | out = self.activation2(out) 56 | out = self.conv2(out) 57 | if (self.batchnorm): 58 | out = self.bnorm3(out) 59 | if self.squeeze is not None: 60 | out = self.squeeze(out) 61 | if self.residual: 62 | out += down 63 | out = self.activation3(out) 64 | return out 65 | 66 | 67 | class ResBlock(nn.Module): 68 | def __init__(self,inchannels,channels,activation=nn.ReLU,batchnorm = False, squeeze = False,residual=True): 69 | super(ResBlock,self).__init__() 70 | self.residual = residual 71 | self.activation1 = activation() 72 | self.activation2 = activation() 73 | self.activation3 = activation() 74 | self.conv0 = nn.Conv2d(inchannels,channels,kernel_size=3,padding=1) 75 | if (batchnorm): 76 | self.bnorm1 = nn.BatchNorm3d(channels) 77 | self.bnorm2 = nn.BatchNorm3d(channels) 78 | self.bnorm3 = nn.BatchNorm3d(channels) 79 | self.conv1 = nn.ConvTranspose2d(channels,channels,kernel_size=3,padding=1) 80 | self.conv2 = nn.ConvTranspose2d(channels,channels,kernel_size=3,padding=1) 81 | self.batchnorm = batchnorm 82 | if squeeze: 83 | self.squeeze = SqueezeExcitation(channels) 84 | else: 85 | self.squeeze = None 86 | 87 | def forward(self,x): 88 | up = self.conv0(x) 89 | if self.batchnorm: 90 | up = self.bnorm1(up) 91 | up = self.activation1(up) 92 | # print("Up",up.size()) 93 | out = self.conv1(up) 94 | if self.batchnorm: 95 | out = self.bnorm2(out) 96 | out = self.activation2(out) 97 | out = self.conv2(out) 98 | if self.batchnorm: 99 | out = self.bnorm3(out) 100 | if self.squeeze is not None: 101 | out = self.squeeze(out) 102 | if self.residual: 103 | out += up 104 | 105 | out = self.activation3(out) 106 | return out 107 | class UpBlock(nn.Module): 108 | def __init__(self,inchannels,channels,activation=nn.ReLU,batchnorm = False, squeeze = False,residual=True): 109 | super(UpBlock,self).__init__() 110 | self.residual=residual 111 | self.activation1 = activation() 112 | self.activation2 = activation() 113 | self.activation3 = activation() 114 | self.upconv = nn.Conv2d(inchannels,channels,kernel_size=3,padding=1) 115 | if (batchnorm): 116 | self.bnorm1 = nn.BatchNorm3d(channels) 117 | self.bnorm2 = nn.BatchNorm3d(channels) 118 | self.bnorm3 = nn.BatchNorm3d(channels) 119 | self.conv1 = nn.ConvTranspose2d(channels,channels,kernel_size=3,padding=1) 120 | self.conv2 = nn.ConvTranspose2d(channels,channels,kernel_size=3,padding=1) 121 | self.batchnorm = batchnorm 122 | if squeeze: 123 | self.squeeze = SqueezeExcitation(channels) 124 | else: 125 | self.squeeze = None 126 | 127 | def forward(self,x): 128 | up = F.upsample(x,scale_factor=2) 129 | up = self.upconv(up) 130 | if self.batchnorm: 131 | up = self.bnorm1(up) 132 | up = self.activation1(up) 133 | # print("Up",up.size()) 134 | out = self.conv1(up) 135 | if self.batchnorm: 136 | out = self.bnorm2(out) 137 | out = self.activation2(out) 138 | out = self.conv2(out) 139 | if self.batchnorm: 140 | out = self.bnorm3(out) 141 | if self.squeeze is not None: 142 | out = self.squeeze(out) 143 | if self.residual: 144 | out += up 145 | 146 | out = self.activation3(out) 147 | return out 148 | 149 | 150 | class ConvertNet(nn.Module): 151 | def __init__(self,init_channels,activation= nn.ReLU): 152 | super(ConvertNet,self).__init__() 153 | self.conv1 = nn.Sequential(nn.Conv2d(1,init_channels,kernel_size=3,padding=1),activation(), 154 | nn.Conv2d(init_channels,1,kernel_size=3,padding=1),activation()) 155 | 156 | def forward(self,x): 157 | return self.conv1(x) 158 | 159 | 160 | class ScatterNet(nn.Module): 161 | def __init__(self,init_channels, layer_channels,batchnorm = False, squeeze = False, skip_first = False,activation = nn.ReLU,exp = False,residual=True): 162 | """ 163 | 164 | :type exp: bool 165 | """ 166 | super(ScatterNet,self).__init__() 167 | 168 | self.activation = activation 169 | self.conv1 = ConvertNet(init_channels,activation=activation) 170 | 171 | self.conv2 = ResBlock(1,layer_channels[0],activation=activation,batchnorm=batchnorm,squeeze=squeeze,residual=residual) 172 | self.upblocks = nn.ModuleList() 173 | self.downblocks = nn.ModuleList() 174 | previous_channels = layer_channels[0] 175 | for channels in layer_channels[1:]: 176 | self.downblocks.append(DownBlock(previous_channels, channels, self.activation,batchnorm,squeeze,residual=residual)) 177 | previous_channels = channels 178 | 179 | self.mixBlock = nn.ModuleList() 180 | for channels in reversed(layer_channels[:-1]): 181 | self.mixBlock.append(nn.Sequential(nn.Conv2d(channels*2,channels,kernel_size=1),self.activation())) 182 | self.upblocks.append(UpBlock(previous_channels,channels,self.activation,batchnorm,squeeze,residual=residual)) 183 | previous_channels = channels 184 | 185 | 186 | self.dconvFinal = nn.ConvTranspose2d(layer_channels[0],1,kernel_size=1,padding=0) 187 | self.skip_first = skip_first 188 | self.exp = exp 189 | 190 | def forward(self,x): 191 | if self.skip_first: 192 | level1 = x 193 | else: 194 | level1 = self.conv1(x) 195 | if self.exp: 196 | level1 = torch.exp(-level1)*2**(16) 197 | 198 | previous = self.conv2(level1) 199 | layers = [previous] 200 | xsize = x.size() 201 | for block in self.downblocks: 202 | previous = block(previous) 203 | layers.append(previous) 204 | 205 | layers = list(reversed(layers[:-1])) 206 | for block,shortcut,mixer in zip(self.upblocks,layers,self.mixBlock): 207 | previous = block(previous) 208 | psize = previous.size() 209 | ssize = shortcut.size() 210 | if (psize != ssize): 211 | diff = np.array(ssize,dtype=int) - np.array(psize,dtype=int) 212 | # print(diff) 213 | previous = F.pad(previous,(0,int(diff[-1]),0,int(diff[-2])),mode="replicate") 214 | # print(previous.size(),shortcut.size()) 215 | previous = torch.cat([previous,shortcut],dim=1) 216 | previous = mixer(previous) 217 | 218 | 219 | previous = self.dconvFinal(previous) 220 | if self.skip_first: 221 | return previous 222 | 223 | if self.exp: 224 | previous = torch.clamp(level1-previous,min=1e-6) 225 | # previous = -torch.log(previous) 226 | # previous = -torch.log(torch.clamp(previous,1e-6,1)) 227 | else: 228 | previous = previous+level1 229 | return previous 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | -------------------------------------------------------------------------------- /scatter_correct.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as sitk 2 | import torch 3 | import torch.backends.cudnn as cudnn 4 | cudnn.benchmark = True 5 | from torch.utils.data import DataLoader,TensorDataset 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import numpy as np 9 | import os 10 | import argparse 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument("files",nargs='+',help='Projection files to scatter correct') 16 | parser.add_argument('--output_dir',help="Output directory") 17 | 18 | args = parser.parse_args() 19 | 20 | import ScatterNet 21 | 22 | model = ScatterNet.ScatterNet(init_channels=8,layer_channels=[8,16,32,64,128,256],batchnorm=False,squeeze=False, 23 | activation=nn.PReLU,exp=False, 24 | skip_first=False,residual=True) 25 | 26 | 27 | state_dict = torch.load("trained_models/model_103.trch") #Model used for article 28 | model = nn.DataParallel(model) 29 | model.load_state_dict(state_dict) 30 | 31 | for proj_file in args.files: 32 | print(dir) 33 | stk_projections = sitk.ReadImage(proj_file) 34 | 35 | 36 | 37 | data = sitk.GetArrayFromImage(stk_projections) 38 | 39 | print("Loaded") 40 | data = np.pad(data,[(0,0),(4,4),(4,4)],mode="edge") 41 | 42 | print("Padded") 43 | 44 | 45 | 46 | loader= DataLoader(TensorDataset(torch.from_numpy(data[:,np.newaxis,...]),torch.from_numpy(data[:,np.newaxis,...])),batch_size=8,pin_memory=True) 47 | 48 | total_projections = [] 49 | for projections,_ in loader: 50 | with torch.no_grad(): 51 | var = Variable(projections.float()) 52 | data_net_corrected = model(var) 53 | 54 | data_net_corrected = data_net_corrected.data.cpu().numpy() 55 | total_projections.append(data_net_corrected) 56 | 57 | total_projections = np.concatenate(total_projections,axis=0)[:,0,...] 58 | total_projections[np.isinf(total_projections)] = 0 59 | total_projections = total_projections[:,4:-4,4:-4] 60 | 61 | 62 | 63 | total_projections = sitk.GetImageFromArray(total_projections) 64 | total_projections.CopyInformation(stk_projections) 65 | sitk.WriteImage(total_projections,args.output_dir + "/" + os.path.basename(proj_file)) 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /train_network.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as sitk 2 | import numpy as np 3 | import torch 4 | import torch.backends.cudnn 5 | import torch.nn as nn 6 | import torch.optim.lr_scheduler 7 | from torch.autograd import Variable 8 | from torch.utils.data import TensorDataset, DataLoader 9 | 10 | torch.backends.cudnn.benchmark = True 11 | 12 | 13 | def load_mha(mha_file): 14 | itk = sitk.ReadImage(mha_file) 15 | return sitk.GetArrayFromImage(itk) 16 | 17 | 18 | def load_projections(files): 19 | projections = [load_mha(f) for f in files] 20 | projections = np.concatenate(projections, 0) 21 | projections = np.pad(projections, [(0, 0), (4, 4), (4, 4)], mode="edge") 22 | projections[projections < 0] = 0 23 | return projections[:, np.newaxis, ...] 24 | 25 | 26 | class ProjectionDatasSet(TensorDataset): 27 | # Mixup 28 | def __init__(self, data_array, target_array, distribution=np.random.rand): 29 | super(ProjectionDatasSet, self).__init__(data_array, target_array) 30 | self.distribution = distribution 31 | 32 | def __getitem__(self, item): 33 | data, target = super(ProjectionDatasSet, self).__getitem__(item) 34 | 35 | other_item = np.random.randint(0, self.__len__()) 36 | 37 | mix = self.distribution() 38 | data2, target2 = super(ProjectionDatasSet, self).__getitem__(other_item) 39 | 40 | data_mixed = data * mix + data2 * (1 - mix) 41 | target_mixed = target * mix + target2 * (1 - mix) 42 | 43 | return data_mixed, target_mixed 44 | 45 | 46 | class AverageMeter(object): 47 | """Computes and stores the average and current value""" 48 | 49 | def __init__(self): 50 | self.reset() 51 | 52 | def reset(self): 53 | self.val = 0 54 | self.avg = 0 55 | self.sum = 0 56 | self.count = 0 57 | 58 | def update(self, val, n=1): 59 | self.val = val 60 | self.sum += val * n 61 | self.count += n 62 | self.avg = self.sum / self.count 63 | 64 | 65 | def is_Conv_type(m): 66 | return isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d) or isinstance(m, nn.Conv2d) or \ 67 | isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or \ 68 | isinstance(m, nn.Linear) 69 | 70 | 71 | def InitModel(model): 72 | for m in model.modules(): 73 | if is_Conv_type(m): 74 | nn.init.orthogonal(m.weight.data) 75 | 76 | 77 | import time 78 | import tensorboardX 79 | 80 | writer = tensorboardX.SummaryWriter() 81 | 82 | use_cuda = True 83 | 84 | 85 | def train_model(model, optimizer, dset_loaders, num_epochs=200, scheduler=None, start_epoch=0, criterion=nn.MSELoss()): 86 | since = time.time() 87 | batch_time = AverageMeter() 88 | running_loss = {"val": AverageMeter(), "train": AverageMeter()} 89 | best_model = model 90 | best_acc = 0.0 91 | 92 | for epoch in range(start_epoch, num_epochs): 93 | # optimizer.update_step() 94 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 95 | print('-' * 10) 96 | k = 0 97 | 98 | for phase in ['train', 'val']: 99 | 100 | if phase == 'train': 101 | model.train(True) # Set model to training mode 102 | else: 103 | model.train(False) # Set model to evaluate mode 104 | 105 | running_corrects = 0 106 | i = 0 107 | report = 200 108 | 109 | # Iterate over data. 110 | for data in dset_loaders[phase]: 111 | # get the inputs 112 | inputs, targets = data 113 | 114 | # wrap them in Variable 115 | if use_cuda: 116 | inputs, targets = Variable(inputs.float().cuda(async=True)), Variable( 117 | targets.float().cuda(async=True)) 118 | else: 119 | inputs, targets = Variable(inputs.float()), Variable( 120 | targets.float()) 121 | 122 | if phase == "val": 123 | inputs.volatile = True 124 | targets.volatile = True 125 | # zero the parameter gradients 126 | optimizer.zero_grad() 127 | 128 | # forward 129 | outputs = model(inputs) 130 | 131 | loss = criterion(outputs, targets) 132 | 133 | # backward + optimize only if in training phase 134 | if phase == "train": 135 | loss.backward() 136 | optimizer.step() 137 | 138 | if isinstance(criterion, nn.MSELoss): 139 | base_loss = loss 140 | else: 141 | base_loss = nn.MSELoss()(torch.log(outputs), torch.log(targets)) 142 | print("Penguins were here") 143 | print("Loss", base_loss.data[0], "Time", 144 | time.time() - since) 145 | batch_time.update(time.time() - since) 146 | since = time.time() 147 | # statistics 148 | 149 | running_loss[phase].update(base_loss.data[0], n=outputs.size()[0]) 150 | i += 1 151 | 152 | writer.add_scalar('Loss_' + phase, running_loss[phase].avg, epoch) 153 | if phase == 'train': 154 | if scheduler is not None: 155 | scheduler.step(running_loss[phase].avg) 156 | 157 | batch_time.reset() 158 | 159 | for phase in ["val", "train"]: 160 | running_loss[phase].reset() 161 | 162 | torch.save(model.state_dict(), writer.file_writer.get_logdir() + "/model_" + str(epoch) + ".trch") 163 | 164 | print() 165 | 166 | writer.close() 167 | return model 168 | 169 | 170 | training_patients = [2, 3, 4, 5, 7, 19, 21, 22, 23, 24, 25, 26, 27, 28, 29] 171 | 172 | projection_files = ["NewProjections/CBCTcor" + str(k) + "/ProjectionData/CBCT_projections_rtk_binned.mha" for k in 173 | training_patients] 174 | corprojection_files = ["NewProjections/CBCTcor" + str(k) + "/ProjectionData/CBCT_projections_cor_CF_1.6.mha" for k in 175 | training_patients] 176 | 177 | test_patients = [8, 9, 10, 12, 13, 14, 15] 178 | 179 | test_projection_files = ["NewProjections/CBCTcor" + str(k) + "/ProjectionData/CBCT_projections_rtk_binned.mha" for k in 180 | test_patients] 181 | test_corprojection_files = ["NewProjections/CBCTcor" + str(k) + "/ProjectionData/CBCT_projections_cor_CF_1.6.mha" for k 182 | in test_patients] 183 | 184 | distribution = np.random.rand 185 | 186 | train_loader = DataLoader(TensorDataset(torch.from_numpy(load_projections(projection_files)), 187 | torch.from_numpy(load_projections(corprojection_files))), batch_size=8, 188 | shuffle=True, pin_memory=True) 189 | test_loader = DataLoader(TensorDataset(torch.from_numpy(load_projections(test_projection_files)), 190 | torch.from_numpy(load_projections(test_corprojection_files))), batch_size=8, 191 | shuffle=False, pin_memory=True) 192 | 193 | import ScatterNet 194 | 195 | model = ScatterNet.ScatterNet(init_channels=8, layer_channels=[8, 16, 32, 64, 128, 256], batchnorm=False, squeeze=False, 196 | activation=nn.PReLU, exp=False, 197 | skip_first=False, residual=True) 198 | 199 | InitModel(model) 200 | 201 | torch.save(model, writer.file_writer.get_logdir() + "/base_model.trch") 202 | 203 | dummy_data = None 204 | 205 | if use_cuda: 206 | model = nn.DataParallel(model).cuda() 207 | 208 | optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0) 209 | 210 | train_model(model, optimizer, {"val": test_loader, "train": train_loader}, num_epochs=10000) 211 | -------------------------------------------------------------------------------- /trained_models/model_103.trch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dchansen/ScatterNet/1c5ea5d03ee838a3410fc0c2c12188e689d8d68b/trained_models/model_103.trch --------------------------------------------------------------------------------