├── .gitattributes ├── Data ├── L067 │ └── 024 │ │ └── ND │ │ └── Data.mat └── L506 │ └── 001 │ └── ND │ └── Data.mat ├── Model_save └── VVtensor_sparse_96_Unet_L2 │ ├── result │ ├── L067_001.mat │ └── L506_001.mat │ ├── SinoIndices_save │ ├── VVtenser_96_fan_Weight2.dat │ └── VVtenser_96_fan_indices.dat │ └── Model_save │ └── best_train_model_params_1_VVtensor_sparse_96_Unet_L2.pkl ├── README.md ├── Model ├── common.py └── networks_1.py └── VVtensor_sparse_96_Unet_L2.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.dat filter=lfs diff=lfs merge=lfs -text 2 | *.pkl filter=lfs diff=lfs merge=lfs -text 3 | *.mat filter=lfs diff=lfs merge=lfs -text 4 | -------------------------------------------------------------------------------- /Data/L067/024/ND/Data.mat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:100da41c9a5c856449ab517dc599d6494168757f331177d915be88ba1e07776e 3 | size 10550136 4 | -------------------------------------------------------------------------------- /Data/L506/001/ND/Data.mat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b085c021a541df1920fbe7b6abcde08a8c7217bcb53c7805a1f71a44fcfc821b 3 | size 10451574 4 | -------------------------------------------------------------------------------- /Model_save/VVtensor_sparse_96_Unet_L2/result/L067_001.mat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9862c71fb51a21ac817c6248ac88c59865db29cfd2de6779f7f336f1a3a6b58b 3 | size 3428832 4 | -------------------------------------------------------------------------------- /Model_save/VVtensor_sparse_96_Unet_L2/result/L506_001.mat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c4af6e8eb75983888d9dae2c1fec2a603ff58b5174bfc55cc0a125bbc095cca7 3 | size 3428832 4 | -------------------------------------------------------------------------------- /Model_save/VVtensor_sparse_96_Unet_L2/SinoIndices_save/VVtenser_96_fan_Weight2.dat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0c0d2aed4a9d3e864104fccb57319d06628c6381b6a41b52634d9adc0efcaf6e 3 | size 134379798 4 | -------------------------------------------------------------------------------- /Model_save/VVtensor_sparse_96_Unet_L2/SinoIndices_save/VVtenser_96_fan_indices.dat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:76d52e230af8e0030965a15de70552e7143b583afdb4292d539c17cdcf3b257e 3 | size 135506173 4 | -------------------------------------------------------------------------------- /Model_save/VVtensor_sparse_96_Unet_L2/Model_save/best_train_model_params_1_VVtensor_sparse_96_Unet_L2.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:442f5d23995956724aeb79b9492784c38c8106bf7ccf19f1502e20df918008c5 3 | size 554427064 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VVBPTensor-net 2 | Learning to Reconstruct CT Images from the VVBP-Tensor. 3 | 4 | 5 | ---- VVtensor_sparse_96_Unet_L2.py: 6 | (demo) The main file running VVBPTensor-net for sparse-view CT with a view number of 96, Unet model, and L2 loss. 7 | 8 | ---- Data: test data. 9 | 10 | ---- Model: some functions for code running. 11 | 12 | ---- Model_save: a trained model. 13 | 14 | 15 | Any questions connect Xi Tao (xtao@smu.edu.cn). 16 | -------------------------------------------------------------------------------- /Model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from torch.autograd import Variable 8 | #gpu_id_start =0 9 | def default_conv(in_channels, out_channels, kernel_size, bias=True, dilation=1): 10 | return nn.Conv2d( 11 | in_channels, out_channels, kernel_size, 12 | padding=(kernel_size//2)+dilation-1, bias=bias, dilation=dilation) 13 | 14 | 15 | def default_conv1(in_channels, out_channels, kernel_size, bias=True, groups=3): 16 | return nn.Conv2d( 17 | in_channels,out_channels, kernel_size, 18 | padding=(kernel_size//2), bias=bias, groups=groups) 19 | 20 | 21 | 22 | class BasicBlock(nn.Sequential): 23 | def __init__( 24 | self, in_channels, out_channels, kernel_size, stride=1, bias=False, 25 | bn=False, act=nn.ReLU(True)): 26 | 27 | m = [nn.Conv2d( 28 | in_channels, out_channels, kernel_size, 29 | padding=(kernel_size//2), stride=stride, bias=bias) 30 | ] 31 | if bn: m.append(nn.BatchNorm2d(out_channels)) 32 | if act is not None: m.append(act) 33 | super(BasicBlock, self).__init__(*m) 34 | 35 | class BBlock(nn.Module): 36 | def __init__( 37 | self, conv, in_channels, out_channels, kernel_size, 38 | bias=True, bn=False, gn=False, act=nn.ReLU(True), res_scale=1, num_groups_im=4): 39 | 40 | super(BBlock, self).__init__() 41 | m = [] 42 | m.append(conv(in_channels, out_channels, kernel_size, bias=bias)) 43 | if bn: m.append(nn.BatchNorm2d(out_channels, eps=1e-4, momentum=0.95)) 44 | if gn: m.append(nn.GroupNorm(num_channels=out_channels, num_groups=num_groups_im)) 45 | m.append(act) 46 | 47 | 48 | self.body = nn.Sequential(*m) 49 | self.res_scale = res_scale 50 | 51 | def forward(self, x): 52 | x = self.body(x).mul(self.res_scale) 53 | return x 54 | -------------------------------------------------------------------------------- /Model/networks_1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import functools 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import pdb 7 | import torch.nn.functional as F 8 | 9 | ############################################################################### 10 | # Functions 11 | ############################################################################### 12 | 13 | 14 | 15 | class FBPCONVNet_big_drop(nn.Module): 16 | def __init__(self, input_nc=1, NFS=128, p=0.2): 17 | super(FBPCONVNet_big_drop, self).__init__() 18 | # create network model 19 | self.block_1_1 = None 20 | self.block_2_1 = None 21 | self.block_3_1 = None 22 | self.block_4_1 = None 23 | self.block_5 = None 24 | self.block_4_2 = None 25 | self.block_3_2 = None 26 | self.block_2_2 = None 27 | self.block_1_2 = None 28 | self.input_nc = input_nc 29 | self.NFS = NFS 30 | self.p = p 31 | self.create_model() 32 | 33 | 34 | def forward(self, input): 35 | 36 | block_1_1_output = self.block_1_1(input) 37 | block_1_1_output = self.block_1_1_drop(block_1_1_output) 38 | 39 | block_2_1_output = self.block_2_1(block_1_1_output) 40 | block_2_1_output = self.block_2_1_drop(block_2_1_output) 41 | 42 | block_3_1_output = self.block_3_1(block_2_1_output) 43 | block_3_1_output = self.block_3_1_drop(block_3_1_output) 44 | 45 | block_4_1_output = self.block_4_1(block_3_1_output) 46 | block_4_1_output = self.block_4_1_drop(block_4_1_output) 47 | 48 | block_5_output = self.block_5(block_4_1_output) 49 | block_5_output = self.block_5_drop(block_5_output) 50 | 51 | result = self.block_4_2(torch.cat((block_4_1_output, block_5_output), dim=1)) 52 | result = self.block_4_2_drop(result) 53 | 54 | result = self.block_3_2(torch.cat((block_3_1_output, result), dim=1)) 55 | result = self.block_3_2_drop(result) 56 | 57 | result = self.block_2_2(torch.cat((block_2_1_output, result), dim=1)) 58 | result = self.block_2_2_drop(result) 59 | 60 | result = self.block_1_2(torch.cat((block_1_1_output, result), dim=1)) 61 | result = self.block_1_2_drop(result) 62 | 63 | result = result + input 64 | return result 65 | 66 | def create_model(self): 67 | self.block_1_1_drop = nn.Dropout2d(p=self.p) 68 | self.block_2_1_drop = nn.Dropout2d(p=self.p) 69 | self.block_3_1_drop = nn.Dropout2d(p=self.p) 70 | self.block_4_1_drop = nn.Dropout2d(p=self.p) 71 | self.block_5_drop = nn.Dropout2d(p=self.p) 72 | self.block_4_2_drop = nn.Dropout2d(p=self.p) 73 | self.block_3_2_drop = nn.Dropout2d(p=self.p) 74 | self.block_2_2_drop = nn.Dropout2d(p=self.p) 75 | self.block_1_2_drop = nn.Dropout2d(p=self.p) 76 | 77 | kernel_size = 3 78 | padding = kernel_size // 2 79 | NFS = self.NFS 80 | # block_1_1 81 | block_1_1 = [] 82 | block_1_1.extend(self.add_block_conv(in_channels= self.input_nc, out_channels=NFS, kernel_size=kernel_size, stride=1, 83 | padding=padding, batchOn=True, ReluOn=True)) 84 | block_1_1.extend(self.add_block_conv(in_channels=NFS, out_channels=NFS, kernel_size=kernel_size, stride=1, 85 | padding=padding, batchOn=True, ReluOn=True)) 86 | block_1_1.extend(self.add_block_conv(in_channels=NFS, out_channels=NFS, kernel_size=kernel_size, stride=1, 87 | padding=padding, batchOn=True, ReluOn=True)) 88 | 89 | self.block_1_1 = nn.Sequential(*block_1_1) 90 | 91 | # block_2_1 92 | block_2_1 = [nn.MaxPool2d(kernel_size=2)] 93 | block_2_1.extend(self.add_block_conv(in_channels=NFS, out_channels=NFS*2, kernel_size=kernel_size, stride=1, 94 | padding=padding, batchOn=True, ReluOn=True)) 95 | block_2_1.extend(self.add_block_conv(in_channels=NFS*2, out_channels=NFS*2, kernel_size=kernel_size, stride=1, 96 | padding=padding, batchOn=True, ReluOn=True)) 97 | 98 | self.block_2_1 = nn.Sequential(*block_2_1) 99 | 100 | # block_3_1 101 | block_3_1 = [nn.MaxPool2d(kernel_size=2)] 102 | block_3_1.extend(self.add_block_conv(in_channels=NFS*2, out_channels=NFS*4, kernel_size=kernel_size, stride=1, 103 | padding=padding, batchOn=True, ReluOn=True)) 104 | block_3_1.extend(self.add_block_conv(in_channels=NFS*4, out_channels=NFS*4, kernel_size=kernel_size, stride=1, 105 | padding=padding, batchOn=True, ReluOn=True)) 106 | 107 | self.block_3_1 = nn.Sequential(*block_3_1) 108 | 109 | # block_4_1 110 | block_4_1 = [nn.MaxPool2d(kernel_size=2)] 111 | block_4_1.extend(self.add_block_conv(in_channels=NFS*4, out_channels=NFS*8, kernel_size=kernel_size, stride=1, 112 | padding=padding, batchOn=True, ReluOn=True)) 113 | block_4_1.extend(self.add_block_conv(in_channels=NFS*8, out_channels=NFS*8, kernel_size=kernel_size, stride=1, 114 | padding=padding, batchOn=True, ReluOn=True)) 115 | 116 | self.block_4_1 = nn.Sequential(*block_4_1) 117 | 118 | # block_5 119 | block_5 = [nn.MaxPool2d(kernel_size=2)] 120 | block_5.extend(self.add_block_conv(in_channels=NFS*8, out_channels=NFS*16, kernel_size=kernel_size, stride=1, 121 | padding=padding, batchOn=True, ReluOn=True)) 122 | block_5.extend(self.add_block_conv(in_channels=NFS*16, out_channels=NFS*16, kernel_size=kernel_size, stride=1, 123 | padding=padding, batchOn=True, ReluOn=True)) 124 | block_5.extend(self.add_block_conv_transpose(in_channels=NFS*16, out_channels=NFS*8, kernel_size=kernel_size, stride=2, 125 | padding=padding, output_padding=1, batchOn=True, ReluOn=True)) 126 | self.block_5 = nn.Sequential(*block_5) 127 | 128 | # block_4_2 129 | block_4_2 = [] 130 | block_4_2.extend(self.add_block_conv(in_channels=NFS*16, out_channels=NFS*8, kernel_size=kernel_size, stride=1, 131 | padding=padding, batchOn=True, ReluOn=True)) 132 | block_4_2.extend(self.add_block_conv(in_channels=NFS*8, out_channels=NFS*8, kernel_size=kernel_size, stride=1, 133 | padding=padding, batchOn=True, ReluOn=True)) 134 | block_4_2.extend( 135 | self.add_block_conv_transpose(in_channels=NFS*8, out_channels=NFS*4, kernel_size=kernel_size, stride=2, 136 | padding=padding, output_padding=1, batchOn=True, ReluOn=True)) 137 | self.block_4_2 = nn.Sequential(*block_4_2) 138 | 139 | # block_3_2 140 | block_3_2 = [] 141 | block_3_2.extend(self.add_block_conv(in_channels=NFS*8, out_channels=NFS*4, kernel_size=kernel_size, stride=1, 142 | padding=padding, batchOn=True, ReluOn=True)) 143 | block_3_2.extend(self.add_block_conv(in_channels=NFS*4, out_channels=NFS*4, kernel_size=kernel_size, stride=1, 144 | padding=padding, batchOn=True, ReluOn=True)) 145 | block_3_2.extend( 146 | self.add_block_conv_transpose(in_channels=NFS*4, out_channels=NFS*2, kernel_size=kernel_size, stride=2, 147 | padding=padding, output_padding=1, batchOn=True, ReluOn=True)) 148 | self.block_3_2 = nn.Sequential(*block_3_2) 149 | 150 | # block_2_2 151 | block_2_2 = [] 152 | block_2_2.extend(self.add_block_conv(in_channels=NFS*4, out_channels=NFS*2, kernel_size=kernel_size, stride=1, 153 | padding=padding, batchOn=True, ReluOn=True)) 154 | block_2_2.extend(self.add_block_conv(in_channels=NFS*2, out_channels=NFS*2, kernel_size=kernel_size, stride=1, 155 | padding=padding, batchOn=True, ReluOn=True)) 156 | block_2_2.extend( 157 | self.add_block_conv_transpose(in_channels=NFS*2, out_channels=NFS, kernel_size=kernel_size, stride=2, 158 | padding=padding, output_padding=1, batchOn=True, ReluOn=True)) 159 | self.block_2_2 = nn.Sequential(*block_2_2) 160 | 161 | # block_1_2 162 | block_1_2 = [] 163 | block_1_2.extend(self.add_block_conv(in_channels=NFS*2, out_channels=NFS, kernel_size=kernel_size, stride=1, 164 | padding=padding, batchOn=True, ReluOn=True)) 165 | block_1_2.extend(self.add_block_conv(in_channels=NFS, out_channels=NFS, kernel_size=kernel_size, stride=1, 166 | padding=padding, batchOn=True, ReluOn=True)) 167 | block_1_2.extend(self.add_block_conv(in_channels=NFS, out_channels=self.input_nc, kernel_size=1, stride=1, 168 | padding=0, batchOn=False, ReluOn=False)) 169 | self.block_1_2 = nn.Sequential(*block_1_2) 170 | 171 | @staticmethod 172 | def add_block_conv(in_channels, out_channels, kernel_size, stride, padding, batchOn, ReluOn): 173 | seq = [] 174 | # conv layer 175 | conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 176 | stride=stride, padding=padding) 177 | seq.append(conv) 178 | 179 | # batch norm layer 180 | batchOn=False 181 | if batchOn: 182 | batch_norm = nn.BatchNorm2d(num_features=out_channels) 183 | seq.append(batch_norm) 184 | 185 | # relu layer 186 | if ReluOn: 187 | seq.append(nn.ReLU()) 188 | return seq 189 | 190 | @staticmethod 191 | def add_block_conv_transpose(in_channels, out_channels, kernel_size, stride, padding, output_padding, batchOn, ReluOn): 192 | seq = [] 193 | 194 | convt = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 195 | stride=stride, padding=padding, output_padding=output_padding) 196 | seq.append(convt) 197 | 198 | batchOn=False 199 | if batchOn: 200 | batch_norm = nn.BatchNorm2d(num_features=out_channels) 201 | seq.append(batch_norm) 202 | 203 | 204 | if ReluOn: 205 | seq.append(nn.ReLU()) 206 | return seq 207 | 208 | -------------------------------------------------------------------------------- /VVtensor_sparse_96_Unet_L2.py: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/env python3 3 | # -*- coding: utf-8 -*- 4 | """ 5 | Created on Wed Aug 15 09:30:30 2019 6 | 7 | @author: wyb081@smu.edu.cn 8 | """ 9 | 10 | import torch 11 | import numpy as np 12 | import numpy.random as random 13 | from torch.autograd import Function 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from torch.utils.data import Dataset, DataLoader 17 | import glob 18 | from torchvision import transforms 19 | from torch.autograd import Variable 20 | import torch.optim as optim 21 | from torch.optim import lr_scheduler 22 | import torch.nn.init as init 23 | import pickle 24 | import time 25 | import os 26 | import skimage 27 | import copy 28 | import scipy.io 29 | import scipy.io as sio 30 | import pdb 31 | import torchvision.models as models 32 | from Model import networks_1 as networks 33 | from Model import common 34 | from PIL import Image 35 | from PIL import ImageFile 36 | ImageFile.LOAD_TRUNCATED_IMAGES = True 37 | 38 | pdb.set_trace() 39 | 40 | is_train = False # True: Train, False: Test 41 | re_load = True # True: Load the trained model 42 | 43 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 44 | use_cuda = torch.cuda.is_available() 45 | WaterAtValue = 0.0192 46 | source_root_path = '/mnt/storage/VVtenosr-net/Data/' 47 | target_root_path = '/mnt/storage/VVtenosr-net/Model_save/' 48 | Train_reconinfo = {'patients': ['L096', 'L291','L109','L143','L192','L286','L333','L310'], 'SliceThickness': ['']} 49 | Val_reconinfo = {'patients': ['L506','L067'], 'SliceThickness': ['']} # 50 | ResultFolders = ['Model_save', 'Loss_save', 'Optimizer_save', 'SinoIndices_save'] 51 | 52 | 53 | reload_mode = 'train' 54 | batch_num = {'train': 200, 'val': 10, 'test': 1} 55 | batch_size= {'train': 1, 'val': 1, 'test': 1} 56 | is_lr_scheduler = False 57 | filter_size = 3 58 | filter_num = 64 59 | padding_size = 1#(filter_size-1)/2 60 | is_addnoise = False 61 | I0 = 3e8 62 | sigma = 0.01 63 | net_id = 1 64 | net_name = 'VVtensor_sparse_96_Unet_L2' 65 | indices_name = 'VVtenser_96_fan' 66 | 67 | for target_folder in ResultFolders: 68 | if not os.path.isdir(target_root_path + net_name + '/' + target_folder): 69 | os.makedirs(target_root_path + net_name + '/' + target_folder) 70 | 71 | TestMode = 'test' 72 | 73 | 74 | target_folder = 'result' 75 | save_as_mat = True 76 | if is_train is True: 77 | is_shuffle = True 78 | else: 79 | is_shuffle = False 80 | if not os.path.isdir(target_root_path + net_name + '/' + target_folder): 81 | os.makedirs(target_root_path + net_name + '/' + target_folder) 82 | 83 | gpu_id_end = [0] 84 | 85 | 86 | geo = {'nVoxelX': 512, 'nVoxelY': 512, 87 | 'sVoxelX': 355.0208, 'sVoxelY': 355.0208, 88 | 'dVoxelX': 0.6934, 'dVoxelY': 0.6934, 89 | 'sino_views': 96, 'sparse_factor': 12, 90 | 'nDetecU': 736, 'sDetecU': 0.6934*736, 91 | 'dDetecU': 1.2858, 'DSD': 1085.6, 'DSO': 595.0, 92 | 'offOriginX': 0, 'offOriginY': 0, 93 | 'offDetecU': 0, 94 | 'start_angle': 0, 'end_angle': 360, 95 | 'mode': 'fan', 96 | } 97 | 98 | def PixelIndexCal(geo): 99 | 100 | nx = geo['nVoxelX'] 101 | ny = geo['nVoxelY'] 102 | offset_x = geo['offOriginX'] 103 | offset_y = geo['offOriginY'] 104 | 105 | wx = (nx+1)/2 + offset_x 106 | wy = (ny+1)/2 + offset_y 107 | is_arc = 1 108 | 109 | dx = geo['dVoxelX'] 110 | dy = -geo['dVoxelX'] 111 | dr = geo['dDetecU'] 112 | offset_s = geo['offDetecU'] 113 | na = geo['sino_views'] 114 | orbit = geo['end_angle'] - geo['start_angle'] 115 | orbit_start = geo['start_angle'] 116 | nb = geo['nDetecU'] 117 | dso = geo['DSO'] 118 | dsd = geo['DSD'] 119 | source_offset = 0 120 | ds = geo['dDetecU'] 121 | xc, yc = np.mgrid[ dx-wx*dx : (nx-wx)*dx+dx :dx, dy-wy*dy : (ny-wy)*dy+dy :dy] 122 | 123 | 124 | rr = np.sqrt(xc**2+yc**2) 125 | 126 | betas = np.radians(np.arange(0,na).reshape(na,1)/na * orbit + orbit_start) 127 | wb = (nb+1)/2 + offset_s 128 | 129 | 130 | sino_indices = torch.zeros(geo['nVoxelX']*geo['nVoxelY'], geo['sino_views']) 131 | Weight2 = torch.zeros(geo['nVoxelX']*geo['nVoxelY'], geo['sino_views']) 132 | 133 | for ia in range(na): 134 | print('Na!{}/{}'.format(ia+1,na)) 135 | beta = betas[ia] 136 | d_loop = dso + xc * np.sin(beta) - yc * np.cos(beta) # dso - y_beta 137 | 138 | r_loop = xc * np.cos(beta) + yc * np.sin(beta) - source_offset # x_beta-roff 139 | 140 | 141 | if is_arc: 142 | sprime_ds = (dsd/ds) * np.arctan2(r_loop, d_loop) 143 | w2 = dsd**2 / (d_loop**2 + r_loop**2) # [np] image weighting 144 | else: 145 | mag = dsd / d_loop 146 | sprime_ds = mag * r_loop / ds 147 | 148 | bb = sprime_ds + wb 149 | bb = bb + ia * nb - 1 150 | 151 | sino_indices[:,ia] = torch.from_numpy(bb).view(-1) 152 | Weight2[:,ia] = torch.from_numpy(w2).view(-1) 153 | 154 | sino_indices = sino_indices.view(-1) 155 | Weight2 = Weight2.view(-1) 156 | 157 | 158 | return sino_indices,Weight2 159 | 160 | 161 | if os.path.isfile(target_root_path+ net_name + "/SinoIndices_save/{}_indices.dat".format(indices_name)): 162 | print('Loading sinoIndices...') 163 | geo['indices'] = pickle.load(open(target_root_path+ net_name + "/SinoIndices_save/{}_indices.dat".format(indices_name), "rb")) 164 | geo['Weight2'] = pickle.load(open(target_root_path+ net_name + "/SinoIndices_save/{}_Weight2.dat".format(indices_name), "rb")) 165 | print('Done!') 166 | else: 167 | print('Generating sinoIndices...') 168 | geo['indices'],geo['Weight2'] = PixelIndexCal(geo) 169 | f = open(target_root_path+ net_name + "/SinoIndices_save/{}_indices.dat".format(indices_name), "wb") 170 | pickle.dump(geo['indices'], f, True) 171 | f.close() 172 | f = open(target_root_path+ net_name + "/SinoIndices_save/{}_Weight2.dat".format(indices_name), "wb") 173 | pickle.dump(geo['Weight2'], f, True) 174 | f.close() 175 | 176 | print('Done!') 177 | 178 | 179 | if use_cuda: 180 | geo['indices'] = geo['indices'] 181 | geo['il'] = torch.floor(geo['indices']) 182 | geo['wr'] = geo['indices'] - geo['il'] 183 | 184 | geo['indices'] = geo['indices'].cuda(gpu_id_end[0]) 185 | geo['wr'] = geo['wr'].cuda(gpu_id_end[0]) 186 | geo['il'] = geo['il'].type(torch.LongTensor).cuda(gpu_id_end[0]) 187 | geo['Weight2'] = geo['Weight2'].cuda(gpu_id_end[0]) 188 | 189 | 190 | class Backprojected(nn.Module): 191 | def __init__(self, geo, bias=True): 192 | super(Backprojected, self).__init__() 193 | self.geo = geo 194 | 195 | 196 | def forward(self, input): 197 | 198 | input = input.view(-1, self.geo['sino_views']*self.geo['nDetecU']) 199 | 200 | input_ = ((1 -geo['wr']) * torch.index_select( input, 1, self.geo['il'] ) + geo['wr'] * torch.index_select( input, 1, self.geo['il']+1 )) * self.geo['Weight2'] 201 | 202 | input_ = input_.view(-1, self.geo['nVoxelX'], self.geo['nVoxelY'], self.geo['sino_views']) 203 | input_, _ = input_.sort(3,descending=False) 204 | 205 | 206 | return input_ 207 | 208 | 209 | 210 | class VVtenosr_Unet(nn.Module): 211 | def __init__(self): 212 | super(VVtenosr_Unet, self).__init__() 213 | 214 | self.backprojected = Backprojected(geo) 215 | self.DeepVVBP_Unt = networks.FBPCONVNet_big_drop(input_nc=128, NFS=128,p=0.1).cuda(gpu_id_end[0]) 216 | self.rec_criterion = nn.MSELoss().cuda(gpu_id_end[0]) 217 | 218 | act = nn.ReLU(True) 219 | 220 | VV_Compress1 = [common.BBlock(common.default_conv, geo['sino_views'], 128, 3, act=act)] 221 | VV_Compress2 = [common.BBlock(common.default_conv, 128, 128, 3, act=act)] 222 | 223 | output_ly = [common.default_conv( 128, 1, 3)] 224 | 225 | 226 | 227 | self.VV_Compress1 = nn.Sequential(*VV_Compress1).cuda(gpu_id_end[0]) 228 | self.VV_Compress2 = nn.Sequential(*VV_Compress2).cuda(gpu_id_end[0]) 229 | 230 | 231 | self.output_ly = nn.Sequential(*output_ly).cuda(gpu_id_end[0]) 232 | 233 | def forward(self, Image_LD, sino_sparse, Image_HD): 234 | 235 | 236 | x = self.backprojected(sino_sparse) 237 | x = x.permute(0,3,1,2).contiguous() 238 | Image_LD = torch.sum(x,1)*np.pi*20/geo['sino_views'] 239 | 240 | x = self.VV_Compress1(x) 241 | x = self.VV_Compress2(x) 242 | 243 | x = self.DeepVVBP_Unt(x) 244 | x = self.output_ly(x) 245 | 246 | 247 | Loss_rec = self.rec_criterion(x, Image_HD) 248 | 249 | 250 | return x, Loss_rec, Image_LD 251 | 252 | class ToTensor(object): 253 | """Convert ndarrays in sample to Tensors.""" 254 | 255 | def __call__(self, image): 256 | return torch.from_numpy(image).type(torch.FloatTensor) 257 | 258 | class TrainDicmDataset(Dataset): 259 | def __init__(self, root_dir, reconinfo, geo, trf_op=None): 260 | 261 | self.Raw_data_paths = [glob.glob( root_dir +'{}/*'.format(x) ) for x in reconinfo['patients']] 262 | self.Raw_data_paths = [x for j in self.Raw_data_paths for x in j] 263 | self.trf_op = trf_op 264 | self.geo = geo 265 | 266 | 267 | def __len__(self): 268 | return len(self.Raw_data_paths) 269 | 270 | def __getitem__(self, idx): 271 | # Label 272 | try: 273 | image_path = self.Raw_data_paths[idx] + '/ND/Data.mat' 274 | data_HD = sio.loadmat(image_path) 275 | except ValueError as a: 276 | 277 | print('*******************************************************************************') 278 | print ('Exception: ', a) 279 | print('*******************************************************************************') 280 | print(self.Raw_data_paths[idx]) 281 | 282 | try: 283 | idx = random.randint(0,len(self.Raw_data_paths), size=([1]))[0] 284 | image_path = self.Raw_data_paths[idx] + '/ND/Data.mat' 285 | data_HD = sio.loadmat(image_path) 286 | except ValueError as a: 287 | print('*******************************************************************************') 288 | print ('Exception: ', a) 289 | print('*******************************************************************************') 290 | print(self.Raw_data_paths[idx]) 291 | 292 | try: 293 | idx = random.randint(0,len(self.Raw_data_paths), size=([1]))[0] 294 | image_path = self.Raw_data_paths[idx] + '/ND/Data.mat' 295 | data_HD = sio.loadmat(image_path) 296 | except ValueError as a: 297 | print('*******************************************************************************') 298 | print ('Exception: ', a) 299 | print('*******************************************************************************') 300 | print(self.Raw_data_paths[idx]) 301 | 302 | try: 303 | idx = random.randint(0,len(self.Raw_data_paths), size=([1]))[0] 304 | image_path = self.Raw_data_paths[idx] + '/ND/Data.mat' 305 | data_HD = sio.loadmat(image_path) 306 | 307 | except ValueError as a: 308 | print('*******************************************************************************') 309 | print ('Exception: ', a) 310 | print('*******************************************************************************') 311 | print(self.Raw_data_paths[idx]) 312 | idx = random.randint(0,len(self.Raw_data_paths), size=([1]))[0] 313 | image_path = self.Raw_data_paths[idx] + '/ND/Data.mat' 314 | data_HD = sio.loadmat(image_path) 315 | 316 | try: 317 | image_path_LD = self.Raw_data_paths[idx] + '/ND/Data.mat' 318 | data_LD = sio.loadmat(image_path_LD) 319 | except ValueError as a: 320 | try: 321 | image_path_LD = self.Raw_data_paths[idx] + '/ND/Data.mat' 322 | data_LD = sio.loadmat(image_path_LD) 323 | except ValueError as a: 324 | image_path_LD = self.Raw_data_paths[idx+1] + '/ND/Data.mat' 325 | data_LD = sio.loadmat(image_path_LD) 326 | 327 | Image_HD = data_HD['Image'] * 20 328 | sino_sparse = np.transpose(data_LD['SinoFiltered'][:,::geo['sparse_factor']]) 329 | 330 | Image_LD = data_HD['Image'] * 20 331 | 332 | del data_HD,data_LD 333 | #pdb.set_trace() 334 | 335 | image_path_split = image_path.split('/') 336 | name = image_path_split[-4]+'_'+image_path_split[-3] 337 | 338 | random_list = [ToTensor()] 339 | transform = transforms.Compose(random_list) 340 | Image_HD = transform(Image_HD) 341 | Image_LD = transform(Image_LD) 342 | sino_sparse = transform(sino_sparse) 343 | return Image_HD, Image_LD, sino_sparse, name 344 | 345 | 346 | class TrainDataSet(Dataset): 347 | def __init__(self, root_dir, reconinfo, geo, pre_trans_img=None, post_trans_img=None, post_trans_sino=None, addnoise=None): 348 | 349 | self.imgset = TrainDicmDataset(root_dir, reconinfo, geo, pre_trans_img) 350 | self.addnoise = transforms.Compose(addnoise) if addnoise is not None else None 351 | self.post_trans_img = transforms.Compose(post_trans_img) if post_trans_img is not None else None 352 | self.post_trans_sino = transforms.Compose(post_trans_sino) if post_trans_sino is not None else None 353 | 354 | def __len__(self): 355 | return len(self.imgset) 356 | 357 | def __getitem__(self, idx): 358 | Image_HD, Image_LD, sino_sparse, name = self.imgset[idx] 359 | 360 | 361 | sample = {'Image_HD': Image_HD, 'Image_LD': Image_LD, 'name': name, 'sino_sparse':sino_sparse} 362 | sample['Image_HD'].unsqueeze_(0) 363 | sample['Image_LD'].unsqueeze_(0) 364 | sample['sino_sparse'].unsqueeze_(0) 365 | return sample 366 | 367 | 368 | pre_trans_img = None 369 | addnoise = [AddNoise(I0, sigma)] if is_addnoise is True else None 370 | post_trans_img = None 371 | post_trans_sino = None 372 | datasets = [] 373 | datasets = {'train': TrainDataSet(source_root_path, Train_reconinfo, geo, pre_trans_img, post_trans_img, post_trans_sino, addnoise), 374 | 'val': TrainDataSet(source_root_path, Val_reconinfo, geo, pre_trans_img, post_trans_img, post_trans_sino, addnoise), 375 | 'test': TrainDataSet(source_root_path, Val_reconinfo, geo, pre_trans_img, post_trans_img, post_trans_sino, addnoise)} 376 | 377 | kwargs = {'num_workers': 4, 'pin_memory': True} 378 | data = datasets['test'].__getitem__(0) 379 | 380 | 381 | dataloaders = {x: DataLoader(datasets[x], batch_size[x], shuffle=is_shuffle, **kwargs) for x in ['train', 'val', 'test']} 382 | 383 | dataset_sizes = {x: batch_num[x]*batch_size[x] for x in ['train', 'val', 'test']} 384 | 385 | """ Gradient averaging. """ 386 | def average_gradients(model): 387 | size = float(dist.get_world_size()) 388 | for param in model.parameters(): 389 | dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM) 390 | param.grad.data /= size 391 | 392 | def train_model(model, optimizer, criterion=None, scheduler=None, min_loss=None, pre_losses=None, num_epochs=25): 393 | since = time.time() 394 | 395 | if min_loss is None: 396 | min_loss = {x:1.0 for x in ['train', 'val']} 397 | 398 | losses = {x: torch.zeros(num_epochs, batch_num[x]) for x in ['train', 'val']} 399 | 400 | if pre_losses is not None: 401 | min_dim = {x: min(losses[x].size(1), pre_losses[x].size(1)) for x in ['train', 'val']} 402 | 403 | epoch_loss = {x: 0.0 for x in ['train', 'val']} 404 | 405 | print( 'lr={:.10f}'.format(optimizer.param_groups[0]['lr'])) 406 | for epoch in range(num_epochs): 407 | if (epoch == 100) or (epoch == 300) : 408 | optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr']*0.2 409 | print( 'lr={:.10f}'.format(optimizer.param_groups[0]['lr'])) 410 | 411 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 412 | print('-' * 10) 413 | 414 | # Each epoch has a training and validation phase, 'val' 415 | for phase in ['train','val']: 416 | if phase == 'train': 417 | if scheduler is not None: 418 | print('-' * 10)#scheduler.step() 419 | model.train() # Set model to training mode 420 | else: 421 | model.eval() # Set model to evaluate mode 422 | 423 | running_loss = 0.0 424 | 425 | 426 | # Iterate over data. 427 | for i_batch, Sample in enumerate(dataloaders[phase]): 428 | 429 | 430 | if i_batch == batch_num[phase]: 431 | break 432 | Len_data = datasets[phase].__len__() 433 | 434 | 435 | 436 | Image_HD = Sample['Image_HD'] 437 | Image_LD = Sample['Image_LD'] 438 | sino_sparse = Sample['sino_sparse'] 439 | 440 | if use_cuda: # wrap them in Variable 441 | Image_HD = Variable(Image_HD).cuda(gpu_id_end[0]) 442 | Image_LD = Variable(Image_LD).cuda(gpu_id_end[0], async=True) 443 | sino_sparse = Variable(sino_sparse).cuda(gpu_id_end[0], async=True) 444 | else: 445 | Image_HD = Variable(Image_HD) 446 | Image_LD = Variable(Image_LD) 447 | 448 | 449 | # zero the parameter gradients 450 | optimizer.zero_grad() 451 | 452 | # forward 453 | outputs, loss, Image_LD = model(Image_LD, sino_sparse, Image_HD) 454 | #pdb.set_trace() 455 | if (phase == 'val') * (i_batch==0): 456 | 457 | scipy.io.savemat(target_root_path + net_name + '/val.mat', mdict = {'Image_HD':Image_HD.cpu().data.numpy(),'Image_LD':Image_LD.cpu().data.numpy(),'outputs': outputs.cpu().data.numpy() } ) 458 | 459 | del Image_HD, Image_LD, Sample, outputs, sino_sparse 460 | 461 | # backward + optimize only if in training phase 462 | if phase == 'train': 463 | loss.backward() 464 | 465 | optimizer.step() 466 | 467 | if 0 == (i_batch%1): 468 | print('{}:, {}, subIter {}/{}, subLoss: {:.8f}'.format(net_name,phase, i_batch, batch_num[phase], loss.data.item())) 469 | # statistics 470 | losses[phase][epoch, i_batch] = loss.data.item() 471 | 472 | running_loss += loss.data.item() * batch_size[phase] 473 | del loss 474 | 475 | 476 | epoch_loss[phase] = running_loss / (dataset_sizes[phase]) 477 | 478 | print('{} Loss: {:.8f}'.format(phase, epoch_loss[phase])) 479 | 480 | if phase == 'val': 481 | print("Train / Val: {:.8f}".format(epoch_loss['train']/epoch_loss['val'])) 482 | 483 | 484 | 485 | #deep copy the model 486 | if 0 == (epoch%1): 487 | if 1: # epoch_loss[phase] < min_loss[phase]: 488 | min_loss[phase] = epoch_loss[phase] 489 | f = open(target_root_path + net_name + "/Loss_save/min_loss_{}_{}.dat".format(net_id, net_name), "wb") 490 | pickle.dump(min_loss, f, True) 491 | f.close() 492 | torch.save(model.state_dict(), 493 | target_root_path + net_name + "/Model_save/best_{}_model_params_{}_{}.pkl".format(phase, net_id, net_name)) 494 | if phase is 'train': 495 | torch.save(optimizer.state_dict(), 496 | target_root_path + net_name + "/Optimizer_save/optimizer_{}_{}.pkl".format(net_id, net_name)) 497 | 498 | 499 | 500 | if pre_losses is None: 501 | tmp_losses = {x: losses[x][:epoch+1] for x in ['train', 'val']} 502 | else: 503 | tmp_losses = {x: torch.cat((pre_losses[x][:,:min_dim[x]], losses[x][:epoch+1,:min_dim[x]]), 0) for x in ['train', 'val']} 504 | f = open(target_root_path + net_name + "/Loss_save/losses_{}_{}.dat".format(net_id, net_name), "wb") 505 | pickle.dump(tmp_losses, f, True) 506 | f.close() 507 | 508 | 509 | print( 'lr={:.10f}'.format(optimizer.param_groups[0]['lr'])) 510 | 511 | time_elapsed = time.time() - since 512 | print('Training complete in {:.0f}m {:.0f}s'.format( 513 | time_elapsed // 60, time_elapsed % 60)) 514 | print('Minimun train loss: {:5f}'.format(min_loss['train'])) 515 | 516 | def test_model(model, criterion=None): 517 | for i_batch, Sample in enumerate(dataloaders[TestMode]): 518 | print('processing batch_{}...'.format(i_batch)) 519 | 520 | 521 | Image_HD = Sample['Image_HD'] 522 | Image_LD = Sample['Image_LD'] 523 | sino_sparse = Sample['sino_sparse'] 524 | 525 | if use_cuda: # wrap them in Variable 526 | Image_HD = Variable(Image_HD).cuda(gpu_id_end[0]) 527 | Image_LD = Variable(Image_LD).cuda(gpu_id_end[0], async=True) 528 | sino_sparse = Variable(sino_sparse).cuda(gpu_id_end[0], async=True) 529 | else: 530 | Image_HD = Variable(Image_HD) 531 | Image_LD = Variable(Image_LD) 532 | 533 | outputs, loss, Image_LD = model(Image_LD, sino_sparse, Image_HD) 534 | 535 | 536 | Sample['Image_LD'] = Image_LD/20 537 | 538 | Sample['Image_HD'] = Image_HD/20 539 | 540 | Sample['output'] = outputs/20 541 | Sample['loss'] = loss 542 | data_name = ''.join(Sample['name']) 543 | Sample.pop('name') 544 | 545 | del outputs, loss, Image_LD 546 | 547 | 548 | #pdb.set_trace() 549 | data_save = {key: value.cpu().data.numpy() for key, value in Sample.items()} 550 | scipy.io.savemat(target_root_path + net_name + '/result/{}.mat'.format(data_name), mdict = data_save) 551 | 552 | def weights_init(m): 553 | classname = m.__class__.__name__ 554 | if classname.find('Conv') != -1 or classname.find('Linear') != -1: #or classname.find('Backprojected') != -1: 555 | init.xavier_uniform_(m.weight.data) 556 | if m.bias is not None: 557 | m.bias.data.zero_() 558 | 559 | 560 | Network_im = VVtenosr_Unet() 561 | 562 | if re_load is False: 563 | Network_im.apply(weights_init) 564 | min_loss = None 565 | pre_losses = None 566 | else: 567 | epoch_reload_path = target_root_path + net_name + '/Model_save/best_{}_model_params_{}_{}.pkl'.format(reload_mode, net_id, net_name) 568 | if os.path.isfile(epoch_reload_path): 569 | print('reloading previously trained network...') 570 | checkpoint = torch.load(epoch_reload_path, map_location = lambda storage, loc: storage) 571 | model_dict = Network_im.state_dict() 572 | checkpoint = {k: v for k, v in checkpoint.items() if k in model_dict} 573 | model_dict.update(checkpoint) 574 | Network_im.load_state_dict(model_dict) 575 | del checkpoint 576 | torch.cuda.empty_cache() 577 | print('done!') 578 | else: 579 | Network_im.apply(weights_init) 580 | 581 | min_loss_vggath = target_root_path + net_name + "/Loss_save/min_loss_{}_{}.dat".format(net_id, net_name) 582 | min_loss = pickle.load(open(min_loss_vggath, "rb")) if os.path.isfile(min_loss_vggath) else None 583 | 584 | pre_losses_path = target_root_path + net_name + "/Loss_save/losses_{}_{}.dat".format(net_id, net_name) 585 | pre_losses = pickle.load(open(pre_losses_path, "rb")) if os.path.isfile(pre_losses_path) else None 586 | 587 | 588 | criterion =None 589 | 590 | optimizer_ft = optim.RMSprop(Network_im.parameters(), lr=1e-5, momentum=0.9, weight_decay=0.0000) 591 | #optimizer_ft = optim.Adam(Network_im.parameters(), lr=1e-5, betas=(0.9, 0.99), eps=1e-06, weight_decay=0.0005) 592 | if re_load is True: 593 | optimizer_reload_path = target_root_path+ net_name + "/Optimizer_save/optimizer_{}_{}.pkl".format(net_id, net_name) 594 | if os.path.isfile(optimizer_reload_path): 595 | print('reloading previous optimizer...') 596 | checkpoint = torch.load(optimizer_reload_path, map_location = lambda storage, loc: storage) 597 | optimizer_ft.load_state_dict(checkpoint) 598 | del checkpoint 599 | torch.cuda.empty_cache() 600 | print('done!') 601 | 602 | 603 | 604 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=300, gamma=0.5) if is_lr_scheduler else None 605 | 606 | if is_train is True: 607 | train_model(Network_im, optimizer_ft, criterion, exp_lr_scheduler, min_loss, pre_losses, num_epochs=400) 608 | else: 609 | Network_im.eval() 610 | test_model(Network_im, criterion) 611 | 612 | 613 | 614 | 615 | 616 | --------------------------------------------------------------------------------