├── images ├── 3D SRGAN(D).png ├── 3D SRGAN(G).png └── Upsamplings.png ├── .idea ├── vcs.xml ├── misc.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml ├── 3D-GAN-superresolution.iml └── workspace.xml ├── utils.py ├── README.md ├── dataset.py └── model.py /images/3D SRGAN(D).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashishpatel26/3D-GAN-superresolution/master/images/3D SRGAN(D).png -------------------------------------------------------------------------------- /images/3D SRGAN(G).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashishpatel26/3D-GAN-superresolution/master/images/3D SRGAN(G).png -------------------------------------------------------------------------------- /images/Upsamplings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ashishpatel26/3D-GAN-superresolution/master/images/Upsamplings.png -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/3D-GAN-superresolution.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import keras as K 2 | from keras.utils import conv_utils 3 | from keras.layers.convolutional import UpSampling3D 4 | from keras.engine import InputSpec 5 | from tensorlayer.layers import * 6 | 7 | 8 | class UpSampling3D(Layer): 9 | def __init__(self, size=(2, 2, 2), **kwargs): 10 | self.size = conv_utils.normalize_tuple(size, 3, 'size') 11 | self.input_spec = InputSpec(ndim=5) 12 | super(UpSampling3D, self).__init__(**kwargs) 13 | 14 | def compute_output_shape(self, input_shape): 15 | dim1 = self.size[0] * input_shape[1] if input_shape[1] is not None else None 16 | dim2 = self.size[1] * input_shape[2] if input_shape[2] is not None else None 17 | dim3 = self.size[2] * input_shape[3] if input_shape[3] is not None else None 18 | return (input_shape[0], 19 | dim1, 20 | dim2, 21 | dim3, 22 | input_shape[4]) 23 | 24 | def call(self, inputs): 25 | return K.resize_volumes(inputs, 26 | self.size[0], self.size[1], self.size[2], 27 | self.data_format) 28 | 29 | def get_config(self): 30 | config = {'size': self.size, 31 | 'data_format': self.data_format} 32 | base_config = super(UpSampling3D, self).get_config() 33 | return dict(list(base_config.items()) + list(config.items())) 34 | 35 | 36 | def smooth_gan_labels(y): 37 | if y == 0: 38 | y_out = tf.random_uniform(shape=y.get_shape(), minval=0.0, maxval=0.3) 39 | else: 40 | y_out = tf.random_uniform(shape=y.get_shape(), minval=0.7, maxval=1.2) 41 | 42 | return y_out 43 | 44 | 45 | def subPixelConv3d(net, img_width, img_height, img_depth, stepsToEnd, n_out_channel): 46 | i = net 47 | r = 2 48 | a, b, z, c = int(img_width / (2 * stepsToEnd)), int(img_height / (2 * stepsToEnd)), int( 49 | img_depth / (2 * stepsToEnd)), tf.shape(i)[3] 50 | bsize = tf.shape(i)[0] # Handling Dimension(None) type for undefined batch dim 51 | xs = tf.split(i, r, 4) # b*h*w*d*r*r*r 52 | xr = tf.concat(xs, 3) # b*h*w*(r*d)*r*r 53 | xss = tf.split(xr, r, 4) # b*h*w*(r*d)*r*r 54 | xrr = tf.concat(xss, 2) # b*h*(r*w)*(r*d)*r 55 | x = tf.reshape(xrr, (bsize, r * a, r * b, r * z, n_out_channel)) # b*(r*h)*(r*w)*(r*d)*n_out n_out=64/2^ 56 | 57 | return x 58 | 59 | 60 | def aggregate(patches): 61 | margin = 16 62 | volume = np.empty([224, 224, 152, 1]) 63 | volume[0:112, 0:112, 0:76, :] = patches[0, 0:112, 0:112, 0:76, :] 64 | volume[0:112, 0:112, 76:, :] = patches[1, 0:112, 0:112, margin:, :] 65 | volume[0:112, 112:, 0:76, :] = patches[2, 0:112, margin:, 0:76, :] 66 | volume[0:112, 112:, 76:, :] = patches[3, 0:112, margin:, margin:, :] 67 | volume[112:, 0:112, 0:76, :] = patches[4, margin:, 0:112, 0:76, :] 68 | volume[112:, 0:112, 76:, :] = patches[5, margin:, 0:112, margin:, :] 69 | volume[112:, 112:, 0:76, :] = patches[6, margin:, margin:, 0:76, :] 70 | volume[112:, 112:, 76:, :] = patches[7, margin:, margin:, margin:, :] 71 | return volume 72 | 73 | 74 | def aggregate2(patches): 75 | margin = 8 76 | volume = np.empty([112, 112, 76, 1]) 77 | volume[0:56, 0:56, 0:38, :] = patches[0, 0:56, 0:56, 0:38, :] 78 | volume[0:56, 0:56, 38:, :] = patches[1, 0:56, 0:56, margin:, :] 79 | volume[0:56, 56:, 0:38, :] = patches[2, 0:56, margin:, 0:38, :] 80 | volume[0:56, 56:, 38:, :] = patches[3, 0:56, margin:, margin:, :] 81 | volume[56:, 0:56, 0:38, :] = patches[4, margin:, 0:56, 0:38, :] 82 | volume[56:, 0:56, 38:, :] = patches[5, margin:, 0:56, margin:, :] 83 | volume[56:, 56:, 0:38, :] = patches[6, margin:, margin:, 0:38, :] 84 | volume[56:, 56:, 38:, :] = patches[7, margin:, margin:, margin:, :] 85 | return volume 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3D-GAN-superresolution 2 | Here we present the implementation in TensorFlow of our work to generate high resolution MRI scans from low resolution images using Generative Adversarial Networks (GANs), accepted in the [Medical Imaging with Deep Learning Conference – Amsterdam. 4 - 6th July 2018.](https://midl.amsterdam/) 3 | 4 | Discriminator network 5 | ![alt text](https://github.com/imatge-upc/3D-GAN-superresolution/blob/master/images/3D%20SRGAN(D).png) 6 | 7 | Generator network 8 | ![alt text](https://github.com/imatge-upc/3D-GAN-superresolution/blob/master/images/3D%20SRGAN(G).png) 9 | 10 | In this work we propose an architecture for MRI super-resolution that completely exploits the available volumetric information contained in MRI scans, using 3D convolutions to process the volumes and taking advantage of an adversarial framework, improving the realism of the generated volumes. 11 | The model is based on the [SRGAN network](https://arxiv.org/abs/1609.04802). The adversarial loss uses least squares to stabilize the training and the generator loss, in addition to the adversarial term contains a content term based on mean square error and image gradients in order to improve the quality of the generated images. We explore three different methods for the upsampling phase: an upsampling layer which uses nearest neighbors to replicate consecutive pixels followed by a convolutional layer to improve the approximation, sub-pixel convolution layers as proposed in [Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network](https://arxiv.org/abs/1609.05158) and a modification of this method [Checkerboard artifact free sub-pixel convolution](https://arxiv.org/pdf/1707.02937.pdf) that alleviates checkbock artifacts produced by sub-pixel convolution layers (Check: [Deconvolution and Checkerboard Artifacts](https://distill.pub/2016/deconv-checkerboard/) for more information). 12 | 13 | Comparison of the upsampling methods used 14 | ![alt text](https://github.com/imatge-upc/3D-GAN-superresolution/blob/master/images/Upsamplings.png) 15 | 16 | ### Data 17 | We used a set of normal control T1-weighted images from the Alzheimer’s Disease Neuroimaging Initiative (ADNI) database (see www.adni-info.org for details). Skull stripping is performed in all volumes and part of the background is removed. Final volumes have dimensions 224x224x152. Due to memory constraints the training is patch-based; for each volume we extract patches of size 128x128x92, with a step of 112x112x76, so there are 8 patches per volume, with an overlap of 16x16x16. We have a total number of 589 volumes, 470 are used for training while 119 are used for testing. We use batches of two patches, thus for each volume we perform 4 iterations. This code is prepared to do experiments with the processing of images and dimensions explained. 18 | 19 | The code expects that the database is inside the folder specified by the data_path in the Train_dataset script. Inside there should be a folder for each of the patients containing a 'T1_brain_extractedBrainExtractionMask.nii.gz' file. This file was created taking the original images from ADNI and performing a skull-stripping processing of them. We use the nibabel library to load the images. 20 | 21 | ### Training 22 | To train the network the model.py script is used. When calling the script you should specify: 23 | + -path_prediction: Path to save training predictions. 24 | + -checkpoint_dir: Path to save checkpoints. 25 | + -residual_blocks: Number of residual blocks. 26 | + -upsampling_factor: Upsampling factor. 27 | + -subpixel_NN: Use subpixel nearest neighbour. 28 | + -nn: Use Upsampling3D + nearest neighbour, RC. 29 | + -feature_size: Number of filters. 30 | 31 | By default it will use the sub-pixel convolution layers, 32 filters, 6 residual blocks and an umpsaling factor of 4. 32 | 33 | If you want to restore the training, when calling the script you have to define the checkpoint to use using the restore argument: 34 | ⋅⋅* -restore: Checkpoint path to restore training 35 | 36 | ``` 37 | python model.py -path_prediction YOURPATH -checkpoint_dir YOURCHECKPOINTPATH -residual_blocks 8 -upsampling_factor 2 -subpixel_NN True -feature_size 64 38 | ``` 39 | 40 | ### Testing 41 | To test the network the model.py script is also used. When calling the script you should specify the same arguments as before for the configuration of the model and the new paths used. Also, the argument evaluate should be True: 42 | + -path_volumes: Path to save test volumes. 43 | + -checkpoint_dir_restore: Path to restore checkpoints. 44 | + -residual_blocks: Number of residual blocks. 45 | + -upsampling_factor: Upsampling factor. 46 | + -subpixel_NN: Use subpixel nearest neighbour. 47 | + -nn: Use Upsampling3D + nearest neighbour, RC. 48 | + -feature_size: Number of filters. 49 | + -evaluate: Test the model. 50 | 51 | ``` 52 | python model.py -path_volumes YOURPATH -checkpoint_dir_restore YOURCHECKPOINTPATH -residual_blocks 8 -upsampling_factor 2 -subpixel_NN True -feature_size 64 -evaluate True 53 | ``` 54 | 55 | # Contact 56 | If you have any general doubt about our work or code which may be of interest for other researchers, please use the public issues section on this github repo. 57 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import nibabel as nib 3 | import math 4 | import os 5 | from skimage.util import view_as_windows 6 | 7 | 8 | class Train_dataset(object): 9 | def __init__(self, batch_size, overlapping=1): 10 | self.batch_size = batch_size 11 | self.data_path = '/imatge/isanchez/projects/neuro/ADNI-Screening-1.5T' 12 | self.subject_list = os.listdir(self.data_path) 13 | self.subject_list = np.delete(self.subject_list, 120) 14 | self.heigth_patch = 112 # 128 15 | self.width_patch = 112 # 128 16 | self.depth_patch = 76 # 92 17 | self.margin = 16 18 | self.overlapping = overlapping 19 | self.num_patches = (math.ceil((224 / (self.heigth_patch)) / (self.overlapping))) * ( 20 | math.ceil((224 / (self.width_patch)) / (self.overlapping))) * ( 21 | math.ceil((152 / (self.depth_patch)) / (self.overlapping))) 22 | 23 | def mask(self, iteration): 24 | subject_batch = self.subject_list[iteration * self.batch_size:self.batch_size + (iteration * self.batch_size)] 25 | subjects_true = np.empty([self.batch_size, 256, 256, 184]) 26 | i = 0 27 | for subject in subject_batch: 28 | if subject != 'ADNI_SCREENING_CLINICAL_FILE_08_02_17.csv': 29 | filename = os.path.join(self.data_path, subject) 30 | filename = os.path.join(filename, 'T1_brain_extractedBrainExtractionMask.nii.gz') 31 | proxy = nib.load(filename) 32 | data = np.array(proxy.dataobj) 33 | 34 | paddwidthr = int((256 - proxy.shape[0]) / 2) 35 | paddheightr = int((256 - proxy.shape[1]) / 2) 36 | paddepthr = int((184 - proxy.shape[2]) / 2) 37 | 38 | if (paddwidthr * 2 + proxy.shape[0]) != 256: 39 | paddwidthl = paddwidthr + 1 40 | else: 41 | paddwidthl = paddwidthr 42 | 43 | if (paddheightr * 2 + proxy.shape[1]) != 256: 44 | paddheightl = paddheightr + 1 45 | else: 46 | paddheightl = paddheightr 47 | 48 | if (paddepthr * 2 + proxy.shape[2]) != 184: 49 | paddepthl = paddepthr + 1 50 | else: 51 | paddepthl = paddepthr 52 | 53 | data_padded = np.pad(data, 54 | [(paddwidthl, paddwidthr), (paddheightl, paddheightr), (paddepthl, paddepthr)], 55 | 'constant', constant_values=0) 56 | subjects_true[i] = data_padded 57 | i = i + 1 58 | mask = np.empty( 59 | [self.batch_size * self.num_patches, self.width_patch + self.margin, self.heigth_patch + self.margin, 60 | self.depth_patch + self.margin, 1]) 61 | i = 0 62 | for subject in subjects_true: 63 | patch = view_as_windows(subject, window_shape=( 64 | (self.width_patch + self.margin), (self.heigth_patch + self.margin), (self.depth_patch + self.margin)), 65 | step=(self.width_patch - self.margin, self.heigth_patch - self.margin, 66 | self.depth_patch - self.margin)) 67 | for d in range(patch.shape[0]): 68 | for v in range(patch.shape[1]): 69 | for h in range(patch.shape[2]): 70 | p = patch[d, v, h, :] 71 | p = p[:, np.newaxis] 72 | p = p.transpose((0, 2, 3, 1)) 73 | mask[i] = p 74 | i = i + 1 75 | return mask 76 | 77 | def patches_true(self, iteration): 78 | subjects_true = self.data_true(iteration) 79 | patches_true = np.empty( 80 | [self.batch_size * self.num_patches, self.width_patch + self.margin, self.heigth_patch + self.margin, 81 | self.depth_patch + self.margin, 1]) 82 | i = 0 83 | for subject in subjects_true: 84 | patch = view_as_windows(subject, window_shape=( 85 | (self.width_patch + self.margin), (self.heigth_patch + self.margin), (self.depth_patch + self.margin)), 86 | step=(self.width_patch - self.margin, self.heigth_patch - self.margin, 87 | self.depth_patch - self.margin)) 88 | for d in range(patch.shape[0]): 89 | for v in range(patch.shape[1]): 90 | for h in range(patch.shape[2]): 91 | p = patch[d, v, h, :] 92 | p = p[:, np.newaxis] 93 | p = p.transpose((0, 2, 3, 1)) 94 | patches_true[i] = p 95 | i = i + 1 96 | return patches_true 97 | 98 | def data_true(self, iteration): 99 | subject_batch = self.subject_list[iteration * self.batch_size:self.batch_size + (iteration * self.batch_size)] 100 | subjects = np.empty([self.batch_size, 224, 224, 152]) 101 | i = 0 102 | for subject in subject_batch: 103 | if subject != 'ADNI_SCREENING_CLINICAL_FILE_08_02_17.csv': 104 | filename = os.path.join(self.data_path, subject) 105 | filename = os.path.join(filename, 'T1_brain_extractedBrainExtractionBrain.nii.gz') 106 | proxy = nib.load(filename) 107 | data = np.array(proxy.dataobj) 108 | 109 | paddwidthr = int((256 - proxy.shape[0]) / 2) 110 | paddheightr = int((256 - proxy.shape[1]) / 2) 111 | paddepthr = int((184 - proxy.shape[2]) / 2) 112 | 113 | if (paddwidthr * 2 + proxy.shape[0]) != 256: 114 | paddwidthl = paddwidthr + 1 115 | else: 116 | paddwidthl = paddwidthr 117 | 118 | if (paddheightr * 2 + proxy.shape[1]) != 256: 119 | paddheightl = paddheightr + 1 120 | else: 121 | paddheightl = paddheightr 122 | 123 | if (paddepthr * 2 + proxy.shape[2]) != 184: 124 | paddepthl = paddepthr + 1 125 | else: 126 | paddepthl = paddepthr 127 | 128 | data_padded = np.pad(data, 129 | [(paddwidthl, paddwidthr), (paddheightl, paddheightr), (paddepthl, paddepthr)], 130 | 'constant', constant_values=0) 131 | 132 | subjects[i] = data_padded[16:240, 16:240, 16:168] # remove background 133 | i = i + 1 134 | return subjects 135 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 63 | 64 | 65 | 67 | 68 | 75 | 76 | 77 | 78 | 79 | true 80 | DEFINITION_ORDER 81 | 82 | 83 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 110 | 111 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 136 | 137 | 138 | 139 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 188 | 189 | 202 | 203 | 221 | 222 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 264 | 265 | 281 | 282 | 283 | 285 | 286 | 287 | 288 | 1529748036876 289 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 325 | 326 | 328 | 329 | 330 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorlayer as tl 3 | from tensorlayer.layers import * 4 | from dataset import Train_dataset 5 | import math 6 | from scipy.ndimage.interpolation import zoom 7 | from scipy.ndimage.filters import gaussian_filter 8 | from utils import smooth_gan_labels, aggregate, subPixelConv3d 9 | import nibabel as nib 10 | import os 11 | from skimage.measure import compare_ssim as ssim 12 | from skimage.measure import compare_psnr as psnr 13 | from keras.layers.convolutional import UpSampling3D 14 | import argparse 15 | 16 | 17 | def lrelu1(x): 18 | return tf.maximum(x, 0.25 * x) 19 | 20 | 21 | def lrelu2(x): 22 | return tf.maximum(x, 0.3 * x) 23 | 24 | 25 | def discriminator(input_disc, kernel, reuse, is_train=True): 26 | w_init = tf.random_normal_initializer(stddev=0.02) 27 | batch_size = 1 28 | div_patches = 4 29 | num_patches = 8 30 | img_width = 128 31 | img_height = 128 32 | img_depth = 92 33 | with tf.variable_scope("SRGAN_d", reuse=reuse): 34 | tl.layers.set_name_reuse(reuse) 35 | input_disc.set_shape([int((batch_size * num_patches) / div_patches), img_width, img_height, img_depth, 1], ) 36 | x = InputLayer(input_disc, name='in') 37 | x = Conv3dLayer(x, act=lrelu2, shape=[kernel, kernel, kernel, 1, 32], strides=[1, 1, 1, 1, 1], 38 | padding='SAME', W_init=w_init, name='conv1') 39 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 32, 32], strides=[1, 2, 2, 2, 1], 40 | padding='SAME', W_init=w_init, name='conv2') 41 | 42 | x = BatchNormLayer(x, is_train=is_train, name='BN1-conv2', act=lrelu2) 43 | 44 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 32, 64], strides=[1, 1, 1, 1, 1], 45 | padding='SAME', W_init=w_init, name='conv3') 46 | x = BatchNormLayer(x, is_train=is_train, name='BN1-conv3', act=lrelu2) 47 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 64, 64], strides=[1, 2, 2, 2, 1], 48 | padding='SAME', W_init=w_init, name='conv4') 49 | x = BatchNormLayer(x, is_train=is_train, name='BN1-conv4', act=lrelu2) 50 | 51 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 64, 128], strides=[1, 1, 1, 1, 1], 52 | padding='SAME', W_init=w_init, name='conv5') 53 | x = BatchNormLayer(x, is_train=is_train, name='BN1-conv5', act=lrelu2) 54 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 128, 128], strides=[1, 2, 2, 2, 1], 55 | padding='SAME', W_init=w_init, name='conv6') 56 | x = BatchNormLayer(x, is_train=is_train, name='BN1-conv6', act=lrelu2) 57 | 58 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 128, 256], strides=[1, 1, 1, 1, 1], 59 | padding='SAME', W_init=w_init, name='conv7') 60 | x = BatchNormLayer(x, is_train=is_train, name='BN1-conv7', act=lrelu2) 61 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 256, 256], strides=[1, 2, 2, 2, 1], 62 | padding='SAME', W_init=w_init, name='conv8') 63 | x = BatchNormLayer(x, is_train=is_train, name='BN1-conv8', act=lrelu2) 64 | 65 | x = FlattenLayer(x, name='flatten') 66 | x = DenseLayer(x, n_units=1024, act=lrelu2, name='dense1') 67 | x = DenseLayer(x, n_units=1, name='dense2') 68 | 69 | logits = x.outputs 70 | x.outputs = tf.nn.sigmoid(x.outputs, name='output') 71 | 72 | return x, logits 73 | 74 | 75 | def generator(input_gen, kernel, nb, upscaling_factor, reuse, feature_size, img_width, img_height, img_depth, 76 | subpixel_NN, nn, is_train=True): 77 | w_init = tf.random_normal_initializer(stddev=0.02) 78 | 79 | w_init_subpixel1 = np.random.normal(scale=0.02, size=[3, 3, 3, 64, feature_size]) 80 | w_init_subpixel1 = zoom(w_init_subpixel1, [2, 2, 2, 1, 1], order=0) 81 | w_init_subpixel1_last = tf.constant_initializer(w_init_subpixel1) 82 | w_init_subpixel2 = np.random.normal(scale=0.02, size=[3, 3, 3, 64, 64]) 83 | w_init_subpixel2 = zoom(w_init_subpixel2, [2, 2, 2, 1, 1], order=0) 84 | w_init_subpixel2_last = tf.constant_initializer(w_init_subpixel2) 85 | 86 | with tf.variable_scope("SRGAN_g", reuse=reuse): 87 | tl.layers.set_name_reuse(reuse) 88 | x = InputLayer(input_gen, name='in') 89 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 1, feature_size], strides=[1, 1, 1, 1, 1], 90 | padding='SAME', W_init=w_init, name='conv1') 91 | x = BatchNormLayer(x, act=lrelu1, is_train=is_train, name='BN-conv1') 92 | inputRB = x 93 | inputadd = x 94 | 95 | # residual blocks 96 | for i in range(nb): 97 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, feature_size, feature_size], strides=[1, 1, 1, 1, 1], 98 | padding='SAME', W_init=w_init, name='conv1-rb/%s' % i) 99 | x = BatchNormLayer(x, act=lrelu1, is_train=is_train, name='BN1-rb/%s' % i) 100 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, feature_size, feature_size], strides=[1, 1, 1, 1, 1], 101 | padding='SAME', W_init=w_init, name='conv2-rb/%s' % i) 102 | x = BatchNormLayer(x, is_train=is_train, name='BN2-rb/%s' % i, ) 103 | # short skip connection 104 | x = ElementwiseLayer([x, inputadd], tf.add, name='add-rb/%s' % i) 105 | inputadd = x 106 | 107 | # large skip connection 108 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, feature_size, feature_size], strides=[1, 1, 1, 1, 1], 109 | padding='SAME', W_init=w_init, name='conv2') 110 | x = BatchNormLayer(x, is_train=is_train, name='BN-conv2') 111 | x = ElementwiseLayer([x, inputRB], tf.add, name='add-conv2') 112 | 113 | # ____________SUBPIXEL-NN______________# 114 | 115 | if subpixel_NN: 116 | # upscaling block 1 117 | if upscaling_factor == 4: 118 | img_height_deconv = int(img_height / 2) 119 | img_width_deconv = int(img_width / 2) 120 | img_depth_deconv = int(img_depth / 2) 121 | else: 122 | img_height_deconv = img_height 123 | img_width_deconv = img_width 124 | img_depth_deconv = img_depth 125 | 126 | x = DeConv3dLayer(x, shape=[kernel * 2, kernel * 2, kernel * 2, 64, feature_size], 127 | act=lrelu1, strides=[1, 2, 2, 2, 1], 128 | output_shape=[tf.shape(input_gen)[0], img_height_deconv, img_width_deconv, 129 | img_depth_deconv, 64], 130 | padding='SAME', W_init=w_init_subpixel1_last, name='conv1-ub-subpixelnn/1') 131 | 132 | # upscaling block 2 133 | if upscaling_factor == 4: 134 | x = DeConv3dLayer(x, shape=[kernel * 2, kernel * 2, kernel * 2, 64, 64], 135 | act=lrelu1, strides=[1, 2, 2, 2, 1], padding='SAME', 136 | output_shape=[tf.shape(input_gen)[0], img_height, img_width, 137 | img_depth, 64], 138 | W_init=w_init_subpixel2_last, name='conv1-ub-subpixelnn/2') 139 | 140 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 64, 1], strides=[1, 1, 1, 1, 1], 141 | padding='SAME', W_init=w_init, name='convlast-subpixelnn') 142 | 143 | # ____________RC______________# 144 | 145 | elif nn: 146 | # upscaling block 1 147 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, feature_size, 64], act=lrelu1, 148 | strides=[1, 1, 1, 1, 1], 149 | padding='SAME', W_init=w_init, name='conv1-ub/1') 150 | x = UpSampling3D(name='UpSampling3D_1')(x.outputs) 151 | x = Conv3dLayer(InputLayer(x, name='in ub1 conv2'), 152 | shape=[kernel, kernel, kernel, 64, 64], 153 | act=lrelu1, 154 | strides=[1, 1, 1, 1, 1], 155 | padding='SAME', W_init=w_init, name='conv2-ub/1') 156 | 157 | # upscaling block 2 158 | if upscaling_factor == 4: 159 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 64, 64], act=lrelu1, 160 | strides=[1, 1, 1, 1, 1], 161 | padding='SAME', W_init=w_init, name='conv1-ub/2') 162 | x = UpSampling3D(name='UpSampling3D_1')(x.outputs) 163 | x = Conv3dLayer(InputLayer(x, name='in ub2 conv2'), shape=[kernel, kernel, kernel, 64, 164 | 64], act=lrelu1, 165 | strides=[1, 1, 1, 1, 1], 166 | padding='SAME', W_init=w_init, name='conv2-ub/2') 167 | 168 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 64, 1], strides=[1, 1, 1, 1, 1], 169 | act=tf.nn.tanh, padding='SAME', W_init=w_init, name='convlast') 170 | 171 | # ____________SUBPIXEL - BASELINE______________# 172 | 173 | else: 174 | 175 | if upscaling_factor == 4: 176 | steps_to_end = 2 177 | else: 178 | steps_to_end = 1 179 | 180 | # upscaling block 1 181 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, feature_size, 64], act=lrelu1, 182 | strides=[1, 1, 1, 1, 1], 183 | padding='SAME', W_init=w_init, name='conv1-ub/1') 184 | arguments = {'img_width': img_width, 'img_height': img_height, 'img_depth': img_depth, 185 | 'stepsToEnd': steps_to_end, 186 | 'n_out_channel': int(64 / 8)} 187 | x = LambdaLayer(x, fn=subPixelConv3d, fn_args=arguments, name='SubPixel1') 188 | 189 | # upscaling block 2 190 | if upscaling_factor == 4: 191 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, int((64) / 8), 64], act=lrelu1, 192 | strides=[1, 1, 1, 1, 1], 193 | padding='SAME', W_init=w_init, name='conv1-ub/2') 194 | arguments = {'img_width': img_width, 'img_height': img_height, 'img_depth': img_depth, 'stepsToEnd': 1, 195 | 'n_out_channel': int(64 / 8)} 196 | x = LambdaLayer(x, fn=subPixelConv3d, fn_args=arguments, name='SubPixel2') 197 | 198 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, int(64 / 8), 1], strides=[1, 1, 1, 1, 1], 199 | padding='SAME', W_init=w_init, name='convlast') 200 | 201 | return x 202 | 203 | 204 | def train(upscaling_factor, residual_blocks, feature_size, path_prediction, checkpoint_dir, img_width, img_height, 205 | img_depth, subpixel_NN, nn, restore, batch_size=1, div_patches=4, epochs=10): 206 | traindataset = Train_dataset(batch_size) 207 | iterations_train = math.ceil((len(traindataset.subject_list) * 0.8) / batch_size) 208 | num_patches = traindataset.num_patches 209 | 210 | # ##========================== DEFINE MODEL ============================## 211 | t_input_gen = tf.placeholder('float32', [int((batch_size * num_patches) / div_patches), None, 212 | None, None, 1], 213 | name='t_image_input_to_SRGAN_generator') 214 | t_target_image = tf.placeholder('float32', [int((batch_size * num_patches) / div_patches), 215 | img_width, img_height, img_depth, 1], 216 | name='t_target_image') 217 | t_input_mask = tf.placeholder('float32', [int((batch_size * num_patches) / div_patches), 218 | img_width, img_height, img_depth, 1], 219 | name='t_image_input_mask') 220 | 221 | net_gen = generator(input_gen=t_input_gen, kernel=3, nb=residual_blocks, upscaling_factor=upscaling_factor, 222 | img_height=img_height, img_width=img_width, img_depth=img_depth, subpixel_NN=subpixel_NN, nn=nn, 223 | feature_size=feature_size, is_train=True, reuse=False) 224 | net_d, disc_out_real = discriminator(input_disc=t_target_image, kernel=3, is_train=True, reuse=False) 225 | _, disc_out_fake = discriminator(input_disc=net_gen.outputs, kernel=3, is_train=True, reuse=True) 226 | 227 | # test 228 | gen_test = generator(t_input_gen, kernel=3, nb=residual_blocks, upscaling_factor=upscaling_factor, 229 | img_height=img_height, img_width=img_width, img_depth=img_depth, subpixel_NN=subpixel_NN, 230 | nn=nn, 231 | feature_size=feature_size, is_train=True, reuse=True) 232 | 233 | # ###========================== DEFINE TRAIN OPS ==========================### 234 | 235 | if np.random.uniform() > 0.1: 236 | # give correct classifications 237 | y_gan_real = tf.ones_like(disc_out_real) 238 | y_gan_fake = tf.zeros_like(disc_out_real) 239 | else: 240 | # give wrong classifications (noisy labels) 241 | y_gan_real = tf.zeros_like(disc_out_real) 242 | y_gan_fake = tf.ones_like(disc_out_real) 243 | 244 | d_loss_real = tf.reduce_mean(tf.square(disc_out_real - smooth_gan_labels(y_gan_real)), 245 | name='d_loss_real') 246 | d_loss_fake = tf.reduce_mean(tf.square(disc_out_fake - smooth_gan_labels(y_gan_fake)), 247 | name='d_loss_fake') 248 | d_loss = d_loss_real + d_loss_fake 249 | 250 | mse_loss = tf.reduce_sum( 251 | tf.square(net_gen.outputs - t_target_image), axis=[0, 1, 2, 3, 4], name='g_loss_mse') 252 | 253 | dx_real = t_target_image[:, 1:, :, :, :] - t_target_image[:, :-1, :, :, :] 254 | dy_real = t_target_image[:, :, 1:, :, :] - t_target_image[:, :, :-1, :, :] 255 | dz_real = t_target_image[:, :, :, 1:, :] - t_target_image[:, :, :, :-1, :] 256 | dx_fake = net_gen.outputs[:, 1:, :, :, :] - net_gen.outputs[:, :-1, :, :, :] 257 | dy_fake = net_gen.outputs[:, :, 1:, :, :] - net_gen.outputs[:, :, :-1, :, :] 258 | dz_fake = net_gen.outputs[:, :, :, 1:, :] - net_gen.outputs[:, :, :, :-1, :] 259 | 260 | gd_loss = tf.reduce_sum(tf.square(tf.abs(dx_real) - tf.abs(dx_fake))) + \ 261 | tf.reduce_sum(tf.square(tf.abs(dy_real) - tf.abs(dy_fake))) + \ 262 | tf.reduce_sum(tf.square(tf.abs(dz_real) - tf.abs(dz_fake))) 263 | 264 | g_gan_loss = 10e-2 * tf.reduce_mean(tf.square(disc_out_fake - smooth_gan_labels(tf.ones_like(disc_out_real))), 265 | name='g_loss_gan') 266 | 267 | g_loss = mse_loss + g_gan_loss + gd_loss 268 | 269 | g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) 270 | d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) 271 | 272 | with tf.variable_scope('learning_rate'): 273 | lr_v = tf.Variable(1e-4, trainable=False) 274 | global_step = tf.Variable(0, trainable=False) 275 | decay_rate = 0.5 276 | decay_steps = 4920 # every 2 epochs (more or less) 277 | learning_rate = tf.train.inverse_time_decay(lr_v, global_step=global_step, decay_rate=decay_rate, 278 | decay_steps=decay_steps) 279 | 280 | # Optimizers 281 | g_optim = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars) 282 | d_optim = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars) 283 | 284 | session = tf.Session() 285 | tl.layers.initialize_global_variables(session) 286 | 287 | step = 0 288 | saver = tf.train.Saver() 289 | 290 | if restore is not None: 291 | saver.restore(session, tf.train.latest_checkpoint(restore)) 292 | val_restore = 0 * epochs 293 | else: 294 | val_restore = 0 295 | 296 | array_psnr = [] 297 | array_ssim = [] 298 | 299 | for j in range(val_restore, epochs + val_restore): 300 | for i in range(0, iterations_train): 301 | # ====================== LOAD DATA =========================== # 302 | xt_total = traindataset.patches_true(i) 303 | xm_total = traindataset.mask(i) 304 | for k in range(0, div_patches): 305 | print('{}'.format(k)) 306 | xt = xt_total[k * int((batch_size * num_patches) / div_patches):(int( 307 | (batch_size * num_patches) / div_patches) * k) + int( 308 | (batch_size * num_patches) / div_patches)] 309 | xm = xm_total[k * int((batch_size * num_patches) / div_patches):(int( 310 | (batch_size * num_patches) / div_patches) * k) + int( 311 | (batch_size * num_patches) / div_patches)] 312 | 313 | # NORMALIZING 314 | for t in range(0, xt.shape[0]): 315 | normfactor = (np.amax(xt[t])) / 2 316 | if normfactor != 0: 317 | xt[t] = ((xt[t] - normfactor) / normfactor) 318 | 319 | x_generator = gaussian_filter(xt, sigma=1) 320 | x_generator = zoom(x_generator, [1, (1 / upscaling_factor), (1 / upscaling_factor), 321 | (1 / upscaling_factor), 1], prefilter=False, order=0) 322 | xgenin = x_generator 323 | 324 | # ========================= train SRGAN ========================= # 325 | # update D 326 | errd, _ = session.run([d_loss, d_optim], {t_target_image: xt, t_input_gen: xgenin}) 327 | # update G 328 | errg, errmse, errgan, errgd, _ = session.run([g_loss, mse_loss, g_gan_loss, gd_loss, g_optim], 329 | {t_input_gen: xgenin, t_target_image: xt, 330 | t_input_mask: xm}) 331 | print( 332 | "Epoch [%2d/%2d] [%4d/%4d] [%4d/%4d]: d_loss: %.8f g_loss: %.8f (mse: %.6f gdl: %.6f adv: %.6f)" % ( 333 | j, epochs + val_restore, i, iterations_train, k, div_patches - 1, errd, errg, errmse, errgd, 334 | errgan)) 335 | 336 | # ========================= evaluate & save model ========================= # 337 | 338 | if k == 1 and i % 20 == 0: 339 | if j - val_restore == 0: 340 | x_true_img = xt[0] 341 | if normfactor != 0: 342 | x_true_img = ((x_true_img + 1) * normfactor) # denormalize 343 | img_true = nib.Nifti1Image(x_true_img, np.eye(4)) 344 | img_true.to_filename( 345 | os.path.join(path_prediction, str(j) + str(i) + 'true.nii.gz')) 346 | 347 | x_gen_img = xgenin[0] 348 | if normfactor != 0: 349 | x_gen_img = ((x_gen_img + 1) * normfactor) # denormalize 350 | img_gen = nib.Nifti1Image(x_gen_img, np.eye(4)) 351 | img_gen.to_filename( 352 | os.path.join(path_prediction, str(j) + str(i) + 'gen.nii.gz')) 353 | 354 | x_pred = session.run(gen_test.outputs, {t_input_gen: xgenin}) 355 | x_pred_img = x_pred[0] 356 | if normfactor != 0: 357 | x_pred_img = ((x_pred_img + 1) * normfactor) # denormalize 358 | img_pred = nib.Nifti1Image(x_pred_img, np.eye(4)) 359 | img_pred.to_filename( 360 | os.path.join(path_prediction, str(j) + str(i) + '.nii.gz')) 361 | 362 | max_gen = np.amax(x_pred_img) 363 | max_real = np.amax(x_true_img) 364 | if max_gen > max_real: 365 | val_max = max_gen 366 | else: 367 | val_max = max_real 368 | min_gen = np.amin(x_pred_img) 369 | min_real = np.amin(x_true_img) 370 | if min_gen < min_real: 371 | val_min = min_gen 372 | else: 373 | val_min = min_real 374 | val_psnr = psnr(np.multiply(x_true_img, xm[0]), np.multiply(x_pred_img, xm[0]), 375 | dynamic_range=val_max - val_min) 376 | val_ssim = ssim(np.multiply(x_true_img, xm[0]), np.multiply(x_pred_img, xm[0]), 377 | dynamic_range=val_max - val_min, multichannel=True) 378 | 379 | saver.save(sess=session, save_path=checkpoint_dir, global_step=step) 380 | print("Saved step: [%2d]" % step) 381 | step = step + 1 382 | 383 | 384 | def evaluate(upsampling_factor, residual_blocks, feature_size, checkpoint_dir_restore, path_volumes, nn, subpixel_NN, 385 | img_height, img_width, img_depth): 386 | traindataset = Train_dataset(1) 387 | iterations = math.ceil( 388 | (len(traindataset.subject_list) * 0.2)) 389 | print(len(traindataset.subject_list)) 390 | print(iterations) 391 | totalpsnr = 0 392 | totalssim = 0 393 | array_psnr = np.empty(iterations) 394 | array_ssim = np.empty(iterations) 395 | batch_size = 1 396 | div_patches = 4 397 | num_patches = traindataset.num_patches 398 | 399 | # define model 400 | t_input_gen = tf.placeholder('float32', [1, None, None, None, 1], 401 | name='t_image_input_to_SRGAN_generator') 402 | srgan_network = generator(input_gen=t_input_gen, kernel=3, nb=residual_blocks, 403 | upscaling_factor=upsampling_factor, feature_size=feature_size, subpixel_NN=subpixel_NN, 404 | img_height=img_height, img_width=img_width, img_depth=img_depth, nn=nn, 405 | is_train=False, reuse=False) 406 | 407 | # restore g 408 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) 409 | 410 | saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="SRGAN_g")) 411 | saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir_restore)) 412 | 413 | for i in range(0, iterations): 414 | # extract volumes 415 | xt_total = traindataset.data_true(654 + i) 416 | xt_mask = traindataset.mask(654 + i) 417 | normfactor = (np.amax(xt_total[0])) / 2 418 | x_generator = ((xt_total[0] - normfactor) / normfactor) 419 | res = 1 / upsampling_factor 420 | x_generator = x_generator[:, :, :, np.newaxis] 421 | x_generator = gaussian_filter(x_generator, sigma=1) 422 | x_generator = zoom(x_generator, [res, res, res, 1], prefilter=False) 423 | xg_generated = sess.run(srgan_network.outputs, {t_input_gen: x_generator[np.newaxis, :]}) 424 | xg_generated = ((xg_generated + 1) * normfactor) 425 | volume_real = xt_total[0] 426 | volume_real = volume_real[:, :, :, np.newaxis] 427 | volume_generated = xg_generated[0] 428 | volume_mask = aggregate(xt_mask) 429 | # compute metrics 430 | max_gen = np.amax(volume_generated) 431 | max_real = np.amax(volume_real) 432 | if max_gen > max_real: 433 | val_max = max_gen 434 | else: 435 | val_max = max_real 436 | min_gen = np.amin(volume_generated) 437 | min_real = np.amin(volume_real) 438 | if min_gen < min_real: 439 | val_min = min_gen 440 | else: 441 | val_min = min_real 442 | val_psnr = psnr(np.multiply(volume_real, volume_mask), np.multiply(volume_generated, volume_mask), 443 | dynamic_range=val_max - val_min) 444 | array_psnr[i] = val_psnr 445 | 446 | totalpsnr += val_psnr 447 | val_ssim = ssim(np.multiply(volume_real, volume_mask), np.multiply(volume_generated, volume_mask), 448 | dynamic_range=val_max - val_min, multichannel=True) 449 | array_ssim[i] = val_ssim 450 | totalssim += val_ssim 451 | print(val_psnr) 452 | print(val_ssim) 453 | # save volumes 454 | filename_gen = os.path.join(path_volumes, str(i) + 'gen.nii.gz') 455 | img_volume_gen = nib.Nifti1Image(volume_generated, np.eye(4)) 456 | img_volume_gen.to_filename(filename_gen) 457 | filename_real = os.path.join(path_volumes, str(i) + 'real.nii.gz') 458 | img_volume_real = nib.Nifti1Image(volume_real, np.eye(4)) 459 | img_volume_real.to_filename(filename_real) 460 | 461 | print('{}{}'.format('PSNR: ', array_psnr)) 462 | print('{}{}'.format('SSIM: ', array_ssim)) 463 | print('{}{}'.format('Mean PSNR: ', array_psnr.mean())) 464 | print('{}{}'.format('Mean SSIM: ', array_ssim.mean())) 465 | print('{}{}'.format('Variance PSNR: ', array_psnr.var())) 466 | print('{}{}'.format('Variance SSIM: ', array_ssim.var())) 467 | print('{}{}'.format('Max PSNR: ', array_psnr.max())) 468 | print('{}{}'.format('Min PSNR: ', array_psnr.min())) 469 | print('{}{}'.format('Max SSIM: ', array_ssim.max())) 470 | print('{}{}'.format('Min SSIM: ', array_ssim.min())) 471 | print('{}{}'.format('Median PSNR: ', np.median(array_psnr))) 472 | print('{}{}'.format('Median SSIM: ', np.median(array_ssim))) 473 | 474 | 475 | if __name__ == '__main__': 476 | parser = argparse.ArgumentParser(description='Predict script') 477 | parser.add_argument('-path_prediction', help='Path to save training predictions') 478 | parser.add_argument('-path_volumes', help='Path to save test volumes') 479 | parser.add_argument('-checkpoint_dir', help='Path to save checkpoints') 480 | parser.add_argument('-checkpoint_dir_restore', help='Path to restore checkpoints') 481 | parser.add_argument('-residual_blocks', default=6, help='Number of residual blocks') 482 | parser.add_argument('-upsampling_factor', default=4, help='Upsampling factor') 483 | parser.add_argument('-evaluate', default=False, help='Test the model') 484 | parser.add_argument('-subpixel_NN', default=False, help='Use subpixel nearest neighbour') 485 | parser.add_argument('-nn', default=False, help='Use Upsampling3D + nearest neighbour, RC') 486 | parser.add_argument('-feature_size', default=32, help='Number of filters') 487 | parser.add_argument('-restore', default=None, help='Checkpoint path to restore training') 488 | args = parser.parse_args() 489 | 490 | if args.evaluate: 491 | evaluate(upsampling_factor=int(args.upsampling_factor), feature_size=int(args.feature_size), 492 | residual_blocks=int(args.residual_blocks), checkpoint_dir_restore=args.checkpoint_dir_restore, 493 | path_volumes=args.path_volumes, subpixel_NN=args.subpixel_NN, nn=args.nn, img_width=224, 494 | img_height=224, img_depth=152) 495 | else: 496 | train(upscaling_factor=int(args.upsampling_factor), feature_size=int(args.feature_size), 497 | subpixel_NN=args.subpixel_NN, nn=args.nn, residual_blocks=int(args.residual_blocks), 498 | path_prediction=args.path_prediction, checkpoint_dir=args.checkpoint_dir, img_width=128, 499 | img_height=128, img_depth=92, batch_size=1, restore=args.restore) 500 | --------------------------------------------------------------------------------