├── LICENSE
├── README.md
├── img
├── ASCnet.PNG
├── YT.PNG
└── res.PNG
└── src
├── continue_training_stage_1.py
├── continue_training_stage_2.py
└── create_networks.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Raunak Dey
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## ASC Net summary
3 |
4 | ### Introduction
5 | ASC-Net is a framework which allows us to define a Reference Distribution Set and then take in any Input Image and compare with the Reference Distribution and throw out anomalies present in the Input Image. The kind of cases where this is useful is when you have some images/signals where you are aware of its contents and then you get a set of new images and you want to see if the new images differ from the original set aka anomaly/novelty detection.
6 |
7 | ### Archive Link
8 |
9 | https://arxiv.org/pdf/2103.03664.pdf
10 |
11 | ### Highlights
12 |
13 | 1. Solves the difficulty in defining a class/set of things deterministically down to the nitty gritty details. The Reference Distribution can work on any combination of image set and abstract out the manifold encompassing them.
14 | 2. No need of perfect reconstruction. We care about the anomaly not the reconstruction unlike other existing algorithms. State of the art performance!
15 | 3. We can potentially define any manifold using Reference Distribution and then compare any incoming input image to it.
16 | 4. Works on any image sizes. Simply adjust the size of the encoder/decoder sets to match your input size and hardware capacity.
17 | 5. ***The claim of "independent of instability of GANs" holds since the final termination is not dependent on the adversarial training. We terminate when the I(ro) output has split into distinct peaks.***
18 |
19 | ### Network Architecture
20 |
21 | 
22 |
23 | ### High level Summary [Short Video]
24 |
25 |
26 | [](https://www.youtube.com/watch?v=oUeBNOYOheg)
27 |
28 | ### Important
29 |
30 | ***Always take threshold on the reconstruction i.e. ID3 in the code section as it summarizes the two cuts in one place***
31 |
32 |
33 | ## Code
34 |
35 | ### Dependencies/Environment used
36 |
37 | * [CUDA](https://developer.nvidia.com/cuda-90-download-archive) - CUDA-9.0.176
38 | * [CUDNN](https://developer.nvidia.com/cudnn-download-survey) - CUDNN- Major 6; Minor 0; PatchLevel 21
39 | * [Python](https://www.python.org/downloads/) - Version 2.7.12
40 | * [Tensorflow](https://www.tensorflow.org/install) - Version 1.10.0
41 | * [Keras](http://www.keras.io) - Version 2.2.2
42 | * [Numpy](http://www.numpy.org/) - Version 1.15.5
43 | * [Nibabel](https://nipy.org/nibabel/) - Version 2.2.0
44 | * [Open-CV](https://opencv.org/releases/) - Version 2.4.9.1
45 | * [Brats 2019](https://ipp.cbica.upenn.edu/) - Select Brats 2019
46 | * [LiTS](https://competitions.codalab.org/competitions/17094) - LiTS Website
47 | * [MS-SEG 2015](https://smart-stats-tools.org/lesion-challenge) - MS-SEG2015 website
48 | * [12 gb TitanX]
49 |
50 | ### Code Summary [Short Video. I havent YET commented the code so watch this for a walkthrough :>]
51 |
52 |
53 | [](https://www.youtube.com/watch?v=F53Grnmnpz0)
54 |
55 | ### Comments
56 |
57 |
58 |
59 | - ID1 is Ifc
60 | - ID2 is Iwc
61 | - ID3 is Iro. ***Please take threshold on this***
62 |
63 | ### Data Files/Inputs
64 |
65 | 1. To make the frame work function we require 2 files [Mandatory!!!!]
66 | - Reference Distribution - Named ***good_dic_to_train_disc.npy*** for our code
67 | > This is the image set which we know something about. This forms a manifold.
68 | - Input Images - Named ***input_for_generator.npy*** for our code
69 | > These can contain any thing the framework will split it into two halves with one halves consisting of components of the input image in the manifold of the Reference distribution and the other being everything else/anomaly.
70 |
71 | 2. Ground truth for the anomaly we want to test for [Optional used during testing]
72 | - Masks - Named ***tumor_mask_for_generator.npy*** for our code
73 | > The framework is able to throw out anomaly without needing any guidance from a ground truth. However to check performance we may want to include a mask for anomalies of the input image set we use above. In real life scenarios we wont have these and we dont need these.
74 |
75 | ### Source File
76 |
77 | #### Initial Conditions
78 |
79 | - The framework is initialized with input shape 160x160x1 for MS-SEG experiments. Please update this according to your needs.
80 | - Update the path variables for the folders in case you want to visualize the network output while training it
81 | - To change the base network please change the build_generator and build_discriminator methods
82 |
83 | #### File Sequence to Run
84 |
85 | - create_networks.py
86 | > This creates the network mentioned in our paper. If you need a network with different architecture please edit this file accordingly and update the baseline structures of the encoder/decoder. Try to keep the final connections intact.
87 | - After running this you will obtain three h5 files
88 | - disjoint_un_sup_mse_generator.h5 : This is the main module in the network diagram above
89 | - disjoint_un_sup_mse_discriminator.h5 : This is the discriminator in the network diagram above
90 | - disjoint_un_sup_mse_complete_gans.h5 : This is a completed version of the entire network diagram
91 |
92 |
93 | - continue_training_stage_1.py
94 | > Stage 1 training. Read the paper!
95 |
96 | - continue_training_stage_2.py
97 | > Stage 2 training. Read the paper!
98 |
99 | ### Results
100 |
101 | 
102 |
--------------------------------------------------------------------------------
/img/ASCnet.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raun1/ASC-NET/d7616bb8c3da70a69287ef8aef982bb3a9d597e8/img/ASCnet.PNG
--------------------------------------------------------------------------------
/img/YT.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raun1/ASC-NET/d7616bb8c3da70a69287ef8aef982bb3a9d597e8/img/YT.PNG
--------------------------------------------------------------------------------
/img/res.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raun1/ASC-NET/d7616bb8c3da70a69287ef8aef982bb3a9d597e8/img/res.PNG
--------------------------------------------------------------------------------
/src/continue_training_stage_1.py:
--------------------------------------------------------------------------------
1 |
2 | import keras
3 | from keras import optimizers
4 | #from keras.utils import multi_gpu_model
5 | import scipy as sp
6 | import scipy.misc, scipy.ndimage.interpolation
7 | from medpy import metric
8 | import numpy as np
9 | import os
10 | from keras import losses
11 | import tensorflow as tf
12 | from keras.models import Model,Sequential
13 | from keras.layers import Input,merge, concatenate, Conv2D, MaxPooling2D, Activation, UpSampling2D,Dropout,Conv2DTranspose,add,multiply,Dense,Flatten
14 | from keras.layers.normalization import BatchNormalization as bn
15 | from keras.callbacks import ModelCheckpoint, TensorBoard
16 | from keras.optimizers import RMSprop
17 | from keras import regularizers
18 | from keras import backend as K
19 | from keras.optimizers import Adam
20 | from keras.callbacks import ModelCheckpoint
21 | import tensorflow as tf
22 | #from keras.applications import Xception
23 | from keras.utils import multi_gpu_model
24 | import random
25 | import numpy as np
26 | from keras.callbacks import EarlyStopping, ModelCheckpoint
27 | import nibabel as nib
28 | import cv2
29 | CUDA_VISIBLE_DEVICES = [0,1,2,3]
30 | os.environ['CUDA_VISIBLE_DEVICES']=','.join([str(x) for x in CUDA_VISIBLE_DEVICES])
31 | smooth=1.
32 | input_shape=240,240,1
33 | ########################################Losses#################################################
34 | def special_loss_disjoint(y_true,y_pred):
35 |
36 | y_true,y_pred=tf.split(y_pred, 2,axis=-1)
37 |
38 | thresholded_pred = tf.where( tf.greater( 0.0000000000000001, y_pred ), 1 * tf.ones_like( y_pred ), y_pred )#where(cond : take true values : take false values)
39 |
40 | thresholded_true=tf.where( tf.greater( 0.0000000000000001, y_true ), 1 * tf.ones_like( y_true ), y_true )
41 |
42 | return dice_coef(thresholded_true,thresholded_pred)
43 |
44 | def dice_coef(y_true, y_pred):
45 |
46 | y_true_f = K.flatten(y_true)
47 | y_pred_f = K.flatten(y_pred)
48 | intersection = K.sum(y_true_f * y_pred_f)
49 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
50 | def dice_coef_loss(y_true, y_pred):
51 | return dice_coef(y_true, y_pred)
52 |
53 | ################################################################################################
54 |
55 |
56 |
57 | def build_discriminator(input_shape,learn_rate=1e-3):
58 | l2_lambda = 0.0002
59 | DropP = 0.3
60 | kernel_size=3
61 |
62 | inputs = Input(input_shape,name="disc_ip")
63 |
64 | conv0a = Conv2D( 32, (kernel_size, kernel_size), activation='relu', padding='same',
65 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc15' )(inputs)
66 |
67 |
68 | conv0a = bn(name='disc_l2_bn1')(conv0a)
69 |
70 | conv0b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
71 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc16' )(conv0a)
72 |
73 | conv0b = bn(name='disc_l2_bn2')(conv0b)
74 |
75 |
76 |
77 |
78 | pool0 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp1')(conv0b)
79 |
80 | pool0 = Dropout(DropP,name='disc_l2_d1')(pool0)
81 |
82 |
83 |
84 |
85 |
86 |
87 | conv2a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
88 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc17' )(pool0)
89 |
90 | conv2a = bn(name='disc_l2_bn3')(conv2a)
91 |
92 | conv2b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
93 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc18')(conv2a)
94 |
95 | conv2b = bn(name='disc_l2_bn4')(conv2b)
96 |
97 | pool2 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp2')(conv2b)
98 |
99 | pool2 = Dropout(DropP,name='disc_l2_d2')(pool2)
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 | conv3a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
108 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc19' )(pool2)
109 |
110 | conv3a = bn(name='disc_l2_bn5')(conv3a)
111 |
112 | conv3b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
113 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc20')(conv3a)
114 |
115 | conv3b = bn(name='disc_l2_bn6')(conv3b)
116 |
117 |
118 |
119 | pool3 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp3')(conv3b)
120 |
121 | pool3 = Dropout(DropP,name='disc_l2_d3')(pool3)
122 |
123 |
124 | conv4a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
125 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc21' )(pool3)
126 |
127 | conv4a = bn(name='disc_l2_bn7')(conv4a)
128 |
129 | conv4b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
130 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc22' )(conv4a)
131 |
132 | conv4b = bn(name='disc_l2_bn8')(conv4b)
133 |
134 | pool4 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp4')(conv4b)
135 |
136 | pool4 = Dropout(DropP,name='disc_l2_d4')(pool4)
137 |
138 |
139 |
140 |
141 |
142 | conv5a = Conv2D(512, (kernel_size, kernel_size), activation='relu', padding='same',
143 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc23')(pool4)
144 |
145 | conv5a = bn(name='disc_l2_bn9')(conv5a)
146 |
147 | conv5b = Conv2D(512, (kernel_size, kernel_size), activation='relu', padding='same',
148 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc24')(conv5a)
149 |
150 | conv5b = bn(name='disc_l2_bn10')(conv5b)
151 |
152 | flat=Flatten()(conv5b)
153 |
154 | output_disc=Dense(1,activation='tanh',name='disc_output')(flat)#placeholder
155 |
156 | model=Model(inputs=[inputs],outputs=[output_disc])
157 | model.compile(loss='mae',
158 | optimizer=keras.optimizers.Adam(lr=5e-5),
159 | metrics=['accuracy'])
160 | #model.summary()
161 | return model
162 |
163 | input_shape=240,240,1
164 |
165 | from keras.models import load_model
166 | generator=load_model('disjoint_un_sup_mse_generator.h5', custom_objects={'dice_coef_loss':dice_coef_loss,'special_loss_disjoint':special_loss_disjoint})
167 | discriminator=load_model('disjoint_un_sup_mse_discriminator.h5', custom_objects={'dice_coef_loss':dice_coef_loss,'special_loss_disjoint':special_loss_disjoint})
168 |
169 | for layer in discriminator.layers: layer.trainable = False
170 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={
171 |
172 | 'new_res_1_final_opa':'mse',
173 | 'x_u_net_opsp':special_loss_disjoint
174 |
175 | })
176 |
177 | discriminator.compile(loss='mae',
178 | optimizer=keras.optimizers.Adam(lr=5e-5),
179 | metrics=['accuracy'])
180 |
181 | final_input=generator.input
182 |
183 |
184 |
185 | x_u_net_opsp=(generator.get_layer('x_u_net_opsp').output)
186 | final_output_gans=discriminator(generator.get_layer('new_final_op').output)
187 | final_output_seg=(generator.get_layer('new_xfinal_op').output)
188 | final_output_res=(generator.get_layer('new_res_1_final_opa').output)
189 |
190 | #final_model.add(generator)
191 | #final_model.add(discriminator)
192 | final_model=Model(inputs=[final_input],outputs=[final_output_gans,final_output_seg,final_output_res,x_u_net_opsp])
193 |
194 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['mae'],loss={'model_2':'mae',
195 |
196 | 'new_res_1_final_opa':'mse',
197 | 'x_u_net_opsp':special_loss_disjoint})
198 |
199 |
200 | print("full gans")
201 | final_model.summary()
202 | print(final_model.input)
203 | print(final_model.output)
204 | print("============================================================================================================================================================")
205 | print("generator")
206 | generator.summary()
207 | print(generator.input)
208 | print(generator.output)
209 | print("============================================================================================================================================================")
210 |
211 | print("discriminator")
212 | discriminator.summary()
213 | print(discriminator.get_input_at(0))
214 | print(discriminator.get_input_at(1))
215 | #print(discriminator.output)
216 | print("============================================================================================================================================================")
217 | #print(discriminator.get_input_at(2))
218 | #print(discriminator.input[2])
219 | #X_train=np.ones((1,160,160,1))
220 | #final_model.fit([X_train],[1],batch_size=1,nb_epoch=1,shuffle=False)
221 | #print ("hi",final_model.predict([X_train],batch_size=1))
222 |
223 |
224 | def train_disc(real_data,fake_data,true_label,ep,loss_ch):
225 |
226 | discriminator=build_discriminator(input_shape)
227 | discriminator.name='model_2'
228 | for layer in discriminator.layers: layer.trainable = False
229 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={
230 |
231 | 'new_res_1_final_opa':'mse',
232 | 'x_u_net_opsp':special_loss_disjoint
233 |
234 | })
235 |
236 | discriminator.compile(loss='mae',
237 | optimizer=keras.optimizers.Adam(lr=5e-5),
238 | metrics=['accuracy'])
239 |
240 | final_input=generator.input
241 |
242 |
243 |
244 | x_u_net_opsp=(generator.get_layer('x_u_net_opsp').output)
245 | final_output_gans=discriminator(generator.get_layer('new_final_op').output)
246 | final_output_seg=(generator.get_layer('new_xfinal_op').output)
247 | final_output_res=(generator.get_layer('new_res_1_final_opa').output)
248 |
249 | final_model=Model(inputs=[final_input],outputs=[final_output_gans,final_output_seg,final_output_res,x_u_net_opsp])
250 |
251 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['mae'],loss={'model_2':'mae',
252 |
253 | 'new_res_1_final_opa':'mse',
254 | 'x_u_net_opsp':special_loss_disjoint})
255 |
256 | for layer in discriminator.layers: layer.trainable = True
257 |
258 |
259 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={
260 |
261 | 'new_res_1_final_opa':'mse',
262 |
263 | })
264 |
265 | discriminator.compile(loss='mae',
266 | optimizer=keras.optimizers.Adam(lr=5e-5),
267 | metrics=['accuracy'])
268 | multi_discriminator=multi_gpu_model(discriminator,gpus=4)
269 | multi_discriminator.compile(loss='mae',
270 | optimizer=keras.optimizers.Adam(lr=5e-5),
271 | metrics=['accuracy'])
272 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['accuracy'],loss={'model_2':'mae',
273 |
274 | 'new_res_1_final_opa':'mse',
275 | 'x_u_net_opsp':special_loss_disjoint})
276 |
277 | discriminator.summary()
278 |
279 |
280 | y_train_true=-np.ones(shape=len(real_data))
281 | y_train_true=y_train_true#-0.1
282 | print(y_train_true.shape)
283 |
284 |
285 |
286 |
287 |
288 | y_train_fake=np.ones(shape=len(fake_data))
289 | y_train_fake=y_train_fake#-0.1
290 |
291 | real_data=(list)(real_data)
292 | fake_data=(list)(fake_data)
293 | y_train_true=(list)(y_train_true)
294 |
295 | y_train_fake=(list)(y_train_fake)
296 | merged_inputs=[real_data+fake_data]
297 | real_data=[]
298 | fake_data=[]
299 | merged_gt=[y_train_true+y_train_fake]
300 | print('hi')
301 |
302 | y_train_fake=[]
303 | y_train_true=[]
304 | from sklearn.utils import shuffle
305 | merged_inputs,merged_gt=shuffle(merged_inputs,merged_gt)
306 |
307 | merged_inputs=np.array(merged_inputs)
308 | merged_gt=np.array(merged_gt)
309 | merged_inputs=np.squeeze(merged_inputs,axis=(0,))
310 | merged_gt=np.squeeze(merged_gt,axis=(0,))
311 |
312 | print("training_discriminator===============================================================================")
313 | while(True):
314 | xx=(int)((raw_input)("press 1 to keep training"))
315 | ep=(int)((raw_input)("enter updated number of epochs"))
316 | if(xx!=1):
317 | break
318 | #multi_discriminator.summary()
319 | multi_discriminator.fit([merged_inputs],[merged_gt],batch_size=72*4,nb_epoch=ep,shuffle=True)
320 |
321 | return
322 |
323 |
324 |
325 | def train_generator(true_label,ep,loss_ch):
326 | for layer in discriminator.layers: layer.trainable = False
327 |
328 |
329 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={
330 |
331 | 'new_res_1_final_opa':'mse',
332 | 'x_u_net_opsp':special_loss_disjoint
333 |
334 | })
335 |
336 | discriminator.compile(loss='mae',
337 | optimizer=keras.optimizers.Adam(lr=5e-5),
338 | metrics=['accuracy'])
339 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['accuracy'],loss={'model_2':'mae',
340 |
341 | 'new_res_1_final_opa':'mse',
342 | 'x_u_net_opsp':special_loss_disjoint})
343 |
344 | multi_final_model=multi_gpu_model(final_model,gpus=4)
345 | multi_final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['accuracy'],loss={'model_2':'mae',
346 |
347 | 'new_res_1_final_opa':'mse',
348 | 'x_u_net_opsp':special_loss_disjoint})
349 |
350 | #discriminator.summary()
351 | X_train=np.load("input_for_generator.npy")
352 | #X_train=np.load("input_for_generator.npy")
353 |
354 |
355 |
356 |
357 | y_train=[]
358 |
359 | for j in range(0,len(X_train)):
360 |
361 | y_train.append(-1)
362 | y_train=np.array(y_train)
363 | #print(multi_final_model.summary())
364 | y_empty=np.zeros(shape=(X_train.shape))
365 | while(True):
366 | xx=(int)((raw_input)("press 1 to keep training"))
367 | ep=(int)((raw_input)("enter updated number of epochs"))
368 | if(xx!=1):
369 | break
370 |
371 | multi_final_model.fit([X_train],[y_train,X_train,y_empty],batch_size=16*4,nb_epoch=ep,shuffle=True)
372 | result=generator.predict([X_train[0:1000]],batch_size=16)
373 | result[1]=(result[1]-np.amin(result[1]))/((np.amax(result[1]))-(np.amin(result[1])))
374 | result[0]=(result[0]-np.amin(result[0]))/((np.amax(result[0]))-(np.amin(result[0])))
375 | result[2]=(result[2]-np.amin(result[2]))/((np.amax(result[2]))-(np.amin(result[2])))
376 | for i in range(0,1000):
377 | cv2.imwrite("outputs/norm/id1/"+str(i)+".png",(result[0][i])*255)
378 | cv2.imwrite("outputs/norm/id2/"+str(i)+".png",(result[1][i])*255)
379 | cv2.imwrite("outputs/norm/id3/"+str(i)+".png",(result[2][i])*255)
380 | cv2.imwrite("outputs/norm/in/"+str(i)+".png",(X_train[i])*255)
381 | #final_model.fit([X_train],[y_train,X_train,y_empty],batch_size=16,nb_epoch=ep,shuffle=True)
382 | return
383 |
384 | while(True):
385 | i_p=(int)(raw_input("press 0 to train disc and 1 to train gen 2 to save models 3 to check outputs anything else to quit"))
386 |
387 | if(i_p==0):
388 | #'''
389 | discriminator=build_discriminator(input_shape)
390 | discriminator.name='model_2'
391 | for layer in discriminator.layers: layer.trainable = False
392 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={
393 |
394 | 'new_res_1_final_opa':'mse',
395 | 'x_u_net_opsp':special_loss_disjoint
396 |
397 | })
398 |
399 | discriminator.compile(loss='mae',
400 | optimizer=keras.optimizers.Adam(lr=5e-5),
401 | metrics=['accuracy'])
402 |
403 | #discriminator.trainable=False
404 | final_input=generator.input
405 | #final_input_1=discriminator.input
406 | #connect the two
407 |
408 | #discriminator.input=generator.get_layer('output_gen').output
409 | x_u_net_opsp=(generator.get_layer('x_u_net_opsp').output)
410 | final_output_gans=discriminator(generator.get_layer('new_final_op').output)
411 | final_output_seg=(generator.get_layer('new_xfinal_op').output)
412 | final_output_res=(generator.get_layer('new_res_1_final_opa').output)
413 |
414 | #final_model.add(generator)
415 | #final_model.add(discriminator)
416 | final_model=Model(inputs=[final_input],outputs=[final_output_gans,final_output_seg,final_output_res,x_u_net_opsp])
417 |
418 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['mae'],loss={'model_2':'mae',
419 |
420 | 'new_res_1_final_opa':'mse',
421 | 'x_u_net_opsp':special_loss_disjoint})
422 | loss_ch=0
423 | #'''
424 | print ("training disc")
425 |
426 | ep=(int)(raw_input("enter number of epochs"))
427 | real_data=np.load("good_dic_to_train_disc.npy")
428 |
429 |
430 |
431 | X_train_tumors=np.load("input_for_generator.npy")
432 |
433 |
434 |
435 | fake_data=generator.predict([X_train_tumors])[0]
436 |
437 | print("fake_data_shape",fake_data.shape)
438 |
439 |
440 | true_label=1
441 |
442 | print((real_data.shape),(fake_data.shape),true_label,ep)
443 | proceed=(int)((raw_input)("proceed press 1"))
444 | if(proceed==1):
445 | train_disc(real_data,fake_data,true_label,ep,loss_ch)
446 | else:
447 | continue
448 |
449 | elif(i_p==1):
450 | print("training gen")
451 | loss_ch=0
452 | ep=(int)(raw_input("enter number of epochs"))
453 | true_label=1
454 |
455 | proceed=(int)((raw_input)("proceed press 1"))
456 | if(proceed==1):
457 | train_generator(true_label,ep,loss_ch)
458 | else:
459 | continue
460 | elif(i_p==2):
461 | import h5py
462 |
463 | final_model.save('disjoint_un_sup_mse_complete_gans.h5')
464 | generator.save("disjoint_un_sup_mse_generator.h5")
465 | discriminator.save("disjoint_un_sup_mse_discriminator.h5")
466 |
467 |
468 | elif(i_p==3):
469 | X_train=np.load("input_for_generator.npy")
470 | y_train=np.load("tumor_mask_for_generator.npy")
471 | result=generator.predict([X_train[0:1000]],batch_size=16)
472 | #result=np.array(result)
473 | #print (result.shape)
474 |
475 | print(np.amax(result[0]),np.amax(result[1]),np.amax(result[2]),np.amax(result[3]))
476 |
477 | print(np.amin(result[0]),np.amin(result[1]),np.amin(result[2]),np.amin(result[3]))
478 |
479 |
480 |
481 |
482 |
483 | for i in range(0,1000):
484 | cv2.imwrite("outputs/id1/"+str(i)+".png",(result[0][i])*255)
485 | cv2.imwrite("outputs/id2/"+str(i)+".png",(result[1][i])*255)
486 | cv2.imwrite("outputs/id3/"+str(i)+".png",(result[2][i])*255)
487 |
488 | cv2.imwrite("outputs/norm/in/"+str(i)+".png",X_train[i]*255)
489 | cv2.imwrite("outputs/norm/op/"+str(i)+".png",y_train[i]*255)
490 |
491 |
492 | result[1]=(result[1]-np.amin(result[1]))/((np.amax(result[1]))-(np.amin(result[1])))
493 | result[0]=(result[0]-np.amin(result[0]))/((np.amax(result[0]))-(np.amin(result[0])))
494 | result[2]=(result[2]-np.amin(result[2]))/((np.amax(result[2]))-(np.amin(result[2])))
495 | for i in range(0,1000):
496 | cv2.imwrite("outputs/norm/id1/"+str(i)+".png",(result[0][i])*255)
497 | cv2.imwrite("outputs/norm/id2/"+str(i)+".png",(result[1][i])*255)
498 | cv2.imwrite("outputs/norm/id3/"+str(i)+".png",(result[2][i])*255)
499 |
500 | else:
501 | break
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
--------------------------------------------------------------------------------
/src/continue_training_stage_2.py:
--------------------------------------------------------------------------------
1 | import keras
2 | from keras import optimizers
3 | #from keras.utils import multi_gpu_model
4 | import scipy as sp
5 | import scipy.misc, scipy.ndimage.interpolation
6 | from medpy import metric
7 | import numpy as np
8 | import os
9 | from keras import losses
10 | import tensorflow as tf
11 | from keras.models import Model,Sequential
12 | from keras.layers import Input,merge, concatenate, Conv2D, MaxPooling2D, Activation, UpSampling2D,Dropout,Conv2DTranspose,add,multiply,Dense,Flatten
13 | from keras.layers.normalization import BatchNormalization as bn
14 | from keras.callbacks import ModelCheckpoint, TensorBoard
15 | from keras.optimizers import RMSprop
16 | from keras import regularizers
17 | from keras import backend as K
18 | from keras.optimizers import Adam
19 | from keras.callbacks import ModelCheckpoint
20 | import tensorflow as tf
21 | #from keras.applications import Xception
22 | from keras.utils import multi_gpu_model
23 | import random
24 | import numpy as np
25 | from keras.callbacks import EarlyStopping, ModelCheckpoint
26 | import nibabel as nib
27 | import cv2
28 | CUDA_VISIBLE_DEVICES = [0,1,2,3]
29 | os.environ['CUDA_VISIBLE_DEVICES']=','.join([str(x) for x in CUDA_VISIBLE_DEVICES])
30 | smooth=1.
31 | input_shape=240,240,1
32 | ########################################Losses#################################################
33 | def special_loss_disjoint(y_true,y_pred):
34 |
35 | y_true,y_pred=tf.split(y_pred, 2,axis=-1)
36 |
37 | thresholded_pred = tf.where( tf.greater( y_pred ,0.0000000000000001 ), 1 * tf.ones_like( y_pred ), y_pred )#where(cond : take true values : take false values)
38 |
39 | thresholded_true=tf.where( tf.greater( y_true,0.0000000000000001 ), 1 * tf.ones_like( y_true ), y_true )
40 |
41 | return dice_coef(thresholded_true,thresholded_pred)
42 |
43 | def dice_coef(y_true, y_pred):
44 |
45 | y_true_f = K.flatten(y_true)
46 | y_pred_f = K.flatten(y_pred)
47 | intersection = K.sum(y_true_f * y_pred_f)
48 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
49 | def dice_coef_loss(y_true, y_pred):
50 | return dice_coef(y_true, y_pred)
51 |
52 | ################################################################################################
53 |
54 |
55 |
56 | def build_discriminator(input_shape,learn_rate=1e-3):
57 | l2_lambda = 0.0002
58 | DropP = 0.3
59 | kernel_size=3
60 |
61 | inputs = Input(input_shape,name="disc_ip")
62 |
63 | conv0a = Conv2D( 32, (kernel_size, kernel_size), activation='relu', padding='same',
64 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc15' )(inputs)
65 |
66 |
67 | conv0a = bn(name='disc_l2_bn1')(conv0a)
68 |
69 | conv0b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
70 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc16' )(conv0a)
71 |
72 | conv0b = bn(name='disc_l2_bn2')(conv0b)
73 |
74 |
75 |
76 |
77 | pool0 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp1')(conv0b)
78 |
79 | pool0 = Dropout(DropP,name='disc_l2_d1')(pool0)
80 |
81 |
82 |
83 |
84 |
85 |
86 | conv2a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
87 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc17' )(pool0)
88 |
89 | conv2a = bn(name='disc_l2_bn3')(conv2a)
90 |
91 | conv2b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
92 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc18')(conv2a)
93 |
94 | conv2b = bn(name='disc_l2_bn4')(conv2b)
95 |
96 | pool2 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp2')(conv2b)
97 |
98 | pool2 = Dropout(DropP,name='disc_l2_d2')(pool2)
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 | conv3a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
107 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc19' )(pool2)
108 |
109 | conv3a = bn(name='disc_l2_bn5')(conv3a)
110 |
111 | conv3b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
112 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc20')(conv3a)
113 |
114 | conv3b = bn(name='disc_l2_bn6')(conv3b)
115 |
116 |
117 |
118 | pool3 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp3')(conv3b)
119 |
120 | pool3 = Dropout(DropP,name='disc_l2_d3')(pool3)
121 |
122 |
123 | conv4a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
124 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc21' )(pool3)
125 |
126 | conv4a = bn(name='disc_l2_bn7')(conv4a)
127 |
128 | conv4b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
129 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc22' )(conv4a)
130 |
131 | conv4b = bn(name='disc_l2_bn8')(conv4b)
132 |
133 | pool4 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp4')(conv4b)
134 |
135 | pool4 = Dropout(DropP,name='disc_l2_d4')(pool4)
136 |
137 |
138 |
139 |
140 |
141 | conv5a = Conv2D(512, (kernel_size, kernel_size), activation='relu', padding='same',
142 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc23')(pool4)
143 |
144 | conv5a = bn(name='disc_l2_bn9')(conv5a)
145 |
146 | conv5b = Conv2D(512, (kernel_size, kernel_size), activation='relu', padding='same',
147 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc24')(conv5a)
148 |
149 | conv5b = bn(name='disc_l2_bn10')(conv5b)
150 |
151 | flat=Flatten()(conv5b)
152 |
153 | output_disc=Dense(1,activation='tanh',name='disc_output')(flat)#placeholder
154 |
155 | model=Model(inputs=[inputs],outputs=[output_disc])
156 | model.compile(loss='mae',
157 | optimizer=keras.optimizers.Adam(lr=5e-5),
158 | metrics=['accuracy'])
159 | #model.summary()
160 | return model
161 |
162 | input_shape=240,240,1
163 |
164 | from keras.models import load_model
165 | generator=load_model('disjoint_un_sup_mse_generator.h5', custom_objects={'dice_coef_loss':dice_coef_loss,'special_loss_disjoint':special_loss_disjoint})
166 | discriminator=load_model('disjoint_un_sup_mse_discriminator.h5', custom_objects={'dice_coef_loss':dice_coef_loss,'special_loss_disjoint':special_loss_disjoint})
167 |
168 | for layer in discriminator.layers: layer.trainable = False
169 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={
170 |
171 | 'new_res_1_final_opa':'mse',
172 | 'x_u_net_opsp':special_loss_disjoint
173 |
174 | })
175 |
176 | discriminator.compile(loss='mae',
177 | optimizer=keras.optimizers.Adam(lr=5e-5),
178 | metrics=['accuracy'])
179 |
180 | final_input=generator.input
181 |
182 |
183 |
184 | x_u_net_opsp=(generator.get_layer('x_u_net_opsp').output)
185 | final_output_gans=discriminator(generator.get_layer('new_final_op').output)
186 | final_output_seg=(generator.get_layer('new_xfinal_op').output)
187 | final_output_res=(generator.get_layer('new_res_1_final_opa').output)
188 |
189 | #final_model.add(generator)
190 | #final_model.add(discriminator)
191 | final_model=Model(inputs=[final_input],outputs=[final_output_gans,final_output_seg,final_output_res,x_u_net_opsp])
192 |
193 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['mae'],loss={'model_2':'mae',
194 |
195 | 'new_res_1_final_opa':'mse',
196 | 'x_u_net_opsp':special_loss_disjoint})
197 |
198 |
199 | print("full gans")
200 | final_model.summary()
201 | print(final_model.input)
202 | print(final_model.output)
203 | print("============================================================================================================================================================")
204 | print("generator")
205 | generator.summary()
206 | print(generator.input)
207 | print(generator.output)
208 | print("============================================================================================================================================================")
209 |
210 | print("discriminator")
211 | discriminator.summary()
212 | print(discriminator.get_input_at(0))
213 | print(discriminator.get_input_at(1))
214 | #print(discriminator.output)
215 | print("============================================================================================================================================================")
216 | #print(discriminator.get_input_at(2))
217 | #print(discriminator.input[2])
218 | #X_train=np.ones((1,160,160,1))
219 | #final_model.fit([X_train],[1],batch_size=1,nb_epoch=1,shuffle=False)
220 | #print ("hi",final_model.predict([X_train],batch_size=1))
221 |
222 |
223 | def train_disc(real_data,fake_data,true_label,ep,loss_ch):
224 |
225 | discriminator=build_discriminator(input_shape)
226 | discriminator.name='model_2'
227 | for layer in discriminator.layers: layer.trainable = False
228 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={
229 |
230 | 'new_res_1_final_opa':'mse',
231 | 'x_u_net_opsp':special_loss_disjoint
232 |
233 | })
234 |
235 | discriminator.compile(loss='mae',
236 | optimizer=keras.optimizers.Adam(lr=5e-5),
237 | metrics=['accuracy'])
238 |
239 | final_input=generator.input
240 |
241 |
242 |
243 | x_u_net_opsp=(generator.get_layer('x_u_net_opsp').output)
244 | final_output_gans=discriminator(generator.get_layer('new_final_op').output)
245 | final_output_seg=(generator.get_layer('new_xfinal_op').output)
246 | final_output_res=(generator.get_layer('new_res_1_final_opa').output)
247 |
248 | final_model=Model(inputs=[final_input],outputs=[final_output_gans,final_output_seg,final_output_res,x_u_net_opsp])
249 |
250 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['mae'],loss={'model_2':'mae',
251 |
252 | 'new_res_1_final_opa':'mse',
253 | 'x_u_net_opsp':special_loss_disjoint})
254 |
255 | for layer in discriminator.layers: layer.trainable = True
256 |
257 |
258 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={
259 |
260 | 'new_res_1_final_opa':'mse',
261 |
262 | })
263 |
264 | discriminator.compile(loss='mae',
265 | optimizer=keras.optimizers.Adam(lr=5e-5),
266 | metrics=['accuracy'])
267 | multi_discriminator=multi_gpu_model(discriminator,gpus=4)
268 | multi_discriminator.compile(loss='mae',
269 | optimizer=keras.optimizers.Adam(lr=5e-5),
270 | metrics=['accuracy'])
271 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['accuracy'],loss={'model_2':'mae',
272 |
273 | 'new_res_1_final_opa':'mse',
274 | 'x_u_net_opsp':special_loss_disjoint})
275 |
276 | discriminator.summary()
277 |
278 |
279 | y_train_true=-np.ones(shape=len(real_data))
280 | y_train_true=y_train_true#-0.1
281 | print(y_train_true.shape)
282 |
283 |
284 |
285 |
286 |
287 | y_train_fake=np.ones(shape=len(fake_data))
288 | y_train_fake=y_train_fake#-0.1
289 |
290 | real_data=(list)(real_data)
291 | fake_data=(list)(fake_data)
292 | y_train_true=(list)(y_train_true)
293 |
294 | y_train_fake=(list)(y_train_fake)
295 | merged_inputs=[real_data+fake_data]
296 | real_data=[]
297 | fake_data=[]
298 | merged_gt=[y_train_true+y_train_fake]
299 | print('hi')
300 |
301 | y_train_fake=[]
302 | y_train_true=[]
303 | from sklearn.utils import shuffle
304 | merged_inputs,merged_gt=shuffle(merged_inputs,merged_gt)
305 |
306 | merged_inputs=np.array(merged_inputs)
307 | merged_gt=np.array(merged_gt)
308 | merged_inputs=np.squeeze(merged_inputs,axis=(0,))
309 | merged_gt=np.squeeze(merged_gt,axis=(0,))
310 |
311 | print("training_discriminator===============================================================================")
312 | while(True):
313 | xx=(int)((raw_input)("press 1 to keep training"))
314 | ep=(int)((raw_input)("enter updated number of epochs"))
315 | if(xx!=1):
316 | break
317 | #multi_discriminator.summary()
318 | multi_discriminator.fit([merged_inputs],[merged_gt],batch_size=72*4,nb_epoch=ep,shuffle=True)
319 |
320 | return
321 |
322 |
323 |
324 | def train_generator(true_label,ep,loss_ch):
325 | for layer in discriminator.layers: layer.trainable = False
326 |
327 |
328 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={
329 |
330 | 'new_res_1_final_opa':'mse',
331 | 'x_u_net_opsp':special_loss_disjoint
332 |
333 | })
334 |
335 | discriminator.compile(loss='mae',
336 | optimizer=keras.optimizers.Adam(lr=5e-5),
337 | metrics=['accuracy'])
338 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['accuracy'],loss={'model_2':'mae',
339 |
340 | 'new_res_1_final_opa':'mse',
341 | 'x_u_net_opsp':special_loss_disjoint})
342 |
343 | multi_final_model=multi_gpu_model(final_model,gpus=4)
344 | multi_final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['accuracy'],loss={'model_2':'mae',
345 |
346 | 'new_res_1_final_opa':'mse',
347 | 'x_u_net_opsp':special_loss_disjoint})
348 |
349 | #discriminator.summary()
350 | X_train=np.load("input_for_generator.npy")
351 | X_train=np.concatenate((X_train,X_train),axis=0) #double
352 | #X_train=np.load("input_for_generator.npy")
353 |
354 |
355 |
356 |
357 | y_train=[]
358 |
359 | for j in range(0,len(X_train)):
360 |
361 | y_train.append(-1)
362 | y_train=np.array(y_train)
363 | #print(multi_final_model.summary())
364 | y_empty=np.zeros(shape=(X_train.shape))
365 | while(True):
366 | xx=(int)((raw_input)("press 1 to keep training"))
367 | ep=(int)((raw_input)("enter updated number of epochs"))
368 | if(xx!=1):
369 | break
370 |
371 | multi_final_model.fit([X_train],[y_train,X_train,y_empty],batch_size=16*4,nb_epoch=ep,shuffle=True)
372 | result=generator.predict([X_train[0:1000]],batch_size=16)
373 | result[1]=(result[1]-np.amin(result[1]))/((np.amax(result[1]))-(np.amin(result[1])))
374 | result[0]=(result[0]-np.amin(result[0]))/((np.amax(result[0]))-(np.amin(result[0])))
375 | result[2]=(result[2]-np.amin(result[2]))/((np.amax(result[2]))-(np.amin(result[2])))
376 | for i in range(0,1000):
377 | cv2.imwrite("outputs/norm/id1/"+str(i)+".png",(result[0][i])*255)
378 | cv2.imwrite("outputs/norm/id2/"+str(i)+".png",(result[1][i])*255)
379 | cv2.imwrite("outputs/norm/id3/"+str(i)+".png",(result[2][i])*255)
380 | cv2.imwrite("outputs/norm/in/"+str(i)+".png",(X_train[i])*255)
381 | #final_model.fit([X_train],[y_train,X_train,y_empty],batch_size=16,nb_epoch=ep,shuffle=True)
382 | return
383 |
384 | while(True):
385 | i_p=(int)(raw_input("press 0 to train disc and 1 to train gen 2 to save models 3 to check outputs anything else to quit"))
386 |
387 | if(i_p==0):
388 | #'''
389 | discriminator=build_discriminator(input_shape)
390 | discriminator.name='model_2'
391 | for layer in discriminator.layers: layer.trainable = False
392 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={
393 |
394 | 'new_res_1_final_opa':'mse',
395 | 'x_u_net_opsp':special_loss_disjoint
396 |
397 | })
398 |
399 | discriminator.compile(loss='mae',
400 | optimizer=keras.optimizers.Adam(lr=5e-5),
401 | metrics=['accuracy'])
402 |
403 | #discriminator.trainable=False
404 | final_input=generator.input
405 | #final_input_1=discriminator.input
406 | #connect the two
407 |
408 | #discriminator.input=generator.get_layer('output_gen').output
409 | x_u_net_opsp=(generator.get_layer('x_u_net_opsp').output)
410 | final_output_gans=discriminator(generator.get_layer('new_final_op').output)
411 | final_output_seg=(generator.get_layer('new_xfinal_op').output)
412 | final_output_res=(generator.get_layer('new_res_1_final_opa').output)
413 |
414 | #final_model.add(generator)
415 | #final_model.add(discriminator)
416 | final_model=Model(inputs=[final_input],outputs=[final_output_gans,final_output_seg,final_output_res,x_u_net_opsp])
417 |
418 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['mae'],loss={'model_2':'mae',
419 |
420 | 'new_res_1_final_opa':'mse',
421 | 'x_u_net_opsp':special_loss_disjoint})
422 | loss_ch=0
423 | #'''
424 | print ("training disc")
425 |
426 | ep=(int)(raw_input("enter number of epochs"))
427 | real_data=generator.predict(np.load("good_dic_to_train_disc.npy"))[0]
428 | real_data=np.concatenate((real_data,np.load("good_dic_to_train_disc.npy")),axis=0)
429 |
430 |
431 | X_train_tumors=np.load("input_for_generator.npy")
432 | X_train_tumors=np.concatenate((X_train_tumors,X_train_tumors),axis=0)
433 |
434 |
435 | fake_data=generator.predict([X_train_tumors])[0]
436 |
437 | print("fake_data_shape",fake_data.shape)
438 |
439 |
440 | true_label=1
441 |
442 | print((real_data.shape),(fake_data.shape),true_label,ep)
443 | proceed=(int)((raw_input)("proceed press 1"))
444 | if(proceed==1):
445 | train_disc(real_data,fake_data,true_label,ep,loss_ch)
446 | else:
447 | continue
448 |
449 | elif(i_p==1):
450 | print("training gen")
451 | loss_ch=0
452 | ep=(int)(raw_input("enter number of epochs"))
453 | true_label=1
454 |
455 | proceed=(int)((raw_input)("proceed press 1"))
456 | if(proceed==1):
457 | train_generator(true_label,ep,loss_ch)
458 | else:
459 | continue
460 | elif(i_p==2):
461 | import h5py
462 |
463 | final_model.save('disjoint_un_sup_mse_complete_gans.h5')
464 | generator.save("disjoint_un_sup_mse_generator.h5")
465 | discriminator.save("disjoint_un_sup_mse_discriminator.h5")
466 |
467 |
468 | elif(i_p==3):
469 | X_train=np.load("input_for_generator.npy")
470 | y_train=np.load("tumor_mask_for_generator.npy")
471 | result=generator.predict([X_train[0:1000]],batch_size=16)
472 | #result=np.array(result)
473 | #print (result.shape)
474 |
475 | print(np.amax(result[0]),np.amax(result[1]),np.amax(result[2]),np.amax(result[3]))
476 |
477 | print(np.amin(result[0]),np.amin(result[1]),np.amin(result[2]),np.amin(result[3]))
478 |
479 |
480 |
481 |
482 |
483 | for i in range(0,1000):
484 | cv2.imwrite("outputs/id1/"+str(i)+".png",(result[0][i])*255)
485 | cv2.imwrite("outputs/id2/"+str(i)+".png",(result[1][i])*255)
486 | cv2.imwrite("outputs/id3/"+str(i)+".png",(result[2][i])*255)
487 |
488 | cv2.imwrite("outputs/norm/in/"+str(i)+".png",X_train[i]*255)
489 | cv2.imwrite("outputs/norm/op/"+str(i)+".png",y_train[i]*255)
490 |
491 |
492 | result[1]=(result[1]-np.amin(result[1]))/((np.amax(result[1]))-(np.amin(result[1])))
493 | result[0]=(result[0]-np.amin(result[0]))/((np.amax(result[0]))-(np.amin(result[0])))
494 | result[2]=(result[2]-np.amin(result[2]))/((np.amax(result[2]))-(np.amin(result[2])))
495 | for i in range(0,1000):
496 | cv2.imwrite("outputs/norm/id1/"+str(i)+".png",(result[0][i])*255)
497 | cv2.imwrite("outputs/norm/id2/"+str(i)+".png",(result[1][i])*255)
498 | cv2.imwrite("outputs/norm/id3/"+str(i)+".png",(result[2][i])*255)
499 |
500 | else:
501 | break
502 |
503 |
504 |
505 |
506 |
507 |
--------------------------------------------------------------------------------
/src/create_networks.py:
--------------------------------------------------------------------------------
1 | #ps aux --sort=-%mem | awk 'NR<=10{print $0}'
2 |
3 | import keras
4 | from keras import optimizers
5 | #from keras.utils import multi_gpu_model
6 | import scipy as sp
7 | import scipy.misc, scipy.ndimage.interpolation
8 | from medpy import metric
9 | import numpy as np
10 | import os
11 | from keras import losses
12 | import tensorflow as tf
13 | from keras.models import Model,Sequential
14 | from keras.layers import Input,merge, concatenate, Conv2D, MaxPooling2D, Activation, UpSampling2D,Dropout,Conv2DTranspose,add,multiply,Dense,Flatten
15 | from keras.layers.normalization import BatchNormalization as bn
16 | from keras.callbacks import ModelCheckpoint, TensorBoard
17 | from keras.optimizers import RMSprop
18 | from keras import regularizers
19 | from keras import backend as K
20 | from keras.optimizers import Adam
21 | from keras.callbacks import ModelCheckpoint
22 | import tensorflow as tf
23 | #from keras.applications import Xception
24 | from keras.utils import multi_gpu_model
25 | import random
26 | import numpy as np
27 | from keras.callbacks import EarlyStopping, ModelCheckpoint
28 | import nibabel as nib
29 | CUDA_VISIBLE_DEVICES = [1]
30 | os.environ['CUDA_VISIBLE_DEVICES']=','.join([str(x) for x in CUDA_VISIBLE_DEVICES])
31 | smooth=1.
32 | def special_loss_disjoint(y_true,y_pred):
33 |
34 | y_true,y_pred=tf.split(y_pred, 2,axis=-1)
35 | thresholded_pred = tf.where( tf.greater( 0.0, y_pred ), 1 * tf.ones_like( y_pred ), y_pred )
36 | thresholded_true=tf.where( tf.greater( 0.0, y_true ), 1 * tf.ones_like( y_true ), y_true )
37 | #tf.keras.backend.print_tensor(first)
38 | return dice_coef(thresholded_true,thresholded_pred)
39 |
40 | def dice_coef(y_true, y_pred):
41 |
42 | y_true_f = K.flatten(y_true)
43 | y_pred_f = K.flatten(y_pred)
44 | intersection = K.sum(y_true_f * y_pred_f)
45 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
46 | def dice_coef_loss(y_true, y_pred):
47 | return dice_coef(y_true, y_pred)
48 |
49 |
50 | def build_generator(input_shape,learn_rate=1e-3):
51 |
52 |
53 |
54 | l2_lambda = 0.0002
55 | DropP = 0.3
56 | kernel_size=3
57 |
58 | inputs = Input(input_shape)
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 | conv0a = Conv2D( 32, (kernel_size, kernel_size), activation='relu', padding='same',
68 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc15' )(inputs)
69 |
70 |
71 | conv0a = bn(name='l2_bn1')(conv0a)
72 |
73 | conv0b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
74 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc16' )(conv0a)
75 |
76 | conv0b = bn(name='l2_bn2')(conv0b)
77 |
78 |
79 |
80 |
81 | pool0 = MaxPooling2D(pool_size=(2, 2),name='l2_mp1')(conv0b)
82 |
83 | pool0 = Dropout(DropP,name='l2_d1')(pool0)
84 |
85 |
86 |
87 |
88 |
89 |
90 | conv2a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
91 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc17' )(pool0)
92 |
93 | conv2a = bn(name='l2_bn3')(conv2a)
94 |
95 | conv2b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
96 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc18')(conv2a)
97 |
98 | conv2b = bn(name='l2_bn4')(conv2b)
99 |
100 | pool2 = MaxPooling2D(pool_size=(2, 2),name='l2_mp2')(conv2b)
101 |
102 | pool2 = Dropout(DropP,name='l2_d2')(pool2)
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 | conv3a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
111 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc19' )(pool2)
112 |
113 | conv3a = bn(name='l2_bn5')(conv3a)
114 |
115 | conv3b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
116 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc20')(conv3a)
117 |
118 | conv3b = bn(name='l2_bn6')(conv3b)
119 |
120 |
121 |
122 | pool3 = MaxPooling2D(pool_size=(2, 2),name='l2_mp3')(conv3b)
123 |
124 | pool3 = Dropout(DropP,name='l2_d3')(pool3)
125 |
126 |
127 | conv4a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
128 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc21' )(pool3)
129 |
130 | conv4a = bn(name='l2_bn7')(conv4a)
131 |
132 | conv4b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
133 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc22' )(conv4a)
134 |
135 | conv4b = bn(name='l2_bn8')(conv4b)
136 |
137 | pool4 = MaxPooling2D(pool_size=(2, 2),name='l2_mp4')(conv4b)
138 |
139 | pool4 = Dropout(DropP,name='l2_d4')(pool4)
140 |
141 |
142 |
143 |
144 |
145 | conv5a = Conv2D(512, (kernel_size, kernel_size), activation='relu', padding='same',
146 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc23')(pool4)
147 |
148 | conv5a = bn(name='l2_bn9')(conv5a)
149 |
150 | conv5b = Conv2D(512, (kernel_size, kernel_size), activation='relu', padding='same',
151 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc24')(conv5a)
152 |
153 | conv5b = bn(name='l2_bn10')(conv5b)
154 |
155 |
156 |
157 |
158 |
159 | up6 = concatenate([Conv2DTranspose(256,(2, 2), strides=(2, 2), padding='same',name='l2_conc25')(conv5b), (conv4b)], axis=3,name='l2_conc1')
160 |
161 |
162 | up6 = Dropout(DropP,name='l2_d5')(up6)
163 |
164 | conv6a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
165 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc26')(up6)
166 |
167 | conv6a = bn(name='l2_bn11')(conv6a)
168 |
169 | conv6b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
170 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc27' )(conv6a)
171 |
172 | conv6b = bn(name='l2_bn12')(conv6b)
173 |
174 |
175 |
176 |
177 |
178 | up7 = concatenate([Conv2DTranspose(128,(2, 2), strides=(2, 2), padding='same',name='l2_conc28')(conv6b),(conv3b)], axis=3,name='l2_conc2')
179 |
180 | up7 = Dropout(DropP,name='l2_d6')(up7)
181 | #add second output here
182 |
183 | conv7a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
184 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc29')(up7)
185 |
186 | conv7a = bn(name='l2_bn13')(conv7a)
187 |
188 |
189 |
190 | conv7b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
191 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc30')(conv7a)
192 |
193 | conv7b = bn(name='l2_bn14')(conv7b)
194 |
195 |
196 |
197 |
198 |
199 |
200 | up8 = concatenate([Conv2DTranspose(64,(2, 2), strides=(2, 2), padding='same',name='l2_conc31')(conv7b), (conv2b)], axis=3,name='l2_conc3')
201 |
202 | up8 = Dropout(DropP,name='l2_d7')(up8)
203 |
204 | conv8a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
205 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc32')(up8)
206 |
207 | conv8a = bn(name='l2_bn15')(conv8a)
208 |
209 |
210 | conv8b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
211 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc33' )(conv8a)
212 |
213 | conv8b = bn(name='l2_bn16')(conv8b)
214 |
215 |
216 |
217 | up10 = concatenate([Conv2DTranspose(32,(2, 2), strides=(2, 2), padding='same',name='l2_conc34')(conv8b),(conv0b)],axis=3,name='l2_conc4')
218 |
219 | conv10a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
220 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc35')(up10)
221 |
222 | conv10a = bn(name='l2_bn17')(conv10a)
223 |
224 |
225 |
226 | conv10b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
227 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc36' )(conv10a)
228 |
229 | conv10b = bn(name='l2_bn18')(conv10b)
230 |
231 |
232 |
233 | new_final_op=Conv2D(1, (1, 1), activation='sigmoid',name='new_final_op')(conv10b)
234 |
235 |
236 |
237 | #--------------------------------------------------------------------------------------
238 |
239 |
240 | xup6 = concatenate([Conv2DTranspose(256,(2, 2), strides=(2, 2), padding='same',name='l2_conc38')(conv5b), (conv4b)], axis=3,name='l2_conc5')
241 |
242 |
243 |
244 | xup6 = Dropout(DropP,name='l2_d8')(xup6)
245 |
246 | xconv6a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
247 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc39' )(xup6)
248 |
249 | xconv6a = bn(name='l2_bn19')(xconv6a)
250 |
251 |
252 |
253 | xconv6b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
254 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc40' )(xconv6a)
255 |
256 | xconv6b = bn(name='l2_bn20')(xconv6b)
257 |
258 |
259 |
260 |
261 |
262 | xup7 = concatenate([Conv2DTranspose(128,(2, 2), strides=(2, 2), padding='same',name='l2_conc41')(xconv6b),(conv3b)], axis=3,name='l2_conc6')#xconv6b
263 |
264 | xup7 = Dropout(DropP,name='l2_d9')(xup7)
265 |
266 | xconv7a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
267 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc42')(xup7)
268 |
269 | xconv7a = bn(name='l2_bn21')(xconv7a)
270 |
271 |
272 | xconv7b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
273 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc43')(xconv7a)
274 |
275 | xconv7b = bn(name='l2_bn22')(xconv7b)
276 |
277 |
278 | xup8 = concatenate([Conv2DTranspose(64,(2, 2), strides=(2, 2), padding='same',name='l2_conc44')(xconv7b),(conv2b)], axis=3,name='l2_conc7')
279 |
280 | xup8 = Dropout(DropP,name='l2_d10')(xup8)
281 | #add third xoutxout here
282 |
283 | xconv8a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
284 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc45')(xup8)
285 |
286 | xconv8a = bn(name='l2_bn23')(xconv8a)
287 |
288 |
289 | xconv8b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
290 | kernel_regularizer=regularizers.l2(l2_lambda),name='l2_conc46' )(xconv8a)
291 |
292 | xconv8b = bn(name='l2_bn24')(xconv8b)
293 |
294 |
295 |
296 |
297 |
298 |
299 | xup10 = concatenate([Conv2DTranspose(32,(2, 2), strides=(2, 2), padding='same',name='l2_conc47')(xconv8b), (conv0b)],axis=3,name='l2_conc8')
300 |
301 | xup10 = Dropout(DropP,name='l2_d11')(xup10)
302 |
303 |
304 | xconv10a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
305 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc48')(xup10)
306 |
307 | xconv10a = bn(name='l2_bn25')(xconv10a)
308 |
309 |
310 | xconv10b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
311 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='l2_conc49')(xconv10a)
312 |
313 | xconv10b = bn(name='l2_bn26')(xconv10b)
314 |
315 |
316 |
317 |
318 |
319 |
320 | new_xfinal_op=Conv2D(1, (1, 1), activation='sigmoid',name='new_xfinal_op')(xconv10b)#tan
321 |
322 |
323 |
324 |
325 | #-----------------------------third branch
326 |
327 |
328 |
329 | #Concatenation fed to the reconstruction layer of all 3
330 |
331 | x_u_net_op0=keras.layers.concatenate([new_final_op,new_xfinal_op],name='l2_conc9')
332 | x_u_net_opsp=keras.layers.concatenate([new_final_op,new_xfinal_op],name='x_u_net_opsp')
333 |
334 | #res_1_conv0a = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
335 | # kernel_regularizer=regularizers.l2(l2_lambda) ,name='mixer_conv')(x_u_net_op0)
336 |
337 | #res_1_conv0a = bn()(res_1_conv0a)
338 |
339 | #res_1_conv0b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
340 | # kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv0a)
341 | #res_1_conv0c = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
342 | # kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv0b)
343 | #res_1_conv0d = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
344 | # kernel_regularizer=regularizers.l2(l2_lambda) )(res_1_conv0c)
345 |
346 |
347 | new_res_1_final_opa=Conv2D(1, (1, 1), activation='sigmoid',name='new_res_1_final_opa')(x_u_net_op0)
348 |
349 | model=Model(inputs=[inputs],outputs=[new_final_op,
350 | new_xfinal_op,
351 | new_res_1_final_opa,
352 | x_u_net_opsp
353 |
354 |
355 | ])
356 | model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={
357 |
358 | 'new_res_1_final_opa':'mse',
359 | 'x_u_net_opsp':special_loss_disjoint
360 |
361 | })
362 |
363 | return model
364 |
365 |
366 |
367 |
368 |
369 | #model.summary()
370 | #return model
371 |
372 | def build_discriminator(input_shape,learn_rate=1e-3):
373 | l2_lambda = 0.0002
374 | DropP = 0.3
375 | kernel_size=3
376 |
377 | inputs = Input(input_shape,name="disc_ip")
378 |
379 | conv0a = Conv2D( 32, (kernel_size, kernel_size), activation='relu', padding='same',
380 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc15' )(inputs)
381 |
382 |
383 | conv0a = bn(name='disc_l2_bn1')(conv0a)
384 |
385 | conv0b = Conv2D(32, (kernel_size, kernel_size), activation='relu', padding='same',
386 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc16' )(conv0a)
387 |
388 | conv0b = bn(name='disc_l2_bn2')(conv0b)
389 |
390 |
391 |
392 |
393 | pool0 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp1')(conv0b)
394 |
395 | pool0 = Dropout(DropP,name='disc_l2_d1')(pool0)
396 |
397 |
398 |
399 |
400 |
401 |
402 | conv2a = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
403 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc17' )(pool0)
404 |
405 | conv2a = bn(name='disc_l2_bn3')(conv2a)
406 |
407 | conv2b = Conv2D(64, (kernel_size, kernel_size), activation='relu', padding='same',
408 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc18')(conv2a)
409 |
410 | conv2b = bn(name='disc_l2_bn4')(conv2b)
411 |
412 | pool2 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp2')(conv2b)
413 |
414 | pool2 = Dropout(DropP,name='disc_l2_d2')(pool2)
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 | conv3a = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
423 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc19' )(pool2)
424 |
425 | conv3a = bn(name='disc_l2_bn5')(conv3a)
426 |
427 | conv3b = Conv2D(128, (kernel_size, kernel_size), activation='relu', padding='same',
428 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc20')(conv3a)
429 |
430 | conv3b = bn(name='disc_l2_bn6')(conv3b)
431 |
432 |
433 |
434 | pool3 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp3')(conv3b)
435 |
436 | pool3 = Dropout(DropP,name='disc_l2_d3')(pool3)
437 |
438 |
439 | conv4a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
440 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc21' )(pool3)
441 |
442 | conv4a = bn(name='disc_l2_bn7')(conv4a)
443 |
444 | conv4b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
445 | kernel_regularizer=regularizers.l2(l2_lambda),name='disc_l2_conc22' )(conv4a)
446 |
447 | conv4b = bn(name='disc_l2_bn8')(conv4b)
448 |
449 | pool4 = MaxPooling2D(pool_size=(2, 2),name='disc_l2_mp4')(conv4b)
450 |
451 | pool4 = Dropout(DropP,name='disc_l2_d4')(pool4)
452 |
453 |
454 |
455 |
456 |
457 | conv5a = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
458 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc23')(pool4)
459 |
460 | conv5a = bn(name='disc_l2_bn9')(conv5a)
461 |
462 | conv5b = Conv2D(256, (kernel_size, kernel_size), activation='relu', padding='same',
463 | kernel_regularizer=regularizers.l2(l2_lambda) ,name='disc_l2_conc24')(conv5a)
464 |
465 | conv5b = bn(name='disc_l2_bn10')(conv5b)
466 |
467 | flat=Flatten()(conv5b)
468 |
469 | output_disc=Dense(1,activation='tanh',name='disc_output')(flat)#placeholder
470 |
471 | model=Model(inputs=[inputs],outputs=[output_disc])
472 | model.compile(loss='mae',
473 | optimizer=keras.optimizers.Adam(lr=5e-5),
474 | metrics=['accuracy'])
475 | #model.summary()
476 | return model
477 |
478 |
479 | input_shape=240,240,1
480 | #final_model = Sequential()
481 |
482 | generator=build_generator(input_shape)
483 | discriminator=build_discriminator(input_shape)
484 | for layer in discriminator.layers: layer.trainable = False
485 | generator.compile(optimizer=keras.optimizers.Adam(lr=5e-5),loss={
486 |
487 | 'new_res_1_final_opa':'mse',
488 | 'x_u_net_opsp':special_loss_disjoint
489 |
490 | })
491 |
492 | discriminator.compile(loss='mae',
493 | optimizer=keras.optimizers.Adam(lr=5e-5),
494 | metrics=['accuracy'])
495 |
496 | #discriminator.trainable=False
497 | final_input=generator.input
498 | #final_input_1=discriminator.input
499 | #connect the two
500 |
501 | #discriminator.input=generator.get_layer('output_gen').output
502 | x_u_net_opsp=(generator.get_layer('x_u_net_opsp').output)
503 | final_output_gans=discriminator(generator.get_layer('new_final_op').output)
504 | final_output_seg=(generator.get_layer('new_xfinal_op').output)
505 | final_output_res=(generator.get_layer('new_res_1_final_opa').output)
506 |
507 | #final_model.add(generator)
508 | #final_model.add(discriminator)
509 | final_model=Model(inputs=[final_input],outputs=[final_output_gans,final_output_seg,final_output_res,x_u_net_opsp])
510 |
511 | final_model.compile(optimizer=keras.optimizers.Adam(lr=5e-5),metrics=['mae'],loss={'model_2':'mae',
512 |
513 | 'new_res_1_final_opa':'mse',
514 | 'x_u_net_opsp':special_loss_disjoint})
515 |
516 |
517 | print("full gans")
518 | final_model.summary()
519 | print(final_model.input)
520 | print(final_model.output)
521 | print("============================================================================================================================================================")
522 | print("generator")
523 | generator.summary()
524 | print(generator.input)
525 | print(generator.output)
526 | print("============================================================================================================================================================")
527 |
528 | print("discriminator")
529 | discriminator.summary()
530 | print(discriminator.get_input_at(0))
531 | print(discriminator.get_input_at(1))
532 | #print(discriminator.output)
533 | print("============================================================================================================================================================")
534 | #print(discriminator.get_input_at(2))
535 | #print(discriminator.input[2])
536 | #X_train=np.ones((1,240,240,1))
537 | #final_model.fit([X_train],[1],batch_size=1,nb_epoch=1,shuffle=False)
538 | #print ("hi",final_model.predict([X_train],batch_size=1))
539 | import h5py
540 |
541 | final_model.save('disjoint_un_sup_mse_complete_gans.h5')
542 | generator.save("disjoint_un_sup_mse_generator.h5")
543 | discriminator.save("disjoint_un_sup_mse_discriminator.h5")
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
--------------------------------------------------------------------------------