├── images ├── T1_FA_MD.jpg └── T1_subject_1_1000_slice_66.png ├── code ├── main.py └── CycleGAN.py └── README.md /images/T1_FA_MD.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuagu37/CycleGAN/HEAD/images/T1_FA_MD.jpg -------------------------------------------------------------------------------- /images/T1_subject_1_1000_slice_66.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuagu37/CycleGAN/HEAD/images/T1_subject_1_1000_slice_66.png -------------------------------------------------------------------------------- /code/main.py: -------------------------------------------------------------------------------- 1 | 2 | import CycleGAN 3 | from CycleGAN import * 4 | 5 | # Create a CycleGAN on GPU 0 6 | myCycleGAN = CycleGAN(0) 7 | 8 | trainA_dir = '/home/xuagu37/CycleGAN/data/T1_training.nii.gz' 9 | trainB_dir = '/home/xuagu37/CycleGAN/data/FA_training.nii.gz' 10 | models_dir = '/home/xuagu37/CycleGAN/train_T1_FA/models' 11 | output_sample_dir = '/home/xuagu37/CycleGAN/train_T1_FA/output_sample.png' 12 | batch_size = 10 13 | epochs = 200 14 | normalization_factor_A = 1000 15 | normalization_factor_B = 1 16 | myCycleGAN.train(trainA_dir, normalization_factor_A, trainB_dir, normalization_factor_B, models_dir, batch_size, epochs, output_sample_dir=output_sample_dir, output_sample_channels=1) 17 | 18 | for epoch in range(20, 201, 20): 19 | G_X2Y_dir = '/home/xuagu37/CycleGAN/train_T1_FA/models/G_A2B_weights_epoch_' + str(epoch) + '.hdf5' 20 | print(G_X2Y_dir) 21 | test_X_dir = '/home/xuagu37/CycleGAN/data/T1_test.nii.gz' 22 | synthetic_Y_dir = '/home/xuagu37/CycleGAN/train_T1_FA/synthetic/FA_synthetic_epoch_' + str(epoch) + '.nii.gz' 23 | normalization_factor_X = 1000 24 | normalization_factor_Y = 1 25 | myCycleGAN.synthesize(G_X2Y_dir, test_X_dir, normalization_factor_X, synthetic_Y_dir, normalization_factor_Y) 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Keras CycleGAN for nifti data 2 | We provide a Keras implementation for unpaired and image-to-image translation, i.e. CycleGAN [1], for nifti data. 3 | 4 | We proposed an application of CycleGAN to generate synthetic diffusion MRI scalar maps from structural T1-weighted images, see our paper [2]. 5 | 6 | ![T1 to FA/MD](https://github.com/xuagu37/CycleGAN/blob/master/images/T1_FA_MD.jpg) 7 | 8 | ## Getting started 9 | ### Prepare training data 10 | We prepare the training data by stacking subjects on the fourth dimention. 11 | For example, we extract 1 slice from each subject for 1000 subjects and then stack all slices to the fourth dimention. 12 | The created training data will have the size of [X, Y, 1, 1000]. 13 | ![T1_subject_1_1000_slice_66](https://github.com/xuagu37/CycleGAN/blob/master/images/T1_subject_1_1000_slice_66.png) 14 | 15 | ### Training 16 | ```python 17 | # Create a CycleGAN on GPU 0 18 | myCycleGAN = CycleGAN(0) 19 | 20 | # Set directories 21 | trainA_dir = '/home/xuagu37/CycleGAN/data/T1_training.nii.gz' 22 | trainB_dir = '/home/xuagu37/CycleGAN/data/FA_training.nii.gz' 23 | models_dir = '/home/xuagu37/CycleGAN/train_T1_FA/models' 24 | output_sample_dir = '/home/xuagu37/CycleGAN/train_T1_FA/output_sample.png' 25 | 26 | # Set training parameters 27 | batch_size = 10 28 | epochs = 200 29 | normalization_factor_A = 1000 30 | normalization_factor_B = 1 31 | 32 | # Start training 33 | myCycleGAN.train(trainA_dir, normalization_factor_A, trainB_dir, normalization_factor_B, models_dir, batch_size, epochs, output_sample_dir=output_sample_dir, output_sample_channels=1) 34 | ``` 35 | 36 | ### Synthesize 37 | ```python 38 | # Set directory to the trained model 39 | G_X2Y_dir = '/home/xuagu37/CycleGAN/train_T1_FA/models/G_A2B_weights_epoch_100.hdf5' 40 | 41 | # Set directory to the test data 42 | test_X_dir = '/home/xuagu37/CycleGAN/data/T1_test.nii.gz' 43 | 44 | # Set directory to save the synthetic data 45 | synthetic_Y_dir ='/home/xuagu37/CycleGAN/train_T1_FA/synthetic/FA_synthetic.nii.gz' 46 | 47 | # Synthesize 48 | normalization_factor_X = 1000 49 | normalization_factor_Y = 1 50 | myCycleGAN.synthesize(G_X2Y_dir, test_X_dir, normalization_factor_X, synthetic_Y_dir, normalization_factor_Y) 51 | ``` 52 | 53 | 54 | ## References 55 | [1] Zhu, J.Y., Park, T., Isola, P. and Efros, A.A., 2017. Unpaired Image-to-Image Translation Using Cycle-Consistent Adversarial Networks. In 2017 IEEE International Conference on Computer Vision (ICCV) (pp. 2242-2251). IEEE. 56 | [2] Gu, X., Knutsson, H., Nilsson, M. and Eklund, A., 2019. Generating diffusion MRI scalar maps from T1 weighted images using generative adversarial networks. In Scandinavian Conference on Image Analysis (pp. 489-498). Springer, Cham. 57 | 58 | -------------------------------------------------------------------------------- /code/CycleGAN.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 3 | # The GPU id to use, usually either "0" or "1" 4 | # os.environ["CUDA_VISIBLE_DEVICES"]="0" 5 | import nibabel as nib 6 | from keras.layers import Dropout, Layer, Input, Conv2D, Activation, add, BatchNormalization, Conv2DTranspose, UpSampling2D 7 | from keras_contrib.layers.normalization import InstanceNormalization, InputSpec 8 | from keras.layers.advanced_activations import LeakyReLU 9 | from keras.optimizers import Adam 10 | from keras.backend import mean 11 | from keras.models import Model 12 | from keras.engine.topology import Network 13 | from scipy.misc import imsave, toimage 14 | import numpy as np 15 | import random 16 | import datetime 17 | import time 18 | import math 19 | import sys 20 | import keras.backend as K 21 | import tensorflow as tf 22 | import datetime 23 | 24 | 25 | 26 | tf.logging.set_verbosity(tf.logging.ERROR) 27 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 28 | 29 | print('CycleGAN loaded...') 30 | 31 | class CycleGAN(): 32 | def __init__(self, selected_gpu): 33 | os.environ["CUDA_VISIBLE_DEVICES"]=str(selected_gpu) 34 | print('Initializing a CycleGAN on GPU ' + os.environ["CUDA_VISIBLE_DEVICES"]) 35 | self.normalization = InstanceNormalization 36 | # Hyper parameters 37 | self.lr_D = 2e-4 38 | self.lr_G = 2e-4 39 | self.beta_1 = 0.5 40 | self.beta_2 = 0.999 41 | self.lambda_1 = 10.0 # Cyclic loss weight A_2_B 42 | self.lambda_2 = 10.0 # Cyclic loss weight B_2_A 43 | self.lambda_D = 1.0 # Weight for loss from discriminator guess on synthetic images 44 | self.supervised_weight = 10.0 45 | self.synthetic_pool_size = 50 46 | # optimizer 47 | self.opt_D = Adam(self.lr_D, self.beta_1, self.beta_2) 48 | self.opt_G = Adam(self.lr_G, self.beta_1, self.beta_2) 49 | 50 | # TensorFlow wizardry 51 | config = tf.ConfigProto() 52 | # Don't pre-allocate memory; allocate as-needed 53 | config.gpu_options.allow_growth = True 54 | # Create a session with the above options specified. 55 | session = tf.Session(config=config) 56 | K.tensorflow_backend.set_session(session) 57 | 58 | def create_discriminator_and_generator(self): 59 | print('Creating Discriminator and Generator ...') 60 | # Discriminator 61 | D_A = self.Discriminator() 62 | D_B = self.Discriminator() 63 | loss_weights_D = [0.5] 64 | image_A = Input(shape=self.data_shape) 65 | image_B = Input(shape=self.data_shape) 66 | guess_A = D_A(image_A) 67 | guess_B = D_B(image_B) 68 | self.D_A = Model(inputs=image_A, outputs=guess_A, name='D_A') 69 | self.D_B = Model(inputs=image_B, outputs=guess_B, name='D_B') 70 | self.D_A.compile(optimizer=self.opt_D, loss=self.lse, loss_weights=loss_weights_D) 71 | self.D_B.compile(optimizer=self.opt_D, loss=self.lse, loss_weights=loss_weights_D) 72 | # Use containers to avoid falsy keras error about weight descripancies 73 | self.D_A_static = Network(inputs=image_A, outputs=guess_A, name='D_A_static') 74 | self.D_B_static = Network(inputs=image_B, outputs=guess_B, name='D_B_static') 75 | # Do note update discriminator weights during generator training 76 | self.D_A_static.trainable = False 77 | self.D_B_static.trainable = False 78 | 79 | # Generators 80 | self.G_A2B = self.Generator(name='G_A2B') 81 | self.G_B2A = self.Generator(name='G_B2A') 82 | real_A = Input(shape=self.data_shape, name='real_A') 83 | real_B = Input(shape=self.data_shape, name='real_B') 84 | synthetic_B = self.G_A2B(real_A) 85 | synthetic_A = self.G_B2A(real_B) 86 | dA_guess_synthetic = self.D_A_static(synthetic_A) 87 | dB_guess_synthetic = self.D_B_static(synthetic_B) 88 | reconstructed_A = self.G_B2A(synthetic_B) 89 | reconstructed_B = self.G_A2B(synthetic_A) 90 | model_outputs = [reconstructed_A, reconstructed_B] 91 | compile_losses = [self.cycle_loss, self.cycle_loss, self.lse, self.lse] 92 | compile_weights = [self.lambda_1, self.lambda_2, self.lambda_D, self.lambda_D] 93 | model_outputs.append(dA_guess_synthetic) 94 | model_outputs.append(dB_guess_synthetic) 95 | if self.use_supervised_learning: 96 | model_outputs.append(synthetic_A) 97 | model_outputs.append(synthetic_B) 98 | compile_losses.append('MAE') 99 | compile_losses.append('MAE') 100 | compile_weights.append(self.supervised_weight) 101 | compile_weights.append(self.supervised_weight) 102 | self.G_model = Model(inputs=[real_A, real_B], outputs=model_outputs, name='G_model') 103 | self.G_model.compile(optimizer=self.opt_G, loss=compile_losses, loss_weights=compile_weights) 104 | 105 | def ck(self, x, k, use_normalization, stride): 106 | x = Conv2D(filters=k, kernel_size=4, strides=stride, padding='same', use_bias=True)(x) 107 | # Normalization is not done on the first discriminator layer 108 | if use_normalization: 109 | x = self.normalization(axis=3, center=True, epsilon=1e-5)(x, training=True) 110 | x = LeakyReLU(alpha=0.2)(x) 111 | return x 112 | 113 | def c7Ak(self, x, k): 114 | x = Conv2D(filters=k, kernel_size=7, strides=1, padding='valid', use_bias=True)(x) 115 | x = self.normalization(axis=3, center=True, epsilon=1e-5)(x, training=True) 116 | x = Activation('relu')(x) 117 | return x 118 | 119 | def dk(self, x, k): 120 | x = Conv2D(filters=k, kernel_size=3, strides=2, padding='same', use_bias=True)(x) 121 | x = self.normalization(axis=3, center=True, epsilon=1e-5)(x, training=True) 122 | x = Activation('relu')(x) 123 | return x 124 | 125 | def Rk(self, x0): 126 | k = int(x0.shape[-1]) 127 | # first layer 128 | x = ReflectionPadding2D((1,1))(x0) 129 | x = Conv2D(filters=k, kernel_size=3, strides=1, padding='valid', use_bias=True)(x) 130 | x = self.normalization(axis=3, center=True, epsilon=1e-5)(x, training=True) 131 | x = Activation('relu')(x) 132 | # second layer 133 | x = ReflectionPadding2D((1, 1))(x) 134 | x = Conv2D(filters=k, kernel_size=3, strides=1, padding='valid', use_bias=True)(x) 135 | x = self.normalization(axis=3, center=True, epsilon=1e-5)(x, training=True) 136 | # merge 137 | x = add([x, x0]) 138 | return x 139 | 140 | def uk(self, x, k): 141 | if self.use_resize_convolution: 142 | x = UpSampling2D(size=(2, 2))(x) # Nearest neighbor upsampling 143 | x = ReflectionPadding2D((1, 1))(x) 144 | x = Conv2D(filters=k, kernel_size=3, strides=1, padding='valid', use_bias=True)(x) 145 | #x = Dropout(0.1)(x, training=True) 146 | else: 147 | x = Conv2DTranspose(filters=k, kernel_size=3, strides=2, padding='same', use_bias=True)(x) # this matches fractionally stided with stride 1/2 148 | # (up sampling followed by 1x1 convolution <=> fractional-strided 1/2) 149 | # x = Conv2DTranspose(filters=k, kernel_size=3, strides=2, padding='same')(x) # this matches fractionally stided with stride 1/2 150 | x = self.normalization(axis=3, center=True, epsilon=1e-5)(x, training=True) 151 | x = Activation('relu')(x) 152 | return x 153 | 154 | def Discriminator(self, name=None): 155 | # Specify input 156 | input_img = Input(shape=self.data_shape) 157 | # Layer 1 (#Instance normalization is not used for this layer) 158 | x = self.ck(input_img, 64, False, 2) 159 | # Layer 2 160 | x = self.ck(x, 128, True, 2) 161 | # Layer 3 162 | x = self.ck(x, 256, True, 2) 163 | # Layer 4 164 | x = self.ck(x, 512, True, 1) 165 | # Output layer 166 | x = Conv2D(filters=1, kernel_size=4, strides=1, padding='same', use_bias=True)(x) 167 | x = Activation('sigmoid')(x) 168 | return Model(inputs=input_img, outputs=x, name=name) 169 | 170 | def Generator(self, name=None): 171 | input_img = Input(shape=self.data_shape) 172 | # Layer 1 173 | x = ReflectionPadding2D((3, 3))(input_img) 174 | x = self.c7Ak(x, 32) 175 | # Layer 2 176 | x = self.dk(x, 64) 177 | # Layer 3 178 | x = self.dk(x, 128) 179 | # Layer 4-12: Residual layer 180 | for _ in range(4, 13): 181 | x = self.Rk(x) 182 | x = Dropout(self.dropout_rate)(x, training=True) 183 | # Layer 13 184 | x = self.uk(x, 64) 185 | # Layer 14 186 | x = self.uk(x, 32) 187 | x = ReflectionPadding2D((3, 3))(x) 188 | x = Conv2D(filters=self.data_shape[2], kernel_size=7, strides=1, padding='valid', use_bias=True)(x) 189 | x = Activation('tanh')(x) # They say they use Relu but really they do not 190 | return Model(inputs=input_img, outputs=x, name=name) 191 | 192 | 193 | def train(self, train_A_dir, normalization_factor_A, train_B_dir, normalization_factor_B, models_dir, batch_size=10, epochs=200, cycle_loss_type='L1', use_resize_convolution=False, use_supervised_learning=False, output_sample_flag=True, output_sample_dir=None, output_sample_channels=1, dropout_rate=0): 194 | self.batch_size = batch_size 195 | self.epochs = epochs 196 | self.decay_epoch = self.epochs//2 # the epoch where linear decay of the learning rates starts 197 | self.cycle_loss_type = cycle_loss_type 198 | self.use_resize_convolution = use_resize_convolution 199 | self.use_supervised_learning = use_supervised_learning 200 | self.dropout_rate = dropout_rate 201 | # Data dir 202 | self.train_A_dir = train_A_dir 203 | self.train_B_dir = train_B_dir 204 | if not os.path.exists(models_dir): 205 | os.makedirs(models_dir) 206 | self.models_dir = models_dir 207 | self.train_A = load_data(self.train_A_dir, normalization_factor_A) 208 | self.train_B = load_data(self.train_B_dir, normalization_factor_B) 209 | # self.train_A, self.train_B = pad_data(self.train_A, self.train_B) 210 | self.data_shape = self.train_A.shape[1:4] 211 | self.data_num = self.train_A.shape[0] 212 | self.loop_num = self.data_num // self.batch_size 213 | print('Number of epochs: {}, number of loops per epoch: {}'.format(self.epochs, self.loop_num)) 214 | self.create_discriminator_and_generator() 215 | 216 | # Image pools used to update the discriminators 217 | self.synthetic_A_pool = ImagePool(self.synthetic_pool_size) 218 | self.synthetic_B_pool = ImagePool(self.synthetic_pool_size) 219 | 220 | label_shape = (self.batch_size,) + self.D_A.output_shape[1:] 221 | ones = np.ones(shape=label_shape) 222 | zeros = ones * 0 223 | decay_D, decay_G = self.get_lr_linear_decay_rate() 224 | 225 | start_time = time.time() 226 | print("Dropout rate: {}".format(dropout_rate)) 227 | print('Training ...') 228 | for epoch_i in range(self.epochs): 229 | # Update learning rates 230 | if epoch_i > self.decay_epoch: 231 | self.update_lr(self.D_A, decay_D) 232 | self.update_lr(self.D_B, decay_D) 233 | self.update_lr(self.G_model, decay_G) 234 | random_indices = np.random.permutation(self.data_num) 235 | for loop_j in range(self.loop_num): 236 | # training data batches 237 | if self.use_supervised_learning: 238 | random_indices_j = random_indices[loop_j*self.batch_size:(loop_j+1)*self.batch_size] 239 | train_A_batch = self.train_A[random_indices_j] 240 | train_B_batch = self.train_B[random_indices_j] 241 | else: 242 | random_indices_j_A = random_indices[loop_j*self.batch_size:(loop_j+1)*self.batch_size] 243 | random_indices_j_B = random_indices[loop_j*self.batch_size:(loop_j+1)*self.batch_size] 244 | train_A_batch = self.train_A[random_indices_j_A] 245 | train_B_batch = self.train_B[random_indices_j_B] 246 | # Synthetic data for training data batches 247 | synthetic_B_batch = self.G_A2B.predict(train_A_batch) 248 | synthetic_A_batch = self.G_B2A.predict(train_B_batch) 249 | synthetic_A_batch = self.synthetic_A_pool.query(synthetic_A_batch) 250 | synthetic_B_batch = self.synthetic_B_pool.query(synthetic_B_batch) 251 | 252 | # Train Discriminator 253 | DA_loss_train = self.D_A.train_on_batch(x=train_A_batch, y=ones) 254 | DB_loss_train = self.D_B.train_on_batch(x=train_B_batch, y=ones) 255 | DA_loss_synthetic = self.D_A.train_on_batch(x=synthetic_A_batch, y=zeros) 256 | DB_loss_synthetic = self.D_B.train_on_batch(x=synthetic_B_batch, y=zeros) 257 | D_loss = DA_loss_train + DA_loss_synthetic + DB_loss_train + DB_loss_synthetic 258 | 259 | target_data = [train_A_batch, train_B_batch] 260 | target_data.append(ones) 261 | target_data.append(ones) 262 | if self.use_supervised_learning: 263 | target_data.append(train_A_batch) 264 | target_data.append(train_B_batch) 265 | # Train Generator 266 | G_loss = self.G_model.train_on_batch(x=[train_A_batch, train_B_batch], y=target_data) 267 | self.print_info(start_time, epoch_i, loop_j, D_loss, G_loss, DA_loss_train + DA_loss_synthetic, DB_loss_train + DB_loss_synthetic) 268 | if (output_sample_flag): 269 | if (loop_j+1) % 5 == 0: 270 | first_row = np.rot90(train_A_batch[0,:,:,0]) # training data A 271 | second_row = np.rot90(train_B_batch[0,:,:,0]) # training data B 272 | third_row = np.rot90(synthetic_B_batch[0,:,:,0]) # synthetic data B 273 | if output_sample_channels>1: 274 | for channel_i in range(output_sample_channels-1): 275 | first_row = np.append(first_row, np.rot90(train_A_batch[0,:,:,channel_i+1]), axis=1) 276 | second_row = np.append(second_row, np.rot90(train_B_batch[0,:,:,channel_i+1]), axis=1) 277 | third_row = np.append(third_row, np.rot90(synthetic_B_batch[0,:,:,channel_i+1]), axis=1) 278 | output_sample = np.append(np.append(first_row, second_row, axis=0), third_row, axis=0) 279 | toimage(output_sample, cmin=-1, cmax=1).save(output_sample_dir) 280 | if (epoch_i+1) % 20 == 0: 281 | self.save_model(epoch_i) 282 | print("\u001b[12B") 283 | print("\u001b[1000D") 284 | print('Done') 285 | 286 | def synthesize(self, G_X2Y_dir, test_X_dir, normalization_factor_X, synthetic_Y_dir, normalization_factor_Y, use_resize_convolution=False, dropout_rate=0): 287 | test_X_img = nib.load(test_X_dir) 288 | test_X = load_data(test_X_dir, normalization_factor_X) 289 | self.data_shape = test_X.shape[1:4] 290 | self.data_num = test_X.shape[0] 291 | self.use_resize_convolution = use_resize_convolution 292 | self.dropout_rate = dropout_rate 293 | print('Synthesizing ...') 294 | print("Dropout rate: {}".format(dropout_rate)) 295 | self.G_X2Y = self.Generator(name='G_X2Y') 296 | self.G_X2Y.load_weights(G_X2Y_dir) 297 | synthetic_Y = self.G_X2Y.predict(test_X) 298 | synthetic_Y = np.transpose(synthetic_Y, (1, 2, 3, 0)) 299 | synthetic_Y = denormalize_data(synthetic_Y, normalization_factor_Y) 300 | synthetic_Y[synthetic_Y<0] = 0 301 | synthetic_Y = synthetic_Y[0:test_X_img.shape[0], 0:test_X_img.shape[1], :, :] # Remove padded zeros 302 | synthetic_Y_img = nib.Nifti1Image(synthetic_Y, test_X_img.affine, test_X_img.header) 303 | nib.save(synthetic_Y_img, synthetic_Y_dir) 304 | print('Done\n') 305 | 306 | def dropout_sample(self, G_X2Y_dir, test_X_dir, normalization_factor_X, synthetic_Y_dir, normalization_factor_Y, use_resize_convolution=False, dropout_rate=0, dropout_num=1): 307 | test_X_img = nib.load(test_X_dir) 308 | test_X = load_data(test_X_dir, normalization_factor_X) 309 | self.data_shape = test_X.shape[1:4] 310 | self.data_num = test_X.shape[0] 311 | self.use_resize_convolution = use_resize_convolution 312 | self.dropout_rate = dropout_rate 313 | self.G_X2Y = self.Generator(name='G_X2Y') 314 | self.G_X2Y.load_weights(G_X2Y_dir) 315 | print("Dropout rate: {}".format(dropout_rate)) 316 | print("Dropout number: {}".format(dropout_num)) 317 | for dropout_i in range(dropout_num): 318 | print("Dropout sample {}/{}".format(str(dropout_i+1), dropout_num)) 319 | print("\u001b[3A") 320 | print("\u001b[1000D") 321 | sys.stdout.flush() 322 | synthetic_Y = self.G_X2Y.predict(test_X) 323 | synthetic_Y = np.transpose(synthetic_Y, (1, 2, 3, 0)) 324 | synthetic_Y = denormalize_data(synthetic_Y, normalization_factor_Y) 325 | synthetic_Y[synthetic_Y<0] = 0 326 | synthetic_Y = synthetic_Y[0:test_X_img.shape[0], 0:test_X_img.shape[1], :, :] # Remove padded zeros 327 | synthetic_Y_img = nib.Nifti1Image(synthetic_Y, test_X_img.affine, test_X_img.header) 328 | nib.save(synthetic_Y_img, synthetic_Y_dir + "_" + str(dropout_i) + ".nii.gz") 329 | print("\u001b[1000D") 330 | print('Done\n') 331 | 332 | 333 | def lse(self, y_true, y_pred): 334 | loss = tf.reduce_mean(tf.squared_difference(y_pred, y_true)) 335 | return loss 336 | 337 | def cycle_loss(self, y_true, y_pred): 338 | if self.cycle_loss_type == 'L1': 339 | # L1 norm 340 | loss = tf.reduce_mean(tf.abs(y_pred - y_true)) 341 | elif self.cycle_loss_type == 'L2': 342 | # L2 norm 343 | loss = tf.reduce_mean(tf.squared_difference(y_pred, y_true)) 344 | elif self.cycle_loss_type == 'SSIM': 345 | # SSIM 346 | loss = 1 - tf.image.ssim(y_pred,y_true, max_val=1.0)[0] 347 | elif self.cycle_loss_type == 'L1_SSIM': 348 | # L1 + SSIM 349 | loss = 0.5*(1 - tf.image.ssim(y_pred,y_true, max_val=1.0)[0]) + 0.5*tf.reduce_mean(tf.abs(y_pred - y_true)) 350 | elif self.cycle_loss_type == 'L2_SSIM': 351 | # L2 + SSIM 352 | loss = 0.5*(1 - tf.image.ssim(y_pred,y_true, max_val=1.0)[0]) + 0.5*tf.reduce_mean(tf.squared_difference(y_pred, y_true)) 353 | elif self.cycle_loss_type == 'L1_L2_SSIM': 354 | # L1 + L2 + SSIM 355 | loss = 1/3*(1 - tf.image.ssim(y_pred,y_true, max_val=1.0)[0]) + 1/3*tf.reduce_mean(tf.abs(y_pred - y_true)) + 1/3*tf.reduce_mean(tf.squared_difference(y_pred, y_true)) 356 | return loss 357 | 358 | def get_lr_linear_decay_rate(self): 359 | updates_per_epoch_D = 2 * self.data_num 360 | updates_per_epoch_G = self.data_num 361 | denominator_D = (self.epochs - self.decay_epoch) * updates_per_epoch_D 362 | denominator_G = (self.epochs - self.decay_epoch) * updates_per_epoch_G 363 | decay_D = self.lr_D / denominator_D 364 | decay_G = self.lr_G / denominator_G 365 | return decay_D, decay_G 366 | 367 | def update_lr(self, model, decay): 368 | new_lr = K.get_value(model.optimizer.lr) - decay 369 | if new_lr < 0: 370 | new_lr = 0 371 | K.set_value(model.optimizer.lr, new_lr) 372 | 373 | def print_info(self, start_time, epoch_i, loop_j, D_loss, G_loss, DA_loss, DB_loss): 374 | print("\n") 375 | print("Epoch : {:d}/{:d}{}".format(epoch_i + 1, self.epochs, " ")) 376 | print("Loop : {:d}/{:d}{}".format(loop_j + 1, self.loop_num, " ")) 377 | print("D_loss : {:5.4f}{}".format(D_loss, " ")) 378 | print("G_loss : {:5.4f}{}".format(G_loss[0], " ")) 379 | print("reconstruction_loss : {:5.4f}{}".format(G_loss[3]+ G_loss[4], " ")) 380 | print("DA_loss : {:5.4f}{}".format(DA_loss, " ")) 381 | print("DB_loss : {:5.4f}{}".format(DB_loss, " ")) 382 | passed_time = (time.time() - start_time) 383 | loops_finished = epoch_i * self.loop_num + loop_j 384 | loops_total = self.epochs * self.loop_num 385 | loops_left = loops_total - loops_finished 386 | remaining_time = (passed_time / (loops_finished + 1e-5) * loops_left) 387 | passed_time_string = str(datetime.timedelta(seconds=round(passed_time))) 388 | remaining_time_string = str(datetime.timedelta(seconds=round(remaining_time))) 389 | print("Time passed : {}{}".format(passed_time_string, " ")) 390 | print("Time remaining : {}{}".format(remaining_time_string, " ")) 391 | print("\u001b[13A") 392 | print("\u001b[1000D") 393 | sys.stdout.flush() 394 | 395 | def save_model(self, epoch_i): 396 | models_dir_epoch_i = os.path.join(self.models_dir, '{}_weights_epoch_{}.hdf5'.format(self.G_A2B.name, epoch_i+1)) 397 | self.G_A2B.save_weights(models_dir_epoch_i) 398 | models_dir_epoch_i = os.path.join(self.models_dir, '{}_weights_epoch_{}.hdf5'.format(self.G_B2A.name, epoch_i+1)) 399 | self.G_B2A.save_weights(models_dir_epoch_i) 400 | 401 | def normalize_data(data, normalization_factor): 402 | # Normalize data to [-1, 1] 403 | if np.array(normalization_factor).size == 1: 404 | data = data/normalization_factor 405 | else: 406 | for i in range(data.shape[2]): 407 | data[:,:,i,:] = data[:,:,i,:]/normalization_factor[i] # normalize data for each channel 408 | data = data*2-1 409 | return data 410 | 411 | def denormalize_data(data, normalization_factor): 412 | # Denormalize data to [-1, 1] 413 | data = (data+1)/2 414 | if np.array(normalization_factor).size == 1: 415 | data = data*normalization_factor 416 | else: 417 | for i in range(data.shape[2]): 418 | data[:,:,i,:] = data[:,:,i,:]*normalization_factor[i] # normalize data for each channel 419 | return data 420 | 421 | def load_data(data_dir, normalization_factor): 422 | data = nib.load(data_dir).get_fdata() 423 | data[data<0] = 0 424 | if data.ndim == 2: 425 | data = data[:,:,np.newaxis, np.newaxis] 426 | data = normalize_data(data, normalization_factor) 427 | data = np.transpose(data, (3, 0, 1, 2)) 428 | print('Loading data, data size: {}, number of data: {}'.format(data.shape[1:4], data.shape[0])) 429 | # Make sure that slice size is multiple 4 430 | if (data.shape[1]%4 != 0): 431 | data = np.append(data, np.zeros((data.shape[0], 4-data.shape[1]%4, data.shape[2], data.shape[3]))-1, axis=1) 432 | if (data.shape[2]%4 != 0): 433 | data = np.append(data, np.zeros((data.shape[0], data.shape[1], 4-data.shape[2]%4, data.shape[3]))-1, axis=2) 434 | return data 435 | 436 | def pad_data(data_A, data_B): 437 | size_n = data_A.shape[0] 438 | size_x_A = data_A.shape[1] 439 | size_y_A = data_A.shape[2] 440 | size_c_A = data_A.shape[3] 441 | size_x_B = data_B.shape[1] 442 | size_y_B = data_B.shape[2] 443 | size_c_B = data_B.shape[3] 444 | size_x_new = np.maximum(size_x_A, size_x_B) 445 | size_y_new = np.maximum(size_y_A, size_y_B) 446 | size_c_new = np.maximum(size_c_A, size_c_B) 447 | 448 | data_A_new = -np.ones((size_n, size_x_new, size_y_new, size_c_new)) 449 | data_B_new = -np.ones((size_n, size_x_new, size_y_new, size_c_new)) 450 | data_A_new[:, int((size_x_new-size_x_A)/2):int((size_x_new-size_x_A)/2)+size_x_A, int((size_y_new-size_y_A)/2):int((size_y_new-size_y_A)/2)+size_y_A, 0:size_c_A] = data_A 451 | data_B_new[:, int((size_x_new-size_x_B)/2):int((size_x_new-size_x_B)/2)+size_x_B, int((size_y_new-size_y_B)/2):int((size_y_new-size_y_B)/2)+size_y_B, 0:size_c_B] = data_B 452 | 453 | return data_A_new, data_B_new 454 | 455 | class ReflectionPadding2D(Layer): 456 | def __init__(self, padding=(1, 1), **kwargs): 457 | self.padding = tuple(padding) 458 | self.input_spec = [InputSpec(ndim=4)] 459 | super(ReflectionPadding2D, self).__init__(**kwargs) 460 | def compute_output_shape(self, s): 461 | return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3]) 462 | def call(self, x, mask=None): 463 | w_pad, h_pad = self.padding 464 | return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT') 465 | 466 | class ImagePool(): 467 | def __init__(self, pool_size): 468 | self.pool_size = pool_size 469 | if self.pool_size > 0: 470 | self.num_imgs = 0 471 | self.images = [] 472 | 473 | def query(self, images): 474 | if self.pool_size == 0: 475 | return images 476 | return_images = [] 477 | for image in images: 478 | if len(image.shape) == 3: 479 | image = image[np.newaxis, :, :, :] 480 | 481 | if self.num_imgs < self.pool_size: # fill up the image pool 482 | self.num_imgs = self.num_imgs + 1 483 | if len(self.images) == 0: 484 | self.images = image 485 | else: 486 | self.images = np.vstack((self.images, image)) 487 | 488 | if len(return_images) == 0: 489 | return_images = image 490 | else: 491 | return_images = np.vstack((return_images, image)) 492 | else: # 50% chance that we replace an old synthetic image 493 | p = random.uniform(0, 1) 494 | if p > 0.5: 495 | random_id = random.randint(0, self.pool_size - 1) 496 | tmp = self.images[random_id, :, :, :] 497 | tmp = tmp[np.newaxis, :, :, :] 498 | self.images[random_id, :, :, :] = image[0, :, :, :] 499 | if len(return_images) == 0: 500 | return_images = tmp 501 | else: 502 | return_images = np.vstack((return_images, tmp)) 503 | else: 504 | if len(return_images) == 0: 505 | return_images = image 506 | else: 507 | return_images = np.vstack((return_images, image)) 508 | return return_images 509 | --------------------------------------------------------------------------------