├── images
├── 3D SRGAN(D).png
├── 3D SRGAN(G).png
└── Upsamplings.png
├── .idea
├── vcs.xml
├── misc.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
├── 3D-GAN-superresolution.iml
└── workspace.xml
├── utils.py
├── README.md
├── dataset.py
└── model.py
/images/3D SRGAN(D).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashishpatel26/3D-GAN-superresolution/master/images/3D SRGAN(D).png
--------------------------------------------------------------------------------
/images/3D SRGAN(G).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashishpatel26/3D-GAN-superresolution/master/images/3D SRGAN(G).png
--------------------------------------------------------------------------------
/images/Upsamplings.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ashishpatel26/3D-GAN-superresolution/master/images/Upsamplings.png
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/3D-GAN-superresolution.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import keras as K
2 | from keras.utils import conv_utils
3 | from keras.layers.convolutional import UpSampling3D
4 | from keras.engine import InputSpec
5 | from tensorlayer.layers import *
6 |
7 |
8 | class UpSampling3D(Layer):
9 | def __init__(self, size=(2, 2, 2), **kwargs):
10 | self.size = conv_utils.normalize_tuple(size, 3, 'size')
11 | self.input_spec = InputSpec(ndim=5)
12 | super(UpSampling3D, self).__init__(**kwargs)
13 |
14 | def compute_output_shape(self, input_shape):
15 | dim1 = self.size[0] * input_shape[1] if input_shape[1] is not None else None
16 | dim2 = self.size[1] * input_shape[2] if input_shape[2] is not None else None
17 | dim3 = self.size[2] * input_shape[3] if input_shape[3] is not None else None
18 | return (input_shape[0],
19 | dim1,
20 | dim2,
21 | dim3,
22 | input_shape[4])
23 |
24 | def call(self, inputs):
25 | return K.resize_volumes(inputs,
26 | self.size[0], self.size[1], self.size[2],
27 | self.data_format)
28 |
29 | def get_config(self):
30 | config = {'size': self.size,
31 | 'data_format': self.data_format}
32 | base_config = super(UpSampling3D, self).get_config()
33 | return dict(list(base_config.items()) + list(config.items()))
34 |
35 |
36 | def smooth_gan_labels(y):
37 | if y == 0:
38 | y_out = tf.random_uniform(shape=y.get_shape(), minval=0.0, maxval=0.3)
39 | else:
40 | y_out = tf.random_uniform(shape=y.get_shape(), minval=0.7, maxval=1.2)
41 |
42 | return y_out
43 |
44 |
45 | def subPixelConv3d(net, img_width, img_height, img_depth, stepsToEnd, n_out_channel):
46 | i = net
47 | r = 2
48 | a, b, z, c = int(img_width / (2 * stepsToEnd)), int(img_height / (2 * stepsToEnd)), int(
49 | img_depth / (2 * stepsToEnd)), tf.shape(i)[3]
50 | bsize = tf.shape(i)[0] # Handling Dimension(None) type for undefined batch dim
51 | xs = tf.split(i, r, 4) # b*h*w*d*r*r*r
52 | xr = tf.concat(xs, 3) # b*h*w*(r*d)*r*r
53 | xss = tf.split(xr, r, 4) # b*h*w*(r*d)*r*r
54 | xrr = tf.concat(xss, 2) # b*h*(r*w)*(r*d)*r
55 | x = tf.reshape(xrr, (bsize, r * a, r * b, r * z, n_out_channel)) # b*(r*h)*(r*w)*(r*d)*n_out n_out=64/2^
56 |
57 | return x
58 |
59 |
60 | def aggregate(patches):
61 | margin = 16
62 | volume = np.empty([224, 224, 152, 1])
63 | volume[0:112, 0:112, 0:76, :] = patches[0, 0:112, 0:112, 0:76, :]
64 | volume[0:112, 0:112, 76:, :] = patches[1, 0:112, 0:112, margin:, :]
65 | volume[0:112, 112:, 0:76, :] = patches[2, 0:112, margin:, 0:76, :]
66 | volume[0:112, 112:, 76:, :] = patches[3, 0:112, margin:, margin:, :]
67 | volume[112:, 0:112, 0:76, :] = patches[4, margin:, 0:112, 0:76, :]
68 | volume[112:, 0:112, 76:, :] = patches[5, margin:, 0:112, margin:, :]
69 | volume[112:, 112:, 0:76, :] = patches[6, margin:, margin:, 0:76, :]
70 | volume[112:, 112:, 76:, :] = patches[7, margin:, margin:, margin:, :]
71 | return volume
72 |
73 |
74 | def aggregate2(patches):
75 | margin = 8
76 | volume = np.empty([112, 112, 76, 1])
77 | volume[0:56, 0:56, 0:38, :] = patches[0, 0:56, 0:56, 0:38, :]
78 | volume[0:56, 0:56, 38:, :] = patches[1, 0:56, 0:56, margin:, :]
79 | volume[0:56, 56:, 0:38, :] = patches[2, 0:56, margin:, 0:38, :]
80 | volume[0:56, 56:, 38:, :] = patches[3, 0:56, margin:, margin:, :]
81 | volume[56:, 0:56, 0:38, :] = patches[4, margin:, 0:56, 0:38, :]
82 | volume[56:, 0:56, 38:, :] = patches[5, margin:, 0:56, margin:, :]
83 | volume[56:, 56:, 0:38, :] = patches[6, margin:, margin:, 0:38, :]
84 | volume[56:, 56:, 38:, :] = patches[7, margin:, margin:, margin:, :]
85 | return volume
86 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 3D-GAN-superresolution
2 | Here we present the implementation in TensorFlow of our work to generate high resolution MRI scans from low resolution images using Generative Adversarial Networks (GANs), accepted in the [Medical Imaging with Deep Learning Conference – Amsterdam. 4 - 6th July 2018.](https://midl.amsterdam/)
3 |
4 | Discriminator network
5 | .png)
6 |
7 | Generator network
8 | .png)
9 |
10 | In this work we propose an architecture for MRI super-resolution that completely exploits the available volumetric information contained in MRI scans, using 3D convolutions to process the volumes and taking advantage of an adversarial framework, improving the realism of the generated volumes.
11 | The model is based on the [SRGAN network](https://arxiv.org/abs/1609.04802). The adversarial loss uses least squares to stabilize the training and the generator loss, in addition to the adversarial term contains a content term based on mean square error and image gradients in order to improve the quality of the generated images. We explore three different methods for the upsampling phase: an upsampling layer which uses nearest neighbors to replicate consecutive pixels followed by a convolutional layer to improve the approximation, sub-pixel convolution layers as proposed in [Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network](https://arxiv.org/abs/1609.05158) and a modification of this method [Checkerboard artifact free sub-pixel convolution](https://arxiv.org/pdf/1707.02937.pdf) that alleviates checkbock artifacts produced by sub-pixel convolution layers (Check: [Deconvolution and Checkerboard Artifacts](https://distill.pub/2016/deconv-checkerboard/) for more information).
12 |
13 | Comparison of the upsampling methods used
14 | 
15 |
16 | ### Data
17 | We used a set of normal control T1-weighted images from the Alzheimer’s Disease Neuroimaging Initiative (ADNI) database (see www.adni-info.org for details). Skull stripping is performed in all volumes and part of the background is removed. Final volumes have dimensions 224x224x152. Due to memory constraints the training is patch-based; for each volume we extract patches of size 128x128x92, with a step of 112x112x76, so there are 8 patches per volume, with an overlap of 16x16x16. We have a total number of 589 volumes, 470 are used for training while 119 are used for testing. We use batches of two patches, thus for each volume we perform 4 iterations. This code is prepared to do experiments with the processing of images and dimensions explained.
18 |
19 | The code expects that the database is inside the folder specified by the data_path in the Train_dataset script. Inside there should be a folder for each of the patients containing a 'T1_brain_extractedBrainExtractionMask.nii.gz' file. This file was created taking the original images from ADNI and performing a skull-stripping processing of them. We use the nibabel library to load the images.
20 |
21 | ### Training
22 | To train the network the model.py script is used. When calling the script you should specify:
23 | + -path_prediction: Path to save training predictions.
24 | + -checkpoint_dir: Path to save checkpoints.
25 | + -residual_blocks: Number of residual blocks.
26 | + -upsampling_factor: Upsampling factor.
27 | + -subpixel_NN: Use subpixel nearest neighbour.
28 | + -nn: Use Upsampling3D + nearest neighbour, RC.
29 | + -feature_size: Number of filters.
30 |
31 | By default it will use the sub-pixel convolution layers, 32 filters, 6 residual blocks and an umpsaling factor of 4.
32 |
33 | If you want to restore the training, when calling the script you have to define the checkpoint to use using the restore argument:
34 | ⋅⋅* -restore: Checkpoint path to restore training
35 |
36 | ```
37 | python model.py -path_prediction YOURPATH -checkpoint_dir YOURCHECKPOINTPATH -residual_blocks 8 -upsampling_factor 2 -subpixel_NN True -feature_size 64
38 | ```
39 |
40 | ### Testing
41 | To test the network the model.py script is also used. When calling the script you should specify the same arguments as before for the configuration of the model and the new paths used. Also, the argument evaluate should be True:
42 | + -path_volumes: Path to save test volumes.
43 | + -checkpoint_dir_restore: Path to restore checkpoints.
44 | + -residual_blocks: Number of residual blocks.
45 | + -upsampling_factor: Upsampling factor.
46 | + -subpixel_NN: Use subpixel nearest neighbour.
47 | + -nn: Use Upsampling3D + nearest neighbour, RC.
48 | + -feature_size: Number of filters.
49 | + -evaluate: Test the model.
50 |
51 | ```
52 | python model.py -path_volumes YOURPATH -checkpoint_dir_restore YOURCHECKPOINTPATH -residual_blocks 8 -upsampling_factor 2 -subpixel_NN True -feature_size 64 -evaluate True
53 | ```
54 |
55 | # Contact
56 | If you have any general doubt about our work or code which may be of interest for other researchers, please use the public issues section on this github repo.
57 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import nibabel as nib
3 | import math
4 | import os
5 | from skimage.util import view_as_windows
6 |
7 |
8 | class Train_dataset(object):
9 | def __init__(self, batch_size, overlapping=1):
10 | self.batch_size = batch_size
11 | self.data_path = '/imatge/isanchez/projects/neuro/ADNI-Screening-1.5T'
12 | self.subject_list = os.listdir(self.data_path)
13 | self.subject_list = np.delete(self.subject_list, 120)
14 | self.heigth_patch = 112 # 128
15 | self.width_patch = 112 # 128
16 | self.depth_patch = 76 # 92
17 | self.margin = 16
18 | self.overlapping = overlapping
19 | self.num_patches = (math.ceil((224 / (self.heigth_patch)) / (self.overlapping))) * (
20 | math.ceil((224 / (self.width_patch)) / (self.overlapping))) * (
21 | math.ceil((152 / (self.depth_patch)) / (self.overlapping)))
22 |
23 | def mask(self, iteration):
24 | subject_batch = self.subject_list[iteration * self.batch_size:self.batch_size + (iteration * self.batch_size)]
25 | subjects_true = np.empty([self.batch_size, 256, 256, 184])
26 | i = 0
27 | for subject in subject_batch:
28 | if subject != 'ADNI_SCREENING_CLINICAL_FILE_08_02_17.csv':
29 | filename = os.path.join(self.data_path, subject)
30 | filename = os.path.join(filename, 'T1_brain_extractedBrainExtractionMask.nii.gz')
31 | proxy = nib.load(filename)
32 | data = np.array(proxy.dataobj)
33 |
34 | paddwidthr = int((256 - proxy.shape[0]) / 2)
35 | paddheightr = int((256 - proxy.shape[1]) / 2)
36 | paddepthr = int((184 - proxy.shape[2]) / 2)
37 |
38 | if (paddwidthr * 2 + proxy.shape[0]) != 256:
39 | paddwidthl = paddwidthr + 1
40 | else:
41 | paddwidthl = paddwidthr
42 |
43 | if (paddheightr * 2 + proxy.shape[1]) != 256:
44 | paddheightl = paddheightr + 1
45 | else:
46 | paddheightl = paddheightr
47 |
48 | if (paddepthr * 2 + proxy.shape[2]) != 184:
49 | paddepthl = paddepthr + 1
50 | else:
51 | paddepthl = paddepthr
52 |
53 | data_padded = np.pad(data,
54 | [(paddwidthl, paddwidthr), (paddheightl, paddheightr), (paddepthl, paddepthr)],
55 | 'constant', constant_values=0)
56 | subjects_true[i] = data_padded
57 | i = i + 1
58 | mask = np.empty(
59 | [self.batch_size * self.num_patches, self.width_patch + self.margin, self.heigth_patch + self.margin,
60 | self.depth_patch + self.margin, 1])
61 | i = 0
62 | for subject in subjects_true:
63 | patch = view_as_windows(subject, window_shape=(
64 | (self.width_patch + self.margin), (self.heigth_patch + self.margin), (self.depth_patch + self.margin)),
65 | step=(self.width_patch - self.margin, self.heigth_patch - self.margin,
66 | self.depth_patch - self.margin))
67 | for d in range(patch.shape[0]):
68 | for v in range(patch.shape[1]):
69 | for h in range(patch.shape[2]):
70 | p = patch[d, v, h, :]
71 | p = p[:, np.newaxis]
72 | p = p.transpose((0, 2, 3, 1))
73 | mask[i] = p
74 | i = i + 1
75 | return mask
76 |
77 | def patches_true(self, iteration):
78 | subjects_true = self.data_true(iteration)
79 | patches_true = np.empty(
80 | [self.batch_size * self.num_patches, self.width_patch + self.margin, self.heigth_patch + self.margin,
81 | self.depth_patch + self.margin, 1])
82 | i = 0
83 | for subject in subjects_true:
84 | patch = view_as_windows(subject, window_shape=(
85 | (self.width_patch + self.margin), (self.heigth_patch + self.margin), (self.depth_patch + self.margin)),
86 | step=(self.width_patch - self.margin, self.heigth_patch - self.margin,
87 | self.depth_patch - self.margin))
88 | for d in range(patch.shape[0]):
89 | for v in range(patch.shape[1]):
90 | for h in range(patch.shape[2]):
91 | p = patch[d, v, h, :]
92 | p = p[:, np.newaxis]
93 | p = p.transpose((0, 2, 3, 1))
94 | patches_true[i] = p
95 | i = i + 1
96 | return patches_true
97 |
98 | def data_true(self, iteration):
99 | subject_batch = self.subject_list[iteration * self.batch_size:self.batch_size + (iteration * self.batch_size)]
100 | subjects = np.empty([self.batch_size, 224, 224, 152])
101 | i = 0
102 | for subject in subject_batch:
103 | if subject != 'ADNI_SCREENING_CLINICAL_FILE_08_02_17.csv':
104 | filename = os.path.join(self.data_path, subject)
105 | filename = os.path.join(filename, 'T1_brain_extractedBrainExtractionBrain.nii.gz')
106 | proxy = nib.load(filename)
107 | data = np.array(proxy.dataobj)
108 |
109 | paddwidthr = int((256 - proxy.shape[0]) / 2)
110 | paddheightr = int((256 - proxy.shape[1]) / 2)
111 | paddepthr = int((184 - proxy.shape[2]) / 2)
112 |
113 | if (paddwidthr * 2 + proxy.shape[0]) != 256:
114 | paddwidthl = paddwidthr + 1
115 | else:
116 | paddwidthl = paddwidthr
117 |
118 | if (paddheightr * 2 + proxy.shape[1]) != 256:
119 | paddheightl = paddheightr + 1
120 | else:
121 | paddheightl = paddheightr
122 |
123 | if (paddepthr * 2 + proxy.shape[2]) != 184:
124 | paddepthl = paddepthr + 1
125 | else:
126 | paddepthl = paddepthr
127 |
128 | data_padded = np.pad(data,
129 | [(paddwidthl, paddwidthr), (paddheightl, paddheightr), (paddepthl, paddepthr)],
130 | 'constant', constant_values=0)
131 |
132 | subjects[i] = data_padded[16:240, 16:240, 16:168] # remove background
133 | i = i + 1
134 | return subjects
135 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 | true
80 | DEFINITION_ORDER
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 | 1529748036876
289 |
290 |
291 | 1529748036876
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorlayer as tl
3 | from tensorlayer.layers import *
4 | from dataset import Train_dataset
5 | import math
6 | from scipy.ndimage.interpolation import zoom
7 | from scipy.ndimage.filters import gaussian_filter
8 | from utils import smooth_gan_labels, aggregate, subPixelConv3d
9 | import nibabel as nib
10 | import os
11 | from skimage.measure import compare_ssim as ssim
12 | from skimage.measure import compare_psnr as psnr
13 | from keras.layers.convolutional import UpSampling3D
14 | import argparse
15 |
16 |
17 | def lrelu1(x):
18 | return tf.maximum(x, 0.25 * x)
19 |
20 |
21 | def lrelu2(x):
22 | return tf.maximum(x, 0.3 * x)
23 |
24 |
25 | def discriminator(input_disc, kernel, reuse, is_train=True):
26 | w_init = tf.random_normal_initializer(stddev=0.02)
27 | batch_size = 1
28 | div_patches = 4
29 | num_patches = 8
30 | img_width = 128
31 | img_height = 128
32 | img_depth = 92
33 | with tf.variable_scope("SRGAN_d", reuse=reuse):
34 | tl.layers.set_name_reuse(reuse)
35 | input_disc.set_shape([int((batch_size * num_patches) / div_patches), img_width, img_height, img_depth, 1], )
36 | x = InputLayer(input_disc, name='in')
37 | x = Conv3dLayer(x, act=lrelu2, shape=[kernel, kernel, kernel, 1, 32], strides=[1, 1, 1, 1, 1],
38 | padding='SAME', W_init=w_init, name='conv1')
39 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 32, 32], strides=[1, 2, 2, 2, 1],
40 | padding='SAME', W_init=w_init, name='conv2')
41 |
42 | x = BatchNormLayer(x, is_train=is_train, name='BN1-conv2', act=lrelu2)
43 |
44 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 32, 64], strides=[1, 1, 1, 1, 1],
45 | padding='SAME', W_init=w_init, name='conv3')
46 | x = BatchNormLayer(x, is_train=is_train, name='BN1-conv3', act=lrelu2)
47 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 64, 64], strides=[1, 2, 2, 2, 1],
48 | padding='SAME', W_init=w_init, name='conv4')
49 | x = BatchNormLayer(x, is_train=is_train, name='BN1-conv4', act=lrelu2)
50 |
51 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 64, 128], strides=[1, 1, 1, 1, 1],
52 | padding='SAME', W_init=w_init, name='conv5')
53 | x = BatchNormLayer(x, is_train=is_train, name='BN1-conv5', act=lrelu2)
54 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 128, 128], strides=[1, 2, 2, 2, 1],
55 | padding='SAME', W_init=w_init, name='conv6')
56 | x = BatchNormLayer(x, is_train=is_train, name='BN1-conv6', act=lrelu2)
57 |
58 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 128, 256], strides=[1, 1, 1, 1, 1],
59 | padding='SAME', W_init=w_init, name='conv7')
60 | x = BatchNormLayer(x, is_train=is_train, name='BN1-conv7', act=lrelu2)
61 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 256, 256], strides=[1, 2, 2, 2, 1],
62 | padding='SAME', W_init=w_init, name='conv8')
63 | x = BatchNormLayer(x, is_train=is_train, name='BN1-conv8', act=lrelu2)
64 |
65 | x = FlattenLayer(x, name='flatten')
66 | x = DenseLayer(x, n_units=1024, act=lrelu2, name='dense1')
67 | x = DenseLayer(x, n_units=1, name='dense2')
68 |
69 | logits = x.outputs
70 | x.outputs = tf.nn.sigmoid(x.outputs, name='output')
71 |
72 | return x, logits
73 |
74 |
75 | def generator(input_gen, kernel, nb, upscaling_factor, reuse, feature_size, img_width, img_height, img_depth,
76 | subpixel_NN, nn, is_train=True):
77 | w_init = tf.random_normal_initializer(stddev=0.02)
78 |
79 | w_init_subpixel1 = np.random.normal(scale=0.02, size=[3, 3, 3, 64, feature_size])
80 | w_init_subpixel1 = zoom(w_init_subpixel1, [2, 2, 2, 1, 1], order=0)
81 | w_init_subpixel1_last = tf.constant_initializer(w_init_subpixel1)
82 | w_init_subpixel2 = np.random.normal(scale=0.02, size=[3, 3, 3, 64, 64])
83 | w_init_subpixel2 = zoom(w_init_subpixel2, [2, 2, 2, 1, 1], order=0)
84 | w_init_subpixel2_last = tf.constant_initializer(w_init_subpixel2)
85 |
86 | with tf.variable_scope("SRGAN_g", reuse=reuse):
87 | tl.layers.set_name_reuse(reuse)
88 | x = InputLayer(input_gen, name='in')
89 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 1, feature_size], strides=[1, 1, 1, 1, 1],
90 | padding='SAME', W_init=w_init, name='conv1')
91 | x = BatchNormLayer(x, act=lrelu1, is_train=is_train, name='BN-conv1')
92 | inputRB = x
93 | inputadd = x
94 |
95 | # residual blocks
96 | for i in range(nb):
97 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, feature_size, feature_size], strides=[1, 1, 1, 1, 1],
98 | padding='SAME', W_init=w_init, name='conv1-rb/%s' % i)
99 | x = BatchNormLayer(x, act=lrelu1, is_train=is_train, name='BN1-rb/%s' % i)
100 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, feature_size, feature_size], strides=[1, 1, 1, 1, 1],
101 | padding='SAME', W_init=w_init, name='conv2-rb/%s' % i)
102 | x = BatchNormLayer(x, is_train=is_train, name='BN2-rb/%s' % i, )
103 | # short skip connection
104 | x = ElementwiseLayer([x, inputadd], tf.add, name='add-rb/%s' % i)
105 | inputadd = x
106 |
107 | # large skip connection
108 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, feature_size, feature_size], strides=[1, 1, 1, 1, 1],
109 | padding='SAME', W_init=w_init, name='conv2')
110 | x = BatchNormLayer(x, is_train=is_train, name='BN-conv2')
111 | x = ElementwiseLayer([x, inputRB], tf.add, name='add-conv2')
112 |
113 | # ____________SUBPIXEL-NN______________#
114 |
115 | if subpixel_NN:
116 | # upscaling block 1
117 | if upscaling_factor == 4:
118 | img_height_deconv = int(img_height / 2)
119 | img_width_deconv = int(img_width / 2)
120 | img_depth_deconv = int(img_depth / 2)
121 | else:
122 | img_height_deconv = img_height
123 | img_width_deconv = img_width
124 | img_depth_deconv = img_depth
125 |
126 | x = DeConv3dLayer(x, shape=[kernel * 2, kernel * 2, kernel * 2, 64, feature_size],
127 | act=lrelu1, strides=[1, 2, 2, 2, 1],
128 | output_shape=[tf.shape(input_gen)[0], img_height_deconv, img_width_deconv,
129 | img_depth_deconv, 64],
130 | padding='SAME', W_init=w_init_subpixel1_last, name='conv1-ub-subpixelnn/1')
131 |
132 | # upscaling block 2
133 | if upscaling_factor == 4:
134 | x = DeConv3dLayer(x, shape=[kernel * 2, kernel * 2, kernel * 2, 64, 64],
135 | act=lrelu1, strides=[1, 2, 2, 2, 1], padding='SAME',
136 | output_shape=[tf.shape(input_gen)[0], img_height, img_width,
137 | img_depth, 64],
138 | W_init=w_init_subpixel2_last, name='conv1-ub-subpixelnn/2')
139 |
140 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 64, 1], strides=[1, 1, 1, 1, 1],
141 | padding='SAME', W_init=w_init, name='convlast-subpixelnn')
142 |
143 | # ____________RC______________#
144 |
145 | elif nn:
146 | # upscaling block 1
147 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, feature_size, 64], act=lrelu1,
148 | strides=[1, 1, 1, 1, 1],
149 | padding='SAME', W_init=w_init, name='conv1-ub/1')
150 | x = UpSampling3D(name='UpSampling3D_1')(x.outputs)
151 | x = Conv3dLayer(InputLayer(x, name='in ub1 conv2'),
152 | shape=[kernel, kernel, kernel, 64, 64],
153 | act=lrelu1,
154 | strides=[1, 1, 1, 1, 1],
155 | padding='SAME', W_init=w_init, name='conv2-ub/1')
156 |
157 | # upscaling block 2
158 | if upscaling_factor == 4:
159 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 64, 64], act=lrelu1,
160 | strides=[1, 1, 1, 1, 1],
161 | padding='SAME', W_init=w_init, name='conv1-ub/2')
162 | x = UpSampling3D(name='UpSampling3D_1')(x.outputs)
163 | x = Conv3dLayer(InputLayer(x, name='in ub2 conv2'), shape=[kernel, kernel, kernel, 64,
164 | 64], act=lrelu1,
165 | strides=[1, 1, 1, 1, 1],
166 | padding='SAME', W_init=w_init, name='conv2-ub/2')
167 |
168 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, 64, 1], strides=[1, 1, 1, 1, 1],
169 | act=tf.nn.tanh, padding='SAME', W_init=w_init, name='convlast')
170 |
171 | # ____________SUBPIXEL - BASELINE______________#
172 |
173 | else:
174 |
175 | if upscaling_factor == 4:
176 | steps_to_end = 2
177 | else:
178 | steps_to_end = 1
179 |
180 | # upscaling block 1
181 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, feature_size, 64], act=lrelu1,
182 | strides=[1, 1, 1, 1, 1],
183 | padding='SAME', W_init=w_init, name='conv1-ub/1')
184 | arguments = {'img_width': img_width, 'img_height': img_height, 'img_depth': img_depth,
185 | 'stepsToEnd': steps_to_end,
186 | 'n_out_channel': int(64 / 8)}
187 | x = LambdaLayer(x, fn=subPixelConv3d, fn_args=arguments, name='SubPixel1')
188 |
189 | # upscaling block 2
190 | if upscaling_factor == 4:
191 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, int((64) / 8), 64], act=lrelu1,
192 | strides=[1, 1, 1, 1, 1],
193 | padding='SAME', W_init=w_init, name='conv1-ub/2')
194 | arguments = {'img_width': img_width, 'img_height': img_height, 'img_depth': img_depth, 'stepsToEnd': 1,
195 | 'n_out_channel': int(64 / 8)}
196 | x = LambdaLayer(x, fn=subPixelConv3d, fn_args=arguments, name='SubPixel2')
197 |
198 | x = Conv3dLayer(x, shape=[kernel, kernel, kernel, int(64 / 8), 1], strides=[1, 1, 1, 1, 1],
199 | padding='SAME', W_init=w_init, name='convlast')
200 |
201 | return x
202 |
203 |
204 | def train(upscaling_factor, residual_blocks, feature_size, path_prediction, checkpoint_dir, img_width, img_height,
205 | img_depth, subpixel_NN, nn, restore, batch_size=1, div_patches=4, epochs=10):
206 | traindataset = Train_dataset(batch_size)
207 | iterations_train = math.ceil((len(traindataset.subject_list) * 0.8) / batch_size)
208 | num_patches = traindataset.num_patches
209 |
210 | # ##========================== DEFINE MODEL ============================##
211 | t_input_gen = tf.placeholder('float32', [int((batch_size * num_patches) / div_patches), None,
212 | None, None, 1],
213 | name='t_image_input_to_SRGAN_generator')
214 | t_target_image = tf.placeholder('float32', [int((batch_size * num_patches) / div_patches),
215 | img_width, img_height, img_depth, 1],
216 | name='t_target_image')
217 | t_input_mask = tf.placeholder('float32', [int((batch_size * num_patches) / div_patches),
218 | img_width, img_height, img_depth, 1],
219 | name='t_image_input_mask')
220 |
221 | net_gen = generator(input_gen=t_input_gen, kernel=3, nb=residual_blocks, upscaling_factor=upscaling_factor,
222 | img_height=img_height, img_width=img_width, img_depth=img_depth, subpixel_NN=subpixel_NN, nn=nn,
223 | feature_size=feature_size, is_train=True, reuse=False)
224 | net_d, disc_out_real = discriminator(input_disc=t_target_image, kernel=3, is_train=True, reuse=False)
225 | _, disc_out_fake = discriminator(input_disc=net_gen.outputs, kernel=3, is_train=True, reuse=True)
226 |
227 | # test
228 | gen_test = generator(t_input_gen, kernel=3, nb=residual_blocks, upscaling_factor=upscaling_factor,
229 | img_height=img_height, img_width=img_width, img_depth=img_depth, subpixel_NN=subpixel_NN,
230 | nn=nn,
231 | feature_size=feature_size, is_train=True, reuse=True)
232 |
233 | # ###========================== DEFINE TRAIN OPS ==========================###
234 |
235 | if np.random.uniform() > 0.1:
236 | # give correct classifications
237 | y_gan_real = tf.ones_like(disc_out_real)
238 | y_gan_fake = tf.zeros_like(disc_out_real)
239 | else:
240 | # give wrong classifications (noisy labels)
241 | y_gan_real = tf.zeros_like(disc_out_real)
242 | y_gan_fake = tf.ones_like(disc_out_real)
243 |
244 | d_loss_real = tf.reduce_mean(tf.square(disc_out_real - smooth_gan_labels(y_gan_real)),
245 | name='d_loss_real')
246 | d_loss_fake = tf.reduce_mean(tf.square(disc_out_fake - smooth_gan_labels(y_gan_fake)),
247 | name='d_loss_fake')
248 | d_loss = d_loss_real + d_loss_fake
249 |
250 | mse_loss = tf.reduce_sum(
251 | tf.square(net_gen.outputs - t_target_image), axis=[0, 1, 2, 3, 4], name='g_loss_mse')
252 |
253 | dx_real = t_target_image[:, 1:, :, :, :] - t_target_image[:, :-1, :, :, :]
254 | dy_real = t_target_image[:, :, 1:, :, :] - t_target_image[:, :, :-1, :, :]
255 | dz_real = t_target_image[:, :, :, 1:, :] - t_target_image[:, :, :, :-1, :]
256 | dx_fake = net_gen.outputs[:, 1:, :, :, :] - net_gen.outputs[:, :-1, :, :, :]
257 | dy_fake = net_gen.outputs[:, :, 1:, :, :] - net_gen.outputs[:, :, :-1, :, :]
258 | dz_fake = net_gen.outputs[:, :, :, 1:, :] - net_gen.outputs[:, :, :, :-1, :]
259 |
260 | gd_loss = tf.reduce_sum(tf.square(tf.abs(dx_real) - tf.abs(dx_fake))) + \
261 | tf.reduce_sum(tf.square(tf.abs(dy_real) - tf.abs(dy_fake))) + \
262 | tf.reduce_sum(tf.square(tf.abs(dz_real) - tf.abs(dz_fake)))
263 |
264 | g_gan_loss = 10e-2 * tf.reduce_mean(tf.square(disc_out_fake - smooth_gan_labels(tf.ones_like(disc_out_real))),
265 | name='g_loss_gan')
266 |
267 | g_loss = mse_loss + g_gan_loss + gd_loss
268 |
269 | g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
270 | d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)
271 |
272 | with tf.variable_scope('learning_rate'):
273 | lr_v = tf.Variable(1e-4, trainable=False)
274 | global_step = tf.Variable(0, trainable=False)
275 | decay_rate = 0.5
276 | decay_steps = 4920 # every 2 epochs (more or less)
277 | learning_rate = tf.train.inverse_time_decay(lr_v, global_step=global_step, decay_rate=decay_rate,
278 | decay_steps=decay_steps)
279 |
280 | # Optimizers
281 | g_optim = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
282 | d_optim = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
283 |
284 | session = tf.Session()
285 | tl.layers.initialize_global_variables(session)
286 |
287 | step = 0
288 | saver = tf.train.Saver()
289 |
290 | if restore is not None:
291 | saver.restore(session, tf.train.latest_checkpoint(restore))
292 | val_restore = 0 * epochs
293 | else:
294 | val_restore = 0
295 |
296 | array_psnr = []
297 | array_ssim = []
298 |
299 | for j in range(val_restore, epochs + val_restore):
300 | for i in range(0, iterations_train):
301 | # ====================== LOAD DATA =========================== #
302 | xt_total = traindataset.patches_true(i)
303 | xm_total = traindataset.mask(i)
304 | for k in range(0, div_patches):
305 | print('{}'.format(k))
306 | xt = xt_total[k * int((batch_size * num_patches) / div_patches):(int(
307 | (batch_size * num_patches) / div_patches) * k) + int(
308 | (batch_size * num_patches) / div_patches)]
309 | xm = xm_total[k * int((batch_size * num_patches) / div_patches):(int(
310 | (batch_size * num_patches) / div_patches) * k) + int(
311 | (batch_size * num_patches) / div_patches)]
312 |
313 | # NORMALIZING
314 | for t in range(0, xt.shape[0]):
315 | normfactor = (np.amax(xt[t])) / 2
316 | if normfactor != 0:
317 | xt[t] = ((xt[t] - normfactor) / normfactor)
318 |
319 | x_generator = gaussian_filter(xt, sigma=1)
320 | x_generator = zoom(x_generator, [1, (1 / upscaling_factor), (1 / upscaling_factor),
321 | (1 / upscaling_factor), 1], prefilter=False, order=0)
322 | xgenin = x_generator
323 |
324 | # ========================= train SRGAN ========================= #
325 | # update D
326 | errd, _ = session.run([d_loss, d_optim], {t_target_image: xt, t_input_gen: xgenin})
327 | # update G
328 | errg, errmse, errgan, errgd, _ = session.run([g_loss, mse_loss, g_gan_loss, gd_loss, g_optim],
329 | {t_input_gen: xgenin, t_target_image: xt,
330 | t_input_mask: xm})
331 | print(
332 | "Epoch [%2d/%2d] [%4d/%4d] [%4d/%4d]: d_loss: %.8f g_loss: %.8f (mse: %.6f gdl: %.6f adv: %.6f)" % (
333 | j, epochs + val_restore, i, iterations_train, k, div_patches - 1, errd, errg, errmse, errgd,
334 | errgan))
335 |
336 | # ========================= evaluate & save model ========================= #
337 |
338 | if k == 1 and i % 20 == 0:
339 | if j - val_restore == 0:
340 | x_true_img = xt[0]
341 | if normfactor != 0:
342 | x_true_img = ((x_true_img + 1) * normfactor) # denormalize
343 | img_true = nib.Nifti1Image(x_true_img, np.eye(4))
344 | img_true.to_filename(
345 | os.path.join(path_prediction, str(j) + str(i) + 'true.nii.gz'))
346 |
347 | x_gen_img = xgenin[0]
348 | if normfactor != 0:
349 | x_gen_img = ((x_gen_img + 1) * normfactor) # denormalize
350 | img_gen = nib.Nifti1Image(x_gen_img, np.eye(4))
351 | img_gen.to_filename(
352 | os.path.join(path_prediction, str(j) + str(i) + 'gen.nii.gz'))
353 |
354 | x_pred = session.run(gen_test.outputs, {t_input_gen: xgenin})
355 | x_pred_img = x_pred[0]
356 | if normfactor != 0:
357 | x_pred_img = ((x_pred_img + 1) * normfactor) # denormalize
358 | img_pred = nib.Nifti1Image(x_pred_img, np.eye(4))
359 | img_pred.to_filename(
360 | os.path.join(path_prediction, str(j) + str(i) + '.nii.gz'))
361 |
362 | max_gen = np.amax(x_pred_img)
363 | max_real = np.amax(x_true_img)
364 | if max_gen > max_real:
365 | val_max = max_gen
366 | else:
367 | val_max = max_real
368 | min_gen = np.amin(x_pred_img)
369 | min_real = np.amin(x_true_img)
370 | if min_gen < min_real:
371 | val_min = min_gen
372 | else:
373 | val_min = min_real
374 | val_psnr = psnr(np.multiply(x_true_img, xm[0]), np.multiply(x_pred_img, xm[0]),
375 | dynamic_range=val_max - val_min)
376 | val_ssim = ssim(np.multiply(x_true_img, xm[0]), np.multiply(x_pred_img, xm[0]),
377 | dynamic_range=val_max - val_min, multichannel=True)
378 |
379 | saver.save(sess=session, save_path=checkpoint_dir, global_step=step)
380 | print("Saved step: [%2d]" % step)
381 | step = step + 1
382 |
383 |
384 | def evaluate(upsampling_factor, residual_blocks, feature_size, checkpoint_dir_restore, path_volumes, nn, subpixel_NN,
385 | img_height, img_width, img_depth):
386 | traindataset = Train_dataset(1)
387 | iterations = math.ceil(
388 | (len(traindataset.subject_list) * 0.2))
389 | print(len(traindataset.subject_list))
390 | print(iterations)
391 | totalpsnr = 0
392 | totalssim = 0
393 | array_psnr = np.empty(iterations)
394 | array_ssim = np.empty(iterations)
395 | batch_size = 1
396 | div_patches = 4
397 | num_patches = traindataset.num_patches
398 |
399 | # define model
400 | t_input_gen = tf.placeholder('float32', [1, None, None, None, 1],
401 | name='t_image_input_to_SRGAN_generator')
402 | srgan_network = generator(input_gen=t_input_gen, kernel=3, nb=residual_blocks,
403 | upscaling_factor=upsampling_factor, feature_size=feature_size, subpixel_NN=subpixel_NN,
404 | img_height=img_height, img_width=img_width, img_depth=img_depth, nn=nn,
405 | is_train=False, reuse=False)
406 |
407 | # restore g
408 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
409 |
410 | saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="SRGAN_g"))
411 | saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir_restore))
412 |
413 | for i in range(0, iterations):
414 | # extract volumes
415 | xt_total = traindataset.data_true(654 + i)
416 | xt_mask = traindataset.mask(654 + i)
417 | normfactor = (np.amax(xt_total[0])) / 2
418 | x_generator = ((xt_total[0] - normfactor) / normfactor)
419 | res = 1 / upsampling_factor
420 | x_generator = x_generator[:, :, :, np.newaxis]
421 | x_generator = gaussian_filter(x_generator, sigma=1)
422 | x_generator = zoom(x_generator, [res, res, res, 1], prefilter=False)
423 | xg_generated = sess.run(srgan_network.outputs, {t_input_gen: x_generator[np.newaxis, :]})
424 | xg_generated = ((xg_generated + 1) * normfactor)
425 | volume_real = xt_total[0]
426 | volume_real = volume_real[:, :, :, np.newaxis]
427 | volume_generated = xg_generated[0]
428 | volume_mask = aggregate(xt_mask)
429 | # compute metrics
430 | max_gen = np.amax(volume_generated)
431 | max_real = np.amax(volume_real)
432 | if max_gen > max_real:
433 | val_max = max_gen
434 | else:
435 | val_max = max_real
436 | min_gen = np.amin(volume_generated)
437 | min_real = np.amin(volume_real)
438 | if min_gen < min_real:
439 | val_min = min_gen
440 | else:
441 | val_min = min_real
442 | val_psnr = psnr(np.multiply(volume_real, volume_mask), np.multiply(volume_generated, volume_mask),
443 | dynamic_range=val_max - val_min)
444 | array_psnr[i] = val_psnr
445 |
446 | totalpsnr += val_psnr
447 | val_ssim = ssim(np.multiply(volume_real, volume_mask), np.multiply(volume_generated, volume_mask),
448 | dynamic_range=val_max - val_min, multichannel=True)
449 | array_ssim[i] = val_ssim
450 | totalssim += val_ssim
451 | print(val_psnr)
452 | print(val_ssim)
453 | # save volumes
454 | filename_gen = os.path.join(path_volumes, str(i) + 'gen.nii.gz')
455 | img_volume_gen = nib.Nifti1Image(volume_generated, np.eye(4))
456 | img_volume_gen.to_filename(filename_gen)
457 | filename_real = os.path.join(path_volumes, str(i) + 'real.nii.gz')
458 | img_volume_real = nib.Nifti1Image(volume_real, np.eye(4))
459 | img_volume_real.to_filename(filename_real)
460 |
461 | print('{}{}'.format('PSNR: ', array_psnr))
462 | print('{}{}'.format('SSIM: ', array_ssim))
463 | print('{}{}'.format('Mean PSNR: ', array_psnr.mean()))
464 | print('{}{}'.format('Mean SSIM: ', array_ssim.mean()))
465 | print('{}{}'.format('Variance PSNR: ', array_psnr.var()))
466 | print('{}{}'.format('Variance SSIM: ', array_ssim.var()))
467 | print('{}{}'.format('Max PSNR: ', array_psnr.max()))
468 | print('{}{}'.format('Min PSNR: ', array_psnr.min()))
469 | print('{}{}'.format('Max SSIM: ', array_ssim.max()))
470 | print('{}{}'.format('Min SSIM: ', array_ssim.min()))
471 | print('{}{}'.format('Median PSNR: ', np.median(array_psnr)))
472 | print('{}{}'.format('Median SSIM: ', np.median(array_ssim)))
473 |
474 |
475 | if __name__ == '__main__':
476 | parser = argparse.ArgumentParser(description='Predict script')
477 | parser.add_argument('-path_prediction', help='Path to save training predictions')
478 | parser.add_argument('-path_volumes', help='Path to save test volumes')
479 | parser.add_argument('-checkpoint_dir', help='Path to save checkpoints')
480 | parser.add_argument('-checkpoint_dir_restore', help='Path to restore checkpoints')
481 | parser.add_argument('-residual_blocks', default=6, help='Number of residual blocks')
482 | parser.add_argument('-upsampling_factor', default=4, help='Upsampling factor')
483 | parser.add_argument('-evaluate', default=False, help='Test the model')
484 | parser.add_argument('-subpixel_NN', default=False, help='Use subpixel nearest neighbour')
485 | parser.add_argument('-nn', default=False, help='Use Upsampling3D + nearest neighbour, RC')
486 | parser.add_argument('-feature_size', default=32, help='Number of filters')
487 | parser.add_argument('-restore', default=None, help='Checkpoint path to restore training')
488 | args = parser.parse_args()
489 |
490 | if args.evaluate:
491 | evaluate(upsampling_factor=int(args.upsampling_factor), feature_size=int(args.feature_size),
492 | residual_blocks=int(args.residual_blocks), checkpoint_dir_restore=args.checkpoint_dir_restore,
493 | path_volumes=args.path_volumes, subpixel_NN=args.subpixel_NN, nn=args.nn, img_width=224,
494 | img_height=224, img_depth=152)
495 | else:
496 | train(upscaling_factor=int(args.upsampling_factor), feature_size=int(args.feature_size),
497 | subpixel_NN=args.subpixel_NN, nn=args.nn, residual_blocks=int(args.residual_blocks),
498 | path_prediction=args.path_prediction, checkpoint_dir=args.checkpoint_dir, img_width=128,
499 | img_height=128, img_depth=92, batch_size=1, restore=args.restore)
500 |
--------------------------------------------------------------------------------