├── LICENSE ├── README.md ├── img ├── ASCnet.PNG ├── YT.PNG └── res.PNG └── src ├── continue_training_stage_1.py ├── continue_training_stage_2.py └── create_networks.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Raunak Dey 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## ASC Net summary 3 | 4 | ### Introduction 5 | ASC-Net is a framework which allows us to define a Reference Distribution Set and then take in any Input Image and compare with the Reference Distribution and throw out anomalies present in the Input Image. The kind of cases where this is useful is when you have some images/signals where you are aware of its contents and then you get a set of new images and you want to see if the new images differ from the original set aka anomaly/novelty detection. 6 | 7 | ### Archive Link 8 | 9 | https://arxiv.org/pdf/2103.03664.pdf 10 | 11 | ### Highlights 12 | 13 | 1. Solves the difficulty in defining a class/set of things deterministically down to the nitty gritty details. The Reference Distribution can work on any combination of image set and abstract out the manifold encompassing them. 14 | 2. No need of perfect reconstruction. We care about the anomaly not the reconstruction unlike other existing algorithms. State of the art performance! 15 | 3. We can potentially define any manifold using Reference Distribution and then compare any incoming input image to it. 16 | 4. Works on any image sizes. Simply adjust the size of the encoder/decoder sets to match your input size and hardware capacity. 17 | 5. ***The claim of "independent of instability of GANs" holds since the final termination is not dependent on the adversarial training. We terminate when the I(ro) output has split into distinct peaks.*** 18 | 19 | ### Network Architecture 20 | 21 | ![Network Architecture](img/ASCnet.PNG) 22 | 23 | ### High level Summary [Short Video] 24 | 25 | 26 | [![Click for a short vid](img/YT.PNG)](https://www.youtube.com/watch?v=oUeBNOYOheg) 27 | 28 | ### Important 29 | 30 | ***Always take threshold on the reconstruction i.e. ID3 in the code section as it summarizes the two cuts in one place*** 31 | 32 | 33 | ## Code 34 | 35 | ### Dependencies/Environment used 36 | 37 | * [CUDA](https://developer.nvidia.com/cuda-90-download-archive) - CUDA-9.0.176 38 | * [CUDNN](https://developer.nvidia.com/cudnn-download-survey) - CUDNN- Major 6; Minor 0; PatchLevel 21 39 | * [Python](https://www.python.org/downloads/) - Version 2.7.12 40 | * [Tensorflow](https://www.tensorflow.org/install) - Version 1.10.0 41 | * [Keras](http://www.keras.io) - Version 2.2.2 42 | * [Numpy](http://www.numpy.org/) - Version 1.15.5 43 | * [Nibabel](https://nipy.org/nibabel/) - Version 2.2.0 44 | * [Open-CV](https://opencv.org/releases/) - Version 2.4.9.1 45 | * [Brats 2019](https://ipp.cbica.upenn.edu/) - Select Brats 2019 46 | * [LiTS](https://competitions.codalab.org/competitions/17094) - LiTS Website 47 | * [MS-SEG 2015](https://smart-stats-tools.org/lesion-challenge) - MS-SEG2015 website 48 | * [12 gb TitanX] 49 | 50 | ### Code Summary [Short Video. I havent YET commented the code so watch this for a walkthrough :>] 51 | 52 | 53 | [![Click for a short vid](img/YT.PNG)](https://www.youtube.com/watch?v=F53Grnmnpz0) 54 | 55 | ### Comments 56 | 57 | 58 | 59 | - ID1 is Ifc 60 | - ID2 is Iwc 61 | - ID3 is Iro. ***Please take threshold on this*** 62 | 63 | ### Data Files/Inputs 64 | 65 | 1. To make the frame work function we require 2 files [Mandatory!!!!] 66 | - Reference Distribution - Named ***good_dic_to_train_disc.npy*** for our code 67 | > This is the image set which we know something about. This forms a manifold. 68 | - Input Images - Named ***input_for_generator.npy*** for our code 69 | > These can contain any thing the framework will split it into two halves with one halves consisting of components of the input image in the manifold of the Reference distribution and the other being everything else/anomaly. 70 | 71 | 2. Ground truth for the anomaly we want to test for [Optional used during testing] 72 | - Masks - Named ***tumor_mask_for_generator.npy*** for our code 73 | > The framework is able to throw out anomaly without needing any guidance from a ground truth. However to check performance we may want to include a mask for anomalies of the input image set we use above. In real life scenarios we wont have these and we dont need these. 74 | 75 | ### Source File 76 | 77 | #### Initial Conditions 78 | 79 | - The framework is initialized with input shape 160x160x1 for MS-SEG experiments. Please update this according to your needs. 80 | - Update the path variables for the folders in case you want to visualize the network output while training it 81 | - To change the base network please change the build_generator and build_discriminator methods 82 | 83 | #### File Sequence to Run 84 | 85 | - create_networks.py 86 | > This creates the network mentioned in our paper. If you need a network with different architecture please edit this file accordingly and update the baseline structures of the encoder/decoder. Try to keep the final connections intact. 87 | - After running this you will obtain three h5 files 88 | - disjoint_un_sup_mse_generator.h5 : This is the main module in the network diagram above 89 | - disjoint_un_sup_mse_discriminator.h5 : This is the discriminator in the network diagram above 90 | - disjoint_un_sup_mse_complete_gans.h5 : This is a completed version of the entire network diagram 91 | 92 | 93 | - continue_training_stage_1.py 94 | > Stage 1 training. Read the paper! 95 | 96 | - continue_training_stage_2.py 97 | > Stage 2 training. Read the paper! 98 | 99 | ### Results 100 | 101 | ![Results](img/res.PNG) 102 | -------------------------------------------------------------------------------- /img/ASCnet.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raun1/ASC-NET/d7616bb8c3da70a69287ef8aef982bb3a9d597e8/img/ASCnet.PNG -------------------------------------------------------------------------------- /img/YT.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raun1/ASC-NET/d7616bb8c3da70a69287ef8aef982bb3a9d597e8/img/YT.PNG -------------------------------------------------------------------------------- /img/res.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raun1/ASC-NET/d7616bb8c3da70a69287ef8aef982bb3a9d597e8/img/res.PNG -------------------------------------------------------------------------------- /src/continue_training_stage_1.py: -------------------------------------------------------------------------------- 1 | 2 | import keras 3 | from keras import optimizers 4 | #from keras.utils import multi_gpu_model 5 | import scipy as sp 6 | import scipy.misc, scipy.ndimage.interpolation 7 | from medpy import metric 8 | import numpy as np 9 | import os 10 | from keras import losses 11 | import tensorflow as tf 12 | from keras.models import Model,Sequential 13 | from keras.layers import Input,merge, concatenate, Conv2D, MaxPooling2D, Activation, UpSampling2D,Dropout,Conv2DTranspose,add,multiply,Dense,Flatten 14 | from keras.layers.normalization import BatchNormalization as bn 15 | from keras.callbacks import ModelCheckpoint, TensorBoard 16 | from keras.optimizers import RMSprop 17 | from keras import regularizers 18 | from keras import backend as K 19 | from keras.optimizers import Adam 20 | from keras.callbacks import ModelCheckpoint 21 | import tensorflow as tf 22 | #from keras.applications import Xception 23 | from keras.utils import multi_gpu_model 24 | import random 25 | import numpy as np 26 | from keras.callbacks import EarlyStopping, ModelCheckpoint 27 | import nibabel as nib 28 | import cv2 29 | CUDA_VISIBLE_DEVICES = [0,1,2,3] 30 | os.environ['CUDA_VISIBLE_DEVICES']=','.join([str(x) for x in CUDA_VISIBLE_DEVICES]) 31 | smooth=1. 32 | input_shape=240,240,1 33 | ########################################Losses################################################# 34 | def special_loss_disjoint(y_true,y_pred): 35 | 36 | y_true,y_pred=tf.split(y_pred, 2,axis=-1) 37 | 38 | thresholded_pred = tf.where( tf.greater( 0.0000000000000001, y_pred ), 1 * tf.ones_like( y_pred ), y_pred )#where(cond : take true values : take false values) 39 | 40 | thresholded_true=tf.where( tf.greater( 0.0000000000000001, y_true ), 1 * tf.ones_like( y_true ), y_true ) 41 | 42 | return dice_coef(thresholded_true,thresholded_pred) 43 | 44 | def dice_coef(y_true, y_pred): 45 | 46 | y_true_f = K.flatten(y_true) 47 | y_pred_f = K.flatten(y_pred) 48 | intersection = K.sum(y_true_f * y_pred_f) 49 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 50 | def dice_coef_loss(y_true, y_pred): 51 | return dice_coef(y_true, y_pred) 52 | 53 | ################################################################################################ 54 | 55 | 56 | 57 | def build_discriminator(input_shape,learn_rate=1e-3): 58 | l2_lambda = 0.0002 59 | DropP = 0.3 60 | kernel_size=3 61 | 62 | inputs = Input(input_shape,name="disc_ip") 63 | 64 | conv0a = Conv2D( 32, (kernel_size, kernel_size), activation='relu', padding='same', 65 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc15' )(inputs) 66 | 67 | 68 | conv0a = bn(name='disc_l2_bn1')(conv0a) 69 | 70 | conv0b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same', 71 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc16' )(conv0a) 72 | 73 | conv0b = bn(name='disc_l2_bn2')(conv0b) 74 | 75 | 76 | 77 | 78 | pool0 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp1')(conv0b) 79 | 80 | pool0 = Dropout(DropP,name='disc_l2_d1')(pool0) 81 | 82 | 83 | 84 | 85 | 86 | 87 | conv2a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same', 88 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc17' )(pool0) 89 | 90 | conv2a = bn(name='disc_l2_bn3')(conv2a) 91 | 92 | conv2b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same', 93 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc18')(conv2a) 94 | 95 | conv2b = bn(name='disc_l2_bn4')(conv2b) 96 | 97 | pool2 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp2')(conv2b) 98 | 99 | pool2 = Dropout(DropP,name='disc_l2_d2')(pool2) 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | conv3a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same', 108 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc19' )(pool2) 109 | 110 | conv3a = bn(name='disc_l2_bn5')(conv3a) 111 | 112 | conv3b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same', 113 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc20')(conv3a) 114 | 115 | conv3b = bn(name='disc_l2_bn6')(conv3b) 116 | 117 | 118 | 119 | pool3 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp3')(conv3b) 120 | 121 | pool3 = Dropout(DropP,name='disc_l2_d3')(pool3) 122 | 123 | 124 | conv4a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same', 125 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc21' )(pool3) 126 | 127 | conv4a = bn(name='disc_l2_bn7')(conv4a) 128 | 129 | conv4b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same', 130 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc22' )(conv4a) 131 | 132 | conv4b = bn(name='disc_l2_bn8')(conv4b) 133 | 134 | pool4 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp4')(conv4b) 135 | 136 | pool4 = Dropout(DropP,name='disc_l2_d4')(pool4) 137 | 138 | 139 | 140 | 141 | 142 | conv5a = Conv2D(512, (kernel_size, kernel_size), activation='relu', padding='same', 143 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc23')(pool4) 144 | 145 | conv5a = bn(name='disc_l2_bn9')(conv5a) 146 | 147 | conv5b = Conv2D(512, (kernel_size, kernel_size), activation='relu', padding='same', 148 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc24')(conv5a) 149 | 150 | conv5b = bn(name='disc_l2_bn10')(conv5b) 151 | 152 | flat=Flatten()(conv5b) 153 | 154 | output_disc=Dense(1,activation='tanh',name='disc_output')(flat)#placeholder 155 | 156 | model=Model(inputs=[inputs],outputs=[output_disc]) 157 | model.compile(loss='mae', 158 | optimizer=keras.optimizers.Adam(lr=5e-5), 159 | metrics=['accuracy']) 160 | #model.summary() 161 | return model 162 | 163 | input_shape=240,240,1 164 | 165 | from keras.models import load_model 166 | generator=load_model('disjoint_un_sup_mse_generator.h5', custom_objects={'dice_coef_loss':dice_coef_loss,'special_loss_disjoint':special_loss_disjoint}) 167 | discriminator=load_model('disjoint_un_sup_mse_discriminator.h5', custom_objects={'dice_coef_loss':dice_coef_loss,'special_loss_disjoint':special_loss_disjoint}) 168 | 169 | for layer in discriminator.layers: layer.trainable = False 170 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={ 171 | 172 | 'new_res_1_final_opa':'mse', 173 | 'x_u_net_opsp':special_loss_disjoint 174 | 175 | }) 176 | 177 | discriminator.compile(loss='mae', 178 | optimizer=keras.optimizers.Adam(lr=5e-5), 179 | metrics=['accuracy']) 180 | 181 | final_input=generator.input 182 | 183 | 184 | 185 | x_u_net_opsp=(generator.get_layer('x_u_net_opsp').output) 186 | final_output_gans=discriminator(generator.get_layer('new_final_op').output) 187 | final_output_seg=(generator.get_layer('new_xfinal_op').output) 188 | final_output_res=(generator.get_layer('new_res_1_final_opa').output) 189 | 190 | #final_model.add(generator) 191 | #final_model.add(discriminator) 192 | final_model=Model(inputs=[final_input],outputs=[final_output_gans,final_output_seg,final_output_res,x_u_net_opsp]) 193 | 194 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['mae'],loss={'model_2':'mae', 195 | 196 | 'new_res_1_final_opa':'mse', 197 | 'x_u_net_opsp':special_loss_disjoint}) 198 | 199 | 200 | print("full gans") 201 | final_model.summary() 202 | print(final_model.input) 203 | print(final_model.output) 204 | print("============================================================================================================================================================") 205 | print("generator") 206 | generator.summary() 207 | print(generator.input) 208 | print(generator.output) 209 | print("============================================================================================================================================================") 210 | 211 | print("discriminator") 212 | discriminator.summary() 213 | print(discriminator.get_input_at(0)) 214 | print(discriminator.get_input_at(1)) 215 | #print(discriminator.output) 216 | print("============================================================================================================================================================") 217 | #print(discriminator.get_input_at(2)) 218 | #print(discriminator.input[2]) 219 | #X_train=np.ones((1,160,160,1)) 220 | #final_model.fit([X_train],[1],batch_size=1,nb_epoch=1,shuffle=False) 221 | #print ("hi",final_model.predict([X_train],batch_size=1)) 222 | 223 | 224 | def train_disc(real_data,fake_data,true_label,ep,loss_ch): 225 | 226 | discriminator=build_discriminator(input_shape) 227 | discriminator.name='model_2' 228 | for layer in discriminator.layers: layer.trainable = False 229 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={ 230 | 231 | 'new_res_1_final_opa':'mse', 232 | 'x_u_net_opsp':special_loss_disjoint 233 | 234 | }) 235 | 236 | discriminator.compile(loss='mae', 237 | optimizer=keras.optimizers.Adam(lr=5e-5), 238 | metrics=['accuracy']) 239 | 240 | final_input=generator.input 241 | 242 | 243 | 244 | x_u_net_opsp=(generator.get_layer('x_u_net_opsp').output) 245 | final_output_gans=discriminator(generator.get_layer('new_final_op').output) 246 | final_output_seg=(generator.get_layer('new_xfinal_op').output) 247 | final_output_res=(generator.get_layer('new_res_1_final_opa').output) 248 | 249 | final_model=Model(inputs=[final_input],outputs=[final_output_gans,final_output_seg,final_output_res,x_u_net_opsp]) 250 | 251 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['mae'],loss={'model_2':'mae', 252 | 253 | 'new_res_1_final_opa':'mse', 254 | 'x_u_net_opsp':special_loss_disjoint}) 255 | 256 | for layer in discriminator.layers: layer.trainable = True 257 | 258 | 259 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={ 260 | 261 | 'new_res_1_final_opa':'mse', 262 | 263 | }) 264 | 265 | discriminator.compile(loss='mae', 266 | optimizer=keras.optimizers.Adam(lr=5e-5), 267 | metrics=['accuracy']) 268 | multi_discriminator=multi_gpu_model(discriminator,gpus=4) 269 | multi_discriminator.compile(loss='mae', 270 | optimizer=keras.optimizers.Adam(lr=5e-5), 271 | metrics=['accuracy']) 272 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['accuracy'],loss={'model_2':'mae', 273 | 274 | 'new_res_1_final_opa':'mse', 275 | 'x_u_net_opsp':special_loss_disjoint}) 276 | 277 | discriminator.summary() 278 | 279 | 280 | y_train_true=-np.ones(shape=len(real_data)) 281 | y_train_true=y_train_true#-0.1 282 | print(y_train_true.shape) 283 | 284 | 285 | 286 | 287 | 288 | y_train_fake=np.ones(shape=len(fake_data)) 289 | y_train_fake=y_train_fake#-0.1 290 | 291 | real_data=(list)(real_data) 292 | fake_data=(list)(fake_data) 293 | y_train_true=(list)(y_train_true) 294 | 295 | y_train_fake=(list)(y_train_fake) 296 | merged_inputs=[real_data+fake_data] 297 | real_data=[] 298 | fake_data=[] 299 | merged_gt=[y_train_true+y_train_fake] 300 | print('hi') 301 | 302 | y_train_fake=[] 303 | y_train_true=[] 304 | from sklearn.utils import shuffle 305 | merged_inputs,merged_gt=shuffle(merged_inputs,merged_gt) 306 | 307 | merged_inputs=np.array(merged_inputs) 308 | merged_gt=np.array(merged_gt) 309 | merged_inputs=np.squeeze(merged_inputs,axis=(0,)) 310 | merged_gt=np.squeeze(merged_gt,axis=(0,)) 311 | 312 | print("training_discriminator===============================================================================") 313 | while(True): 314 | xx=(int)((raw_input)("press 1 to keep training")) 315 | ep=(int)((raw_input)("enter updated number of epochs")) 316 | if(xx!=1): 317 | break 318 | #multi_discriminator.summary() 319 | multi_discriminator.fit([merged_inputs],[merged_gt],batch_size=72*4,nb_epoch=ep,shuffle=True) 320 | 321 | return 322 | 323 | 324 | 325 | def train_generator(true_label,ep,loss_ch): 326 | for layer in discriminator.layers: layer.trainable = False 327 | 328 | 329 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={ 330 | 331 | 'new_res_1_final_opa':'mse', 332 | 'x_u_net_opsp':special_loss_disjoint 333 | 334 | }) 335 | 336 | discriminator.compile(loss='mae', 337 | optimizer=keras.optimizers.Adam(lr=5e-5), 338 | metrics=['accuracy']) 339 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['accuracy'],loss={'model_2':'mae', 340 | 341 | 'new_res_1_final_opa':'mse', 342 | 'x_u_net_opsp':special_loss_disjoint}) 343 | 344 | multi_final_model=multi_gpu_model(final_model,gpus=4) 345 | multi_final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['accuracy'],loss={'model_2':'mae', 346 | 347 | 'new_res_1_final_opa':'mse', 348 | 'x_u_net_opsp':special_loss_disjoint}) 349 | 350 | #discriminator.summary() 351 | X_train=np.load("input_for_generator.npy") 352 | #X_train=np.load("input_for_generator.npy") 353 | 354 | 355 | 356 | 357 | y_train=[] 358 | 359 | for j in range(0,len(X_train)): 360 | 361 | y_train.append(-1) 362 | y_train=np.array(y_train) 363 | #print(multi_final_model.summary()) 364 | y_empty=np.zeros(shape=(X_train.shape)) 365 | while(True): 366 | xx=(int)((raw_input)("press 1 to keep training")) 367 | ep=(int)((raw_input)("enter updated number of epochs")) 368 | if(xx!=1): 369 | break 370 | 371 | multi_final_model.fit([X_train],[y_train,X_train,y_empty],batch_size=16*4,nb_epoch=ep,shuffle=True) 372 | result=generator.predict([X_train[0:1000]],batch_size=16) 373 | result[1]=(result[1]-np.amin(result[1]))/((np.amax(result[1]))-(np.amin(result[1]))) 374 | result[0]=(result[0]-np.amin(result[0]))/((np.amax(result[0]))-(np.amin(result[0]))) 375 | result[2]=(result[2]-np.amin(result[2]))/((np.amax(result[2]))-(np.amin(result[2]))) 376 | for i in range(0,1000): 377 | cv2.imwrite("outputs/norm/id1/"+str(i)+".png",(result[0][i])*255) 378 | cv2.imwrite("outputs/norm/id2/"+str(i)+".png",(result[1][i])*255) 379 | cv2.imwrite("outputs/norm/id3/"+str(i)+".png",(result[2][i])*255) 380 | cv2.imwrite("outputs/norm/in/"+str(i)+".png",(X_train[i])*255) 381 | #final_model.fit([X_train],[y_train,X_train,y_empty],batch_size=16,nb_epoch=ep,shuffle=True) 382 | return 383 | 384 | while(True): 385 | i_p=(int)(raw_input("press 0 to train disc and 1 to train gen 2 to save models 3 to check outputs anything else to quit")) 386 | 387 | if(i_p==0): 388 | #''' 389 | discriminator=build_discriminator(input_shape) 390 | discriminator.name='model_2' 391 | for layer in discriminator.layers: layer.trainable = False 392 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={ 393 | 394 | 'new_res_1_final_opa':'mse', 395 | 'x_u_net_opsp':special_loss_disjoint 396 | 397 | }) 398 | 399 | discriminator.compile(loss='mae', 400 | optimizer=keras.optimizers.Adam(lr=5e-5), 401 | metrics=['accuracy']) 402 | 403 | #discriminator.trainable=False 404 | final_input=generator.input 405 | #final_input_1=discriminator.input 406 | #connect the two 407 | 408 | #discriminator.input=generator.get_layer('output_gen').output 409 | x_u_net_opsp=(generator.get_layer('x_u_net_opsp').output) 410 | final_output_gans=discriminator(generator.get_layer('new_final_op').output) 411 | final_output_seg=(generator.get_layer('new_xfinal_op').output) 412 | final_output_res=(generator.get_layer('new_res_1_final_opa').output) 413 | 414 | #final_model.add(generator) 415 | #final_model.add(discriminator) 416 | final_model=Model(inputs=[final_input],outputs=[final_output_gans,final_output_seg,final_output_res,x_u_net_opsp]) 417 | 418 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['mae'],loss={'model_2':'mae', 419 | 420 | 'new_res_1_final_opa':'mse', 421 | 'x_u_net_opsp':special_loss_disjoint}) 422 | loss_ch=0 423 | #''' 424 | print ("training disc") 425 | 426 | ep=(int)(raw_input("enter number of epochs")) 427 | real_data=np.load("good_dic_to_train_disc.npy") 428 | 429 | 430 | 431 | X_train_tumors=np.load("input_for_generator.npy") 432 | 433 | 434 | 435 | fake_data=generator.predict([X_train_tumors])[0] 436 | 437 | print("fake_data_shape",fake_data.shape) 438 | 439 | 440 | true_label=1 441 | 442 | print((real_data.shape),(fake_data.shape),true_label,ep) 443 | proceed=(int)((raw_input)("proceed press 1")) 444 | if(proceed==1): 445 | train_disc(real_data,fake_data,true_label,ep,loss_ch) 446 | else: 447 | continue 448 | 449 | elif(i_p==1): 450 | print("training gen") 451 | loss_ch=0 452 | ep=(int)(raw_input("enter number of epochs")) 453 | true_label=1 454 | 455 | proceed=(int)((raw_input)("proceed press 1")) 456 | if(proceed==1): 457 | train_generator(true_label,ep,loss_ch) 458 | else: 459 | continue 460 | elif(i_p==2): 461 | import h5py 462 | 463 | final_model.save('disjoint_un_sup_mse_complete_gans.h5') 464 | generator.save("disjoint_un_sup_mse_generator.h5") 465 | discriminator.save("disjoint_un_sup_mse_discriminator.h5") 466 | 467 | 468 | elif(i_p==3): 469 | X_train=np.load("input_for_generator.npy") 470 | y_train=np.load("tumor_mask_for_generator.npy") 471 | result=generator.predict([X_train[0:1000]],batch_size=16) 472 | #result=np.array(result) 473 | #print (result.shape) 474 | 475 | print(np.amax(result[0]),np.amax(result[1]),np.amax(result[2]),np.amax(result[3])) 476 | 477 | print(np.amin(result[0]),np.amin(result[1]),np.amin(result[2]),np.amin(result[3])) 478 | 479 | 480 | 481 | 482 | 483 | for i in range(0,1000): 484 | cv2.imwrite("outputs/id1/"+str(i)+".png",(result[0][i])*255) 485 | cv2.imwrite("outputs/id2/"+str(i)+".png",(result[1][i])*255) 486 | cv2.imwrite("outputs/id3/"+str(i)+".png",(result[2][i])*255) 487 | 488 | cv2.imwrite("outputs/norm/in/"+str(i)+".png",X_train[i]*255) 489 | cv2.imwrite("outputs/norm/op/"+str(i)+".png",y_train[i]*255) 490 | 491 | 492 | result[1]=(result[1]-np.amin(result[1]))/((np.amax(result[1]))-(np.amin(result[1]))) 493 | result[0]=(result[0]-np.amin(result[0]))/((np.amax(result[0]))-(np.amin(result[0]))) 494 | result[2]=(result[2]-np.amin(result[2]))/((np.amax(result[2]))-(np.amin(result[2]))) 495 | for i in range(0,1000): 496 | cv2.imwrite("outputs/norm/id1/"+str(i)+".png",(result[0][i])*255) 497 | cv2.imwrite("outputs/norm/id2/"+str(i)+".png",(result[1][i])*255) 498 | cv2.imwrite("outputs/norm/id3/"+str(i)+".png",(result[2][i])*255) 499 | 500 | else: 501 | break 502 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | -------------------------------------------------------------------------------- /src/continue_training_stage_2.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from keras import optimizers 3 | #from keras.utils import multi_gpu_model 4 | import scipy as sp 5 | import scipy.misc, scipy.ndimage.interpolation 6 | from medpy import metric 7 | import numpy as np 8 | import os 9 | from keras import losses 10 | import tensorflow as tf 11 | from keras.models import Model,Sequential 12 | from keras.layers import Input,merge, concatenate, Conv2D, MaxPooling2D, Activation, UpSampling2D,Dropout,Conv2DTranspose,add,multiply,Dense,Flatten 13 | from keras.layers.normalization import BatchNormalization as bn 14 | from keras.callbacks import ModelCheckpoint, TensorBoard 15 | from keras.optimizers import RMSprop 16 | from keras import regularizers 17 | from keras import backend as K 18 | from keras.optimizers import Adam 19 | from keras.callbacks import ModelCheckpoint 20 | import tensorflow as tf 21 | #from keras.applications import Xception 22 | from keras.utils import multi_gpu_model 23 | import random 24 | import numpy as np 25 | from keras.callbacks import EarlyStopping, ModelCheckpoint 26 | import nibabel as nib 27 | import cv2 28 | CUDA_VISIBLE_DEVICES = [0,1,2,3] 29 | os.environ['CUDA_VISIBLE_DEVICES']=','.join([str(x) for x in CUDA_VISIBLE_DEVICES]) 30 | smooth=1. 31 | input_shape=240,240,1 32 | ########################################Losses################################################# 33 | def special_loss_disjoint(y_true,y_pred): 34 | 35 | y_true,y_pred=tf.split(y_pred, 2,axis=-1) 36 | 37 | thresholded_pred = tf.where( tf.greater( y_pred ,0.0000000000000001 ), 1 * tf.ones_like( y_pred ), y_pred )#where(cond : take true values : take false values) 38 | 39 | thresholded_true=tf.where( tf.greater( y_true,0.0000000000000001 ), 1 * tf.ones_like( y_true ), y_true ) 40 | 41 | return dice_coef(thresholded_true,thresholded_pred) 42 | 43 | def dice_coef(y_true, y_pred): 44 | 45 | y_true_f = K.flatten(y_true) 46 | y_pred_f = K.flatten(y_pred) 47 | intersection = K.sum(y_true_f * y_pred_f) 48 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 49 | def dice_coef_loss(y_true, y_pred): 50 | return dice_coef(y_true, y_pred) 51 | 52 | ################################################################################################ 53 | 54 | 55 | 56 | def build_discriminator(input_shape,learn_rate=1e-3): 57 | l2_lambda = 0.0002 58 | DropP = 0.3 59 | kernel_size=3 60 | 61 | inputs = Input(input_shape,name="disc_ip") 62 | 63 | conv0a = Conv2D( 32, (kernel_size, kernel_size), activation='relu', padding='same', 64 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc15' )(inputs) 65 | 66 | 67 | conv0a = bn(name='disc_l2_bn1')(conv0a) 68 | 69 | conv0b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same', 70 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc16' )(conv0a) 71 | 72 | conv0b = bn(name='disc_l2_bn2')(conv0b) 73 | 74 | 75 | 76 | 77 | pool0 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp1')(conv0b) 78 | 79 | pool0 = Dropout(DropP,name='disc_l2_d1')(pool0) 80 | 81 | 82 | 83 | 84 | 85 | 86 | conv2a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same', 87 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc17' )(pool0) 88 | 89 | conv2a = bn(name='disc_l2_bn3')(conv2a) 90 | 91 | conv2b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same', 92 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc18')(conv2a) 93 | 94 | conv2b = bn(name='disc_l2_bn4')(conv2b) 95 | 96 | pool2 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp2')(conv2b) 97 | 98 | pool2 = Dropout(DropP,name='disc_l2_d2')(pool2) 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | conv3a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same', 107 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc19' )(pool2) 108 | 109 | conv3a = bn(name='disc_l2_bn5')(conv3a) 110 | 111 | conv3b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same', 112 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc20')(conv3a) 113 | 114 | conv3b = bn(name='disc_l2_bn6')(conv3b) 115 | 116 | 117 | 118 | pool3 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp3')(conv3b) 119 | 120 | pool3 = Dropout(DropP,name='disc_l2_d3')(pool3) 121 | 122 | 123 | conv4a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same', 124 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc21' )(pool3) 125 | 126 | conv4a = bn(name='disc_l2_bn7')(conv4a) 127 | 128 | conv4b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same', 129 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc22' )(conv4a) 130 | 131 | conv4b = bn(name='disc_l2_bn8')(conv4b) 132 | 133 | pool4 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp4')(conv4b) 134 | 135 | pool4 = Dropout(DropP,name='disc_l2_d4')(pool4) 136 | 137 | 138 | 139 | 140 | 141 | conv5a = Conv2D(512, (kernel_size, kernel_size), activation='relu', padding='same', 142 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc23')(pool4) 143 | 144 | conv5a = bn(name='disc_l2_bn9')(conv5a) 145 | 146 | conv5b = Conv2D(512, (kernel_size, kernel_size), activation='relu', padding='same', 147 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc24')(conv5a) 148 | 149 | conv5b = bn(name='disc_l2_bn10')(conv5b) 150 | 151 | flat=Flatten()(conv5b) 152 | 153 | output_disc=Dense(1,activation='tanh',name='disc_output')(flat)#placeholder 154 | 155 | model=Model(inputs=[inputs],outputs=[output_disc]) 156 | model.compile(loss='mae', 157 | optimizer=keras.optimizers.Adam(lr=5e-5), 158 | metrics=['accuracy']) 159 | #model.summary() 160 | return model 161 | 162 | input_shape=240,240,1 163 | 164 | from keras.models import load_model 165 | generator=load_model('disjoint_un_sup_mse_generator.h5', custom_objects={'dice_coef_loss':dice_coef_loss,'special_loss_disjoint':special_loss_disjoint}) 166 | discriminator=load_model('disjoint_un_sup_mse_discriminator.h5', custom_objects={'dice_coef_loss':dice_coef_loss,'special_loss_disjoint':special_loss_disjoint}) 167 | 168 | for layer in discriminator.layers: layer.trainable = False 169 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={ 170 | 171 | 'new_res_1_final_opa':'mse', 172 | 'x_u_net_opsp':special_loss_disjoint 173 | 174 | }) 175 | 176 | discriminator.compile(loss='mae', 177 | optimizer=keras.optimizers.Adam(lr=5e-5), 178 | metrics=['accuracy']) 179 | 180 | final_input=generator.input 181 | 182 | 183 | 184 | x_u_net_opsp=(generator.get_layer('x_u_net_opsp').output) 185 | final_output_gans=discriminator(generator.get_layer('new_final_op').output) 186 | final_output_seg=(generator.get_layer('new_xfinal_op').output) 187 | final_output_res=(generator.get_layer('new_res_1_final_opa').output) 188 | 189 | #final_model.add(generator) 190 | #final_model.add(discriminator) 191 | final_model=Model(inputs=[final_input],outputs=[final_output_gans,final_output_seg,final_output_res,x_u_net_opsp]) 192 | 193 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['mae'],loss={'model_2':'mae', 194 | 195 | 'new_res_1_final_opa':'mse', 196 | 'x_u_net_opsp':special_loss_disjoint}) 197 | 198 | 199 | print("full gans") 200 | final_model.summary() 201 | print(final_model.input) 202 | print(final_model.output) 203 | print("============================================================================================================================================================") 204 | print("generator") 205 | generator.summary() 206 | print(generator.input) 207 | print(generator.output) 208 | print("============================================================================================================================================================") 209 | 210 | print("discriminator") 211 | discriminator.summary() 212 | print(discriminator.get_input_at(0)) 213 | print(discriminator.get_input_at(1)) 214 | #print(discriminator.output) 215 | print("============================================================================================================================================================") 216 | #print(discriminator.get_input_at(2)) 217 | #print(discriminator.input[2]) 218 | #X_train=np.ones((1,160,160,1)) 219 | #final_model.fit([X_train],[1],batch_size=1,nb_epoch=1,shuffle=False) 220 | #print ("hi",final_model.predict([X_train],batch_size=1)) 221 | 222 | 223 | def train_disc(real_data,fake_data,true_label,ep,loss_ch): 224 | 225 | discriminator=build_discriminator(input_shape) 226 | discriminator.name='model_2' 227 | for layer in discriminator.layers: layer.trainable = False 228 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={ 229 | 230 | 'new_res_1_final_opa':'mse', 231 | 'x_u_net_opsp':special_loss_disjoint 232 | 233 | }) 234 | 235 | discriminator.compile(loss='mae', 236 | optimizer=keras.optimizers.Adam(lr=5e-5), 237 | metrics=['accuracy']) 238 | 239 | final_input=generator.input 240 | 241 | 242 | 243 | x_u_net_opsp=(generator.get_layer('x_u_net_opsp').output) 244 | final_output_gans=discriminator(generator.get_layer('new_final_op').output) 245 | final_output_seg=(generator.get_layer('new_xfinal_op').output) 246 | final_output_res=(generator.get_layer('new_res_1_final_opa').output) 247 | 248 | final_model=Model(inputs=[final_input],outputs=[final_output_gans,final_output_seg,final_output_res,x_u_net_opsp]) 249 | 250 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['mae'],loss={'model_2':'mae', 251 | 252 | 'new_res_1_final_opa':'mse', 253 | 'x_u_net_opsp':special_loss_disjoint}) 254 | 255 | for layer in discriminator.layers: layer.trainable = True 256 | 257 | 258 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={ 259 | 260 | 'new_res_1_final_opa':'mse', 261 | 262 | }) 263 | 264 | discriminator.compile(loss='mae', 265 | optimizer=keras.optimizers.Adam(lr=5e-5), 266 | metrics=['accuracy']) 267 | multi_discriminator=multi_gpu_model(discriminator,gpus=4) 268 | multi_discriminator.compile(loss='mae', 269 | optimizer=keras.optimizers.Adam(lr=5e-5), 270 | metrics=['accuracy']) 271 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['accuracy'],loss={'model_2':'mae', 272 | 273 | 'new_res_1_final_opa':'mse', 274 | 'x_u_net_opsp':special_loss_disjoint}) 275 | 276 | discriminator.summary() 277 | 278 | 279 | y_train_true=-np.ones(shape=len(real_data)) 280 | y_train_true=y_train_true#-0.1 281 | print(y_train_true.shape) 282 | 283 | 284 | 285 | 286 | 287 | y_train_fake=np.ones(shape=len(fake_data)) 288 | y_train_fake=y_train_fake#-0.1 289 | 290 | real_data=(list)(real_data) 291 | fake_data=(list)(fake_data) 292 | y_train_true=(list)(y_train_true) 293 | 294 | y_train_fake=(list)(y_train_fake) 295 | merged_inputs=[real_data+fake_data] 296 | real_data=[] 297 | fake_data=[] 298 | merged_gt=[y_train_true+y_train_fake] 299 | print('hi') 300 | 301 | y_train_fake=[] 302 | y_train_true=[] 303 | from sklearn.utils import shuffle 304 | merged_inputs,merged_gt=shuffle(merged_inputs,merged_gt) 305 | 306 | merged_inputs=np.array(merged_inputs) 307 | merged_gt=np.array(merged_gt) 308 | merged_inputs=np.squeeze(merged_inputs,axis=(0,)) 309 | merged_gt=np.squeeze(merged_gt,axis=(0,)) 310 | 311 | print("training_discriminator===============================================================================") 312 | while(True): 313 | xx=(int)((raw_input)("press 1 to keep training")) 314 | ep=(int)((raw_input)("enter updated number of epochs")) 315 | if(xx!=1): 316 | break 317 | #multi_discriminator.summary() 318 | multi_discriminator.fit([merged_inputs],[merged_gt],batch_size=72*4,nb_epoch=ep,shuffle=True) 319 | 320 | return 321 | 322 | 323 | 324 | def train_generator(true_label,ep,loss_ch): 325 | for layer in discriminator.layers: layer.trainable = False 326 | 327 | 328 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={ 329 | 330 | 'new_res_1_final_opa':'mse', 331 | 'x_u_net_opsp':special_loss_disjoint 332 | 333 | }) 334 | 335 | discriminator.compile(loss='mae', 336 | optimizer=keras.optimizers.Adam(lr=5e-5), 337 | metrics=['accuracy']) 338 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['accuracy'],loss={'model_2':'mae', 339 | 340 | 'new_res_1_final_opa':'mse', 341 | 'x_u_net_opsp':special_loss_disjoint}) 342 | 343 | multi_final_model=multi_gpu_model(final_model,gpus=4) 344 | multi_final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['accuracy'],loss={'model_2':'mae', 345 | 346 | 'new_res_1_final_opa':'mse', 347 | 'x_u_net_opsp':special_loss_disjoint}) 348 | 349 | #discriminator.summary() 350 | X_train=np.load("input_for_generator.npy") 351 | X_train=np.concatenate((X_train,X_train),axis=0) #double 352 | #X_train=np.load("input_for_generator.npy") 353 | 354 | 355 | 356 | 357 | y_train=[] 358 | 359 | for j in range(0,len(X_train)): 360 | 361 | y_train.append(-1) 362 | y_train=np.array(y_train) 363 | #print(multi_final_model.summary()) 364 | y_empty=np.zeros(shape=(X_train.shape)) 365 | while(True): 366 | xx=(int)((raw_input)("press 1 to keep training")) 367 | ep=(int)((raw_input)("enter updated number of epochs")) 368 | if(xx!=1): 369 | break 370 | 371 | multi_final_model.fit([X_train],[y_train,X_train,y_empty],batch_size=16*4,nb_epoch=ep,shuffle=True) 372 | result=generator.predict([X_train[0:1000]],batch_size=16) 373 | result[1]=(result[1]-np.amin(result[1]))/((np.amax(result[1]))-(np.amin(result[1]))) 374 | result[0]=(result[0]-np.amin(result[0]))/((np.amax(result[0]))-(np.amin(result[0]))) 375 | result[2]=(result[2]-np.amin(result[2]))/((np.amax(result[2]))-(np.amin(result[2]))) 376 | for i in range(0,1000): 377 | cv2.imwrite("outputs/norm/id1/"+str(i)+".png",(result[0][i])*255) 378 | cv2.imwrite("outputs/norm/id2/"+str(i)+".png",(result[1][i])*255) 379 | cv2.imwrite("outputs/norm/id3/"+str(i)+".png",(result[2][i])*255) 380 | cv2.imwrite("outputs/norm/in/"+str(i)+".png",(X_train[i])*255) 381 | #final_model.fit([X_train],[y_train,X_train,y_empty],batch_size=16,nb_epoch=ep,shuffle=True) 382 | return 383 | 384 | while(True): 385 | i_p=(int)(raw_input("press 0 to train disc and 1 to train gen 2 to save models 3 to check outputs anything else to quit")) 386 | 387 | if(i_p==0): 388 | #''' 389 | discriminator=build_discriminator(input_shape) 390 | discriminator.name='model_2' 391 | for layer in discriminator.layers: layer.trainable = False 392 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={ 393 | 394 | 'new_res_1_final_opa':'mse', 395 | 'x_u_net_opsp':special_loss_disjoint 396 | 397 | }) 398 | 399 | discriminator.compile(loss='mae', 400 | optimizer=keras.optimizers.Adam(lr=5e-5), 401 | metrics=['accuracy']) 402 | 403 | #discriminator.trainable=False 404 | final_input=generator.input 405 | #final_input_1=discriminator.input 406 | #connect the two 407 | 408 | #discriminator.input=generator.get_layer('output_gen').output 409 | x_u_net_opsp=(generator.get_layer('x_u_net_opsp').output) 410 | final_output_gans=discriminator(generator.get_layer('new_final_op').output) 411 | final_output_seg=(generator.get_layer('new_xfinal_op').output) 412 | final_output_res=(generator.get_layer('new_res_1_final_opa').output) 413 | 414 | #final_model.add(generator) 415 | #final_model.add(discriminator) 416 | final_model=Model(inputs=[final_input],outputs=[final_output_gans,final_output_seg,final_output_res,x_u_net_opsp]) 417 | 418 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['mae'],loss={'model_2':'mae', 419 | 420 | 'new_res_1_final_opa':'mse', 421 | 'x_u_net_opsp':special_loss_disjoint}) 422 | loss_ch=0 423 | #''' 424 | print ("training disc") 425 | 426 | ep=(int)(raw_input("enter number of epochs")) 427 | real_data=generator.predict(np.load("good_dic_to_train_disc.npy"))[0] 428 | real_data=np.concatenate((real_data,np.load("good_dic_to_train_disc.npy")),axis=0) 429 | 430 | 431 | X_train_tumors=np.load("input_for_generator.npy") 432 | X_train_tumors=np.concatenate((X_train_tumors,X_train_tumors),axis=0) 433 | 434 | 435 | fake_data=generator.predict([X_train_tumors])[0] 436 | 437 | print("fake_data_shape",fake_data.shape) 438 | 439 | 440 | true_label=1 441 | 442 | print((real_data.shape),(fake_data.shape),true_label,ep) 443 | proceed=(int)((raw_input)("proceed press 1")) 444 | if(proceed==1): 445 | train_disc(real_data,fake_data,true_label,ep,loss_ch) 446 | else: 447 | continue 448 | 449 | elif(i_p==1): 450 | print("training gen") 451 | loss_ch=0 452 | ep=(int)(raw_input("enter number of epochs")) 453 | true_label=1 454 | 455 | proceed=(int)((raw_input)("proceed press 1")) 456 | if(proceed==1): 457 | train_generator(true_label,ep,loss_ch) 458 | else: 459 | continue 460 | elif(i_p==2): 461 | import h5py 462 | 463 | final_model.save('disjoint_un_sup_mse_complete_gans.h5') 464 | generator.save("disjoint_un_sup_mse_generator.h5") 465 | discriminator.save("disjoint_un_sup_mse_discriminator.h5") 466 | 467 | 468 | elif(i_p==3): 469 | X_train=np.load("input_for_generator.npy") 470 | y_train=np.load("tumor_mask_for_generator.npy") 471 | result=generator.predict([X_train[0:1000]],batch_size=16) 472 | #result=np.array(result) 473 | #print (result.shape) 474 | 475 | print(np.amax(result[0]),np.amax(result[1]),np.amax(result[2]),np.amax(result[3])) 476 | 477 | print(np.amin(result[0]),np.amin(result[1]),np.amin(result[2]),np.amin(result[3])) 478 | 479 | 480 | 481 | 482 | 483 | for i in range(0,1000): 484 | cv2.imwrite("outputs/id1/"+str(i)+".png",(result[0][i])*255) 485 | cv2.imwrite("outputs/id2/"+str(i)+".png",(result[1][i])*255) 486 | cv2.imwrite("outputs/id3/"+str(i)+".png",(result[2][i])*255) 487 | 488 | cv2.imwrite("outputs/norm/in/"+str(i)+".png",X_train[i]*255) 489 | cv2.imwrite("outputs/norm/op/"+str(i)+".png",y_train[i]*255) 490 | 491 | 492 | result[1]=(result[1]-np.amin(result[1]))/((np.amax(result[1]))-(np.amin(result[1]))) 493 | result[0]=(result[0]-np.amin(result[0]))/((np.amax(result[0]))-(np.amin(result[0]))) 494 | result[2]=(result[2]-np.amin(result[2]))/((np.amax(result[2]))-(np.amin(result[2]))) 495 | for i in range(0,1000): 496 | cv2.imwrite("outputs/norm/id1/"+str(i)+".png",(result[0][i])*255) 497 | cv2.imwrite("outputs/norm/id2/"+str(i)+".png",(result[1][i])*255) 498 | cv2.imwrite("outputs/norm/id3/"+str(i)+".png",(result[2][i])*255) 499 | 500 | else: 501 | break 502 | 503 | 504 | 505 | 506 | 507 | -------------------------------------------------------------------------------- /src/create_networks.py: -------------------------------------------------------------------------------- 1 | #ps aux --sort=-%mem | awk 'NR<=10{print $0}' 2 | 3 | import keras 4 | from keras import optimizers 5 | #from keras.utils import multi_gpu_model 6 | import scipy as sp 7 | import scipy.misc, scipy.ndimage.interpolation 8 | from medpy import metric 9 | import numpy as np 10 | import os 11 | from keras import losses 12 | import tensorflow as tf 13 | from keras.models import Model,Sequential 14 | from keras.layers import Input,merge, concatenate, Conv2D, MaxPooling2D, Activation, UpSampling2D,Dropout,Conv2DTranspose,add,multiply,Dense,Flatten 15 | from keras.layers.normalization import BatchNormalization as bn 16 | from keras.callbacks import ModelCheckpoint, TensorBoard 17 | from keras.optimizers import RMSprop 18 | from keras import regularizers 19 | from keras import backend as K 20 | from keras.optimizers import Adam 21 | from keras.callbacks import ModelCheckpoint 22 | import tensorflow as tf 23 | #from keras.applications import Xception 24 | from keras.utils import multi_gpu_model 25 | import random 26 | import numpy as np 27 | from keras.callbacks import EarlyStopping, ModelCheckpoint 28 | import nibabel as nib 29 | CUDA_VISIBLE_DEVICES = [1] 30 | os.environ['CUDA_VISIBLE_DEVICES']=','.join([str(x) for x in CUDA_VISIBLE_DEVICES]) 31 | smooth=1. 32 | def special_loss_disjoint(y_true,y_pred): 33 | 34 | y_true,y_pred=tf.split(y_pred, 2,axis=-1) 35 | thresholded_pred = tf.where( tf.greater( 0.0, y_pred ), 1 * tf.ones_like( y_pred ), y_pred ) 36 | thresholded_true=tf.where( tf.greater( 0.0, y_true ), 1 * tf.ones_like( y_true ), y_true ) 37 | #tf.keras.backend.print_tensor(first) 38 | return dice_coef(thresholded_true,thresholded_pred) 39 | 40 | def dice_coef(y_true, y_pred): 41 | 42 | y_true_f = K.flatten(y_true) 43 | y_pred_f = K.flatten(y_pred) 44 | intersection = K.sum(y_true_f * y_pred_f) 45 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 46 | def dice_coef_loss(y_true, y_pred): 47 | return dice_coef(y_true, y_pred) 48 | 49 | 50 | def build_generator(input_shape,learn_rate=1e-3): 51 | 52 | 53 | 54 | l2_lambda = 0.0002 55 | DropP = 0.3 56 | kernel_size=3 57 | 58 | inputs = Input(input_shape) 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | conv0a = Conv2D( 32, (kernel_size, kernel_size), activation='relu', padding='same', 68 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc15' )(inputs) 69 | 70 | 71 | conv0a = bn(name='l2_bn1')(conv0a) 72 | 73 | conv0b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same', 74 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc16' )(conv0a) 75 | 76 | conv0b = bn(name='l2_bn2')(conv0b) 77 | 78 | 79 | 80 | 81 | pool0 = MaxPooling2D(pool_size=(2, 2),name='l2_mp1')(conv0b) 82 | 83 | pool0 = Dropout(DropP,name='l2_d1')(pool0) 84 | 85 | 86 | 87 | 88 | 89 | 90 | conv2a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same', 91 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc17' )(pool0) 92 | 93 | conv2a = bn(name='l2_bn3')(conv2a) 94 | 95 | conv2b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same', 96 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc18')(conv2a) 97 | 98 | conv2b = bn(name='l2_bn4')(conv2b) 99 | 100 | pool2 = MaxPooling2D(pool_size=(2, 2),name='l2_mp2')(conv2b) 101 | 102 | pool2 = Dropout(DropP,name='l2_d2')(pool2) 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | conv3a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same', 111 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc19' )(pool2) 112 | 113 | conv3a = bn(name='l2_bn5')(conv3a) 114 | 115 | conv3b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same', 116 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc20')(conv3a) 117 | 118 | conv3b = bn(name='l2_bn6')(conv3b) 119 | 120 | 121 | 122 | pool3 = MaxPooling2D(pool_size=(2, 2),name='l2_mp3')(conv3b) 123 | 124 | pool3 = Dropout(DropP,name='l2_d3')(pool3) 125 | 126 | 127 | conv4a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same', 128 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc21' )(pool3) 129 | 130 | conv4a = bn(name='l2_bn7')(conv4a) 131 | 132 | conv4b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same', 133 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc22' )(conv4a) 134 | 135 | conv4b = bn(name='l2_bn8')(conv4b) 136 | 137 | pool4 = MaxPooling2D(pool_size=(2, 2),name='l2_mp4')(conv4b) 138 | 139 | pool4 = Dropout(DropP,name='l2_d4')(pool4) 140 | 141 | 142 | 143 | 144 | 145 | conv5a = Conv2D(512, (kernel_size, kernel_size), activation='relu', padding='same', 146 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc23')(pool4) 147 | 148 | conv5a = bn(name='l2_bn9')(conv5a) 149 | 150 | conv5b = Conv2D(512, (kernel_size, kernel_size), activation='relu', padding='same', 151 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc24')(conv5a) 152 | 153 | conv5b = bn(name='l2_bn10')(conv5b) 154 | 155 | 156 | 157 | 158 | 159 | up6 = concatenate([Conv2DTranspose(256,(2, 2), strides=(2, 2), padding='same',name='l2_conc25')(conv5b), (conv4b)], axis=3,name='l2_conc1') 160 | 161 | 162 | up6 = Dropout(DropP,name='l2_d5')(up6) 163 | 164 | conv6a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same', 165 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc26')(up6) 166 | 167 | conv6a = bn(name='l2_bn11')(conv6a) 168 | 169 | conv6b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same', 170 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc27' )(conv6a) 171 | 172 | conv6b = bn(name='l2_bn12')(conv6b) 173 | 174 | 175 | 176 | 177 | 178 | up7 = concatenate([Conv2DTranspose(128,(2, 2), strides=(2, 2), padding='same',name='l2_conc28')(conv6b),(conv3b)], axis=3,name='l2_conc2') 179 | 180 | up7 = Dropout(DropP,name='l2_d6')(up7) 181 | #add second output here 182 | 183 | conv7a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same', 184 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc29')(up7) 185 | 186 | conv7a = bn(name='l2_bn13')(conv7a) 187 | 188 | 189 | 190 | conv7b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same', 191 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc30')(conv7a) 192 | 193 | conv7b = bn(name='l2_bn14')(conv7b) 194 | 195 | 196 | 197 | 198 | 199 | 200 | up8 = concatenate([Conv2DTranspose(64,(2, 2), strides=(2, 2), padding='same',name='l2_conc31')(conv7b), (conv2b)], axis=3,name='l2_conc3') 201 | 202 | up8 = Dropout(DropP,name='l2_d7')(up8) 203 | 204 | conv8a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same', 205 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc32')(up8) 206 | 207 | conv8a = bn(name='l2_bn15')(conv8a) 208 | 209 | 210 | conv8b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same', 211 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc33' )(conv8a) 212 | 213 | conv8b = bn(name='l2_bn16')(conv8b) 214 | 215 | 216 | 217 | up10 = concatenate([Conv2DTranspose(32,(2, 2), strides=(2, 2), padding='same',name='l2_conc34')(conv8b),(conv0b)],axis=3,name='l2_conc4') 218 | 219 | conv10a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same', 220 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc35')(up10) 221 | 222 | conv10a = bn(name='l2_bn17')(conv10a) 223 | 224 | 225 | 226 | conv10b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same', 227 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc36' )(conv10a) 228 | 229 | conv10b = bn(name='l2_bn18')(conv10b) 230 | 231 | 232 | 233 | new_final_op=Conv2D(1, (1, 1), activation='sigmoid',name='new_final_op')(conv10b) 234 | 235 | 236 | 237 | #-------------------------------------------------------------------------------------- 238 | 239 | 240 | xup6 = concatenate([Conv2DTranspose(256,(2, 2), strides=(2, 2), padding='same',name='l2_conc38')(conv5b), (conv4b)], axis=3,name='l2_conc5') 241 | 242 | 243 | 244 | xup6 = Dropout(DropP,name='l2_d8')(xup6) 245 | 246 | xconv6a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same', 247 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc39' )(xup6) 248 | 249 | xconv6a = bn(name='l2_bn19')(xconv6a) 250 | 251 | 252 | 253 | xconv6b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same', 254 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc40' )(xconv6a) 255 | 256 | xconv6b = bn(name='l2_bn20')(xconv6b) 257 | 258 | 259 | 260 | 261 | 262 | xup7 = concatenate([Conv2DTranspose(128,(2, 2), strides=(2, 2), padding='same',name='l2_conc41')(xconv6b),(conv3b)], axis=3,name='l2_conc6')#xconv6b 263 | 264 | xup7 = Dropout(DropP,name='l2_d9')(xup7) 265 | 266 | xconv7a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same', 267 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc42')(xup7) 268 | 269 | xconv7a = bn(name='l2_bn21')(xconv7a) 270 | 271 | 272 | xconv7b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same', 273 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc43')(xconv7a) 274 | 275 | xconv7b = bn(name='l2_bn22')(xconv7b) 276 | 277 | 278 | xup8 = concatenate([Conv2DTranspose(64,(2, 2), strides=(2, 2), padding='same',name='l2_conc44')(xconv7b),(conv2b)], axis=3,name='l2_conc7') 279 | 280 | xup8 = Dropout(DropP,name='l2_d10')(xup8) 281 | #add third xoutxout here 282 | 283 | xconv8a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same', 284 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc45')(xup8) 285 | 286 | xconv8a = bn(name='l2_bn23')(xconv8a) 287 | 288 | 289 | xconv8b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same', 290 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc46' )(xconv8a) 291 | 292 | xconv8b = bn(name='l2_bn24')(xconv8b) 293 | 294 | 295 | 296 | 297 | 298 | 299 | xup10 = concatenate([Conv2DTranspose(32,(2, 2), strides=(2, 2), padding='same',name='l2_conc47')(xconv8b), (conv0b)],axis=3,name='l2_conc8') 300 | 301 | xup10 = Dropout(DropP,name='l2_d11')(xup10) 302 | 303 | 304 | xconv10a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same', 305 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc48')(xup10) 306 | 307 | xconv10a = bn(name='l2_bn25')(xconv10a) 308 | 309 | 310 | xconv10b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same', 311 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc49')(xconv10a) 312 | 313 | xconv10b = bn(name='l2_bn26')(xconv10b) 314 | 315 | 316 | 317 | 318 | 319 | 320 | new_xfinal_op=Conv2D(1, (1, 1), activation='sigmoid',name='new_xfinal_op')(xconv10b)#tan 321 | 322 | 323 | 324 | 325 | #-----------------------------third branch 326 | 327 | 328 | 329 | #Concatenation fed to the reconstruction layer of all 3 330 | 331 | x_u_net_op0=keras.layers.concatenate([new_final_op,new_xfinal_op],name='l2_conc9') 332 | x_u_net_opsp=keras.layers.concatenate([new_final_op,new_xfinal_op],name='x_u_net_opsp') 333 | 334 | #res_1_conv0a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same', 335 | # kernel_regularizer=regularizers.l2(l2_lambda) ,name='mixer_conv')(x_u_net_op0) 336 | 337 | #res_1_conv0a = bn()(res_1_conv0a) 338 | 339 | #res_1_conv0b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same', 340 | # kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv0a) 341 | #res_1_conv0c = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same', 342 | # kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv0b) 343 | #res_1_conv0d = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same', 344 | # kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv0c) 345 | 346 | 347 | new_res_1_final_opa=Conv2D(1, (1, 1), activation='sigmoid',name='new_res_1_final_opa')(x_u_net_op0) 348 | 349 | model=Model(inputs=[inputs],outputs=[new_final_op, 350 | new_xfinal_op, 351 | new_res_1_final_opa, 352 | x_u_net_opsp 353 | 354 | 355 | ]) 356 | model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={ 357 | 358 | 'new_res_1_final_opa':'mse', 359 | 'x_u_net_opsp':special_loss_disjoint 360 | 361 | }) 362 | 363 | return model 364 | 365 | 366 | 367 | 368 | 369 | #model.summary() 370 | #return model 371 | 372 | def build_discriminator(input_shape,learn_rate=1e-3): 373 | l2_lambda = 0.0002 374 | DropP = 0.3 375 | kernel_size=3 376 | 377 | inputs = Input(input_shape,name="disc_ip") 378 | 379 | conv0a = Conv2D( 32, (kernel_size, kernel_size), activation='relu', padding='same', 380 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc15' )(inputs) 381 | 382 | 383 | conv0a = bn(name='disc_l2_bn1')(conv0a) 384 | 385 | conv0b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same', 386 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc16' )(conv0a) 387 | 388 | conv0b = bn(name='disc_l2_bn2')(conv0b) 389 | 390 | 391 | 392 | 393 | pool0 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp1')(conv0b) 394 | 395 | pool0 = Dropout(DropP,name='disc_l2_d1')(pool0) 396 | 397 | 398 | 399 | 400 | 401 | 402 | conv2a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same', 403 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc17' )(pool0) 404 | 405 | conv2a = bn(name='disc_l2_bn3')(conv2a) 406 | 407 | conv2b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same', 408 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc18')(conv2a) 409 | 410 | conv2b = bn(name='disc_l2_bn4')(conv2b) 411 | 412 | pool2 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp2')(conv2b) 413 | 414 | pool2 = Dropout(DropP,name='disc_l2_d2')(pool2) 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | conv3a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same', 423 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc19' )(pool2) 424 | 425 | conv3a = bn(name='disc_l2_bn5')(conv3a) 426 | 427 | conv3b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same', 428 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc20')(conv3a) 429 | 430 | conv3b = bn(name='disc_l2_bn6')(conv3b) 431 | 432 | 433 | 434 | pool3 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp3')(conv3b) 435 | 436 | pool3 = Dropout(DropP,name='disc_l2_d3')(pool3) 437 | 438 | 439 | conv4a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same', 440 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc21' )(pool3) 441 | 442 | conv4a = bn(name='disc_l2_bn7')(conv4a) 443 | 444 | conv4b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same', 445 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc22' )(conv4a) 446 | 447 | conv4b = bn(name='disc_l2_bn8')(conv4b) 448 | 449 | pool4 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp4')(conv4b) 450 | 451 | pool4 = Dropout(DropP,name='disc_l2_d4')(pool4) 452 | 453 | 454 | 455 | 456 | 457 | conv5a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same', 458 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc23')(pool4) 459 | 460 | conv5a = bn(name='disc_l2_bn9')(conv5a) 461 | 462 | conv5b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same', 463 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc24')(conv5a) 464 | 465 | conv5b = bn(name='disc_l2_bn10')(conv5b) 466 | 467 | flat=Flatten()(conv5b) 468 | 469 | output_disc=Dense(1,activation='tanh',name='disc_output')(flat)#placeholder 470 | 471 | model=Model(inputs=[inputs],outputs=[output_disc]) 472 | model.compile(loss='mae', 473 | optimizer=keras.optimizers.Adam(lr=5e-5), 474 | metrics=['accuracy']) 475 | #model.summary() 476 | return model 477 | 478 | 479 | input_shape=240,240,1 480 | #final_model = Sequential() 481 | 482 | generator=build_generator(input_shape) 483 | discriminator=build_discriminator(input_shape) 484 | for layer in discriminator.layers: layer.trainable = False 485 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={ 486 | 487 | 'new_res_1_final_opa':'mse', 488 | 'x_u_net_opsp':special_loss_disjoint 489 | 490 | }) 491 | 492 | discriminator.compile(loss='mae', 493 | optimizer=keras.optimizers.Adam(lr=5e-5), 494 | metrics=['accuracy']) 495 | 496 | #discriminator.trainable=False 497 | final_input=generator.input 498 | #final_input_1=discriminator.input 499 | #connect the two 500 | 501 | #discriminator.input=generator.get_layer('output_gen').output 502 | x_u_net_opsp=(generator.get_layer('x_u_net_opsp').output) 503 | final_output_gans=discriminator(generator.get_layer('new_final_op').output) 504 | final_output_seg=(generator.get_layer('new_xfinal_op').output) 505 | final_output_res=(generator.get_layer('new_res_1_final_opa').output) 506 | 507 | #final_model.add(generator) 508 | #final_model.add(discriminator) 509 | final_model=Model(inputs=[final_input],outputs=[final_output_gans,final_output_seg,final_output_res,x_u_net_opsp]) 510 | 511 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['mae'],loss={'model_2':'mae', 512 | 513 | 'new_res_1_final_opa':'mse', 514 | 'x_u_net_opsp':special_loss_disjoint}) 515 | 516 | 517 | print("full gans") 518 | final_model.summary() 519 | print(final_model.input) 520 | print(final_model.output) 521 | print("============================================================================================================================================================") 522 | print("generator") 523 | generator.summary() 524 | print(generator.input) 525 | print(generator.output) 526 | print("============================================================================================================================================================") 527 | 528 | print("discriminator") 529 | discriminator.summary() 530 | print(discriminator.get_input_at(0)) 531 | print(discriminator.get_input_at(1)) 532 | #print(discriminator.output) 533 | print("============================================================================================================================================================") 534 | #print(discriminator.get_input_at(2)) 535 | #print(discriminator.input[2]) 536 | #X_train=np.ones((1,240,240,1)) 537 | #final_model.fit([X_train],[1],batch_size=1,nb_epoch=1,shuffle=False) 538 | #print ("hi",final_model.predict([X_train],batch_size=1)) 539 | import h5py 540 | 541 | final_model.save('disjoint_un_sup_mse_complete_gans.h5') 542 | generator.save("disjoint_un_sup_mse_generator.h5") 543 | discriminator.save("disjoint_un_sup_mse_discriminator.h5") 544 | 545 | 546 | 547 | 548 | 549 | 550 | 551 | 552 | 553 | 554 | 555 | 556 | 557 | 558 | --------------------------------------------------------------------------------