├── Data ├── C_Elegans │ ├── Raw │ │ └── README.txt │ ├── Synthetic │ │ └── README.txt │ ├── Original │ │ └── README.txt │ └── README.txt └── Example │ ├── Genuine │ ├── raw_image_0.png │ ├── raw_image_1.png │ ├── raw_image_10.png │ ├── raw_image_11.png │ ├── raw_image_12.png │ ├── raw_image_13.png │ ├── raw_image_14.png │ ├── raw_image_15.png │ ├── raw_image_2.png │ ├── raw_image_3.png │ ├── raw_image_4.png │ ├── raw_image_5.png │ ├── raw_image_6.png │ ├── raw_image_7.png │ ├── raw_image_8.png │ └── raw_image_9.png │ └── Synthetic │ ├── syn_image_0.png │ ├── syn_image_1.png │ ├── syn_image_10.png │ ├── syn_image_11.png │ ├── syn_image_12.png │ ├── syn_image_13.png │ ├── syn_image_14.png │ ├── syn_image_15.png │ ├── syn_image_2.png │ ├── syn_image_3.png │ ├── syn_image_4.png │ ├── syn_image_5.png │ ├── syn_image_6.png │ ├── syn_image_7.png │ ├── syn_image_8.png │ └── syn_image_9.png ├── Utilities ├── Utilities.pyc └── Utilities.py ├── Scripts ├── simple_example.sh ├── README.txt ├── train_UDCT_colored_live_neurons.sh ├── train_UDCT_nanowire.sh ├── train_UDCT_dead_live_neurons.sh └── train_UDCT_c_elegans.sh ├── Discriminator ├── MultiPatch.py ├── HisDis.py ├── PatchGAN34.py ├── PatchGAN70.py └── PatchGAN142.py ├── notebooks ├── make_raw_nanowire.ipynb ├── make_raw_dead_live_neurons.ipynb ├── make_synthetic_wires.ipynb └── VGG_Synthetic_by_localization.ipynb ├── README.md ├── create_h5_dataset.py ├── main.py ├── Generator └── Res_Gen.py └── cycleGAN.py /Data/C_Elegans/Raw/README.txt: -------------------------------------------------------------------------------- 1 | 1) Execute: /notebooks/make_raw_c_elegans.ipynb 2 | -------------------------------------------------------------------------------- /Data/C_Elegans/Synthetic/README.txt: -------------------------------------------------------------------------------- 1 | 1) Execute: /notebooks/make_synthetic_c_elegans.ipynb 2 | -------------------------------------------------------------------------------- /Utilities/Utilities.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Utilities/Utilities.pyc -------------------------------------------------------------------------------- /Data/Example/Genuine/raw_image_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_0.png -------------------------------------------------------------------------------- /Data/Example/Genuine/raw_image_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_1.png -------------------------------------------------------------------------------- /Data/Example/Genuine/raw_image_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_10.png -------------------------------------------------------------------------------- /Data/Example/Genuine/raw_image_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_11.png -------------------------------------------------------------------------------- /Data/Example/Genuine/raw_image_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_12.png -------------------------------------------------------------------------------- /Data/Example/Genuine/raw_image_13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_13.png -------------------------------------------------------------------------------- /Data/Example/Genuine/raw_image_14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_14.png -------------------------------------------------------------------------------- /Data/Example/Genuine/raw_image_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_15.png -------------------------------------------------------------------------------- /Data/Example/Genuine/raw_image_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_2.png -------------------------------------------------------------------------------- /Data/Example/Genuine/raw_image_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_3.png -------------------------------------------------------------------------------- /Data/Example/Genuine/raw_image_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_4.png -------------------------------------------------------------------------------- /Data/Example/Genuine/raw_image_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_5.png -------------------------------------------------------------------------------- /Data/Example/Genuine/raw_image_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_6.png -------------------------------------------------------------------------------- /Data/Example/Genuine/raw_image_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_7.png -------------------------------------------------------------------------------- /Data/Example/Genuine/raw_image_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_8.png -------------------------------------------------------------------------------- /Data/Example/Genuine/raw_image_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_9.png -------------------------------------------------------------------------------- /Data/Example/Synthetic/syn_image_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_0.png -------------------------------------------------------------------------------- /Data/Example/Synthetic/syn_image_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_1.png -------------------------------------------------------------------------------- /Data/Example/Synthetic/syn_image_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_10.png -------------------------------------------------------------------------------- /Data/Example/Synthetic/syn_image_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_11.png -------------------------------------------------------------------------------- /Data/Example/Synthetic/syn_image_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_12.png -------------------------------------------------------------------------------- /Data/Example/Synthetic/syn_image_13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_13.png -------------------------------------------------------------------------------- /Data/Example/Synthetic/syn_image_14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_14.png -------------------------------------------------------------------------------- /Data/Example/Synthetic/syn_image_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_15.png -------------------------------------------------------------------------------- /Data/Example/Synthetic/syn_image_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_2.png -------------------------------------------------------------------------------- /Data/Example/Synthetic/syn_image_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_3.png -------------------------------------------------------------------------------- /Data/Example/Synthetic/syn_image_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_4.png -------------------------------------------------------------------------------- /Data/Example/Synthetic/syn_image_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_5.png -------------------------------------------------------------------------------- /Data/Example/Synthetic/syn_image_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_6.png -------------------------------------------------------------------------------- /Data/Example/Synthetic/syn_image_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_7.png -------------------------------------------------------------------------------- /Data/Example/Synthetic/syn_image_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_8.png -------------------------------------------------------------------------------- /Data/Example/Synthetic/syn_image_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_9.png -------------------------------------------------------------------------------- /Scripts/simple_example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Create the sample dataset 4 | cd .. 5 | python create_h5_dataset.py ./Data/Example/Genuine/ ./Data/Example/Synthetic/ ./Data/Example/example_dataset.h5 6 | 7 | # Create a Directory for model data 8 | mkdir -p Models 9 | 10 | # Train the network 11 | python main.py --dataset=./Data/Example/example_dataset.h5 --name=example_model 12 | -------------------------------------------------------------------------------- /Data/C_Elegans/Original/README.txt: -------------------------------------------------------------------------------- 1 | 1) Download the original images of the c elegans dataset from the Broad Bioimage Bechnmark Collection: 2 | 3 | https://data.broadinstitute.org/bbbc/BBBC010/BBBC010_v1_images.zip 4 | 5 | Accession number: BBBC010, Version 1 6 | 7 | 2) Extract the images into this directory. Afterwards, there should be 195 tif images in total in this directory. 8 | 9 | 3) Execute /notebooks/transform_c_elegans_tif_to_png.ipynb 10 | -------------------------------------------------------------------------------- /Scripts/README.txt: -------------------------------------------------------------------------------- 1 | In this directory you can find scripts that can be exectuted in order to reproduce the results of the publication. 2 | 3 | simple_example.sh 4 | A simple example that tests if the code is executing properly. 5 | 6 | train_UDCT_c_elegans.sh 7 | This script downloads the C. elegans dataset from the Broad Bioimage Bechnmark Collection, creates the Raw/Syn dataset and trains a network. Attention: This script is deleting some data in the Data/C_Elegans/Original directory. 8 | 9 | train_UDCT_dead_live_neurons.sh 10 | This script creates the dead vs alive dataset for the neurons. Afterwards, it trains a UDCT cycleGAN on the dataset. 11 | 12 | 13 | -------------------------------------------------------------------------------- /Utilities/Utilities.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def produce_tiled_images(im_A,im_B,fake_A,fake_B,cyc_A,cyc_B): 4 | 5 | list_of_images=[im_A,im_B,fake_A,fake_B,cyc_A,cyc_B] 6 | for i in range(6): 7 | if np.shape(list_of_images[i])[-1]==1: 8 | list_of_images[i]=np.tile(list_of_images[i],[1,1,1,3]) 9 | list_of_images[i]=np.pad(list_of_images[i][0,:,:,:], ((20,20),(20,20),(0,0)), mode='constant', constant_values=[0.5]) 10 | im_A,im_B,fake_A,fake_B,cyc_A,cyc_B=list_of_images 11 | a=np.vstack( (im_A,im_B)) 12 | b=np.vstack( (fake_B,fake_A)) 13 | c=np.vstack( (cyc_A,cyc_B)) 14 | return np.hstack((a,b,c)) 15 | -------------------------------------------------------------------------------- /Scripts/train_UDCT_colored_live_neurons.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Create the raw dataset 4 | echo "Creating the raw dataset" 5 | cd ../notebooks/ 6 | jupyter nbconvert --to notebook --execute make_raw_colored_live_neurons.ipynb --ExecutePreprocessor.timeout=1800 7 | 8 | # Create the hdf5 file 9 | cd ../ 10 | mkdir -p Models 11 | echo "Creating the hdf5 dataset" 12 | python create_h5_dataset.py ./Data/Neuron_Col_Live/Raw/ ./Data/Neuron_Col_Live/Synthetic/ ./Data/Neuron_Col_Live/colored_live_neuron_dataset.h5 13 | 14 | # Train the network 15 | echo "Training the network" 16 | python main.py --dataset=./Data/Neuron_Col_Live/colored_live_neuron_dataset.h5 --name=live_colored_neuron_new 17 | 18 | # Create the generated synthetic images 19 | python main.py --dataset=./Data/Neuron_Col_Live/colored_live_neuron_dataset.h5 --name=live_colored_neuron_new --mode=gen_B 20 | -------------------------------------------------------------------------------- /Data/C_Elegans/README.txt: -------------------------------------------------------------------------------- 1 | Disclaimer: 'We used the C.elegans infection live/dead image set version 1 provided by Fred Ausubel and available from the Broad Bioimage Benchmark Collection [Ljosa et al., Nature Methods, 2012].' 2 | https://data.broadinstitute.org/bbbc/BBBC010/ 3 | 4 | For copyright reasons, the C. elegans dataset needs to be created by downloading the data from the Broad Bioimage Bechnmark Collection. 5 | 6 | To do so, please follow the README.txt instructions in the following order: 7 | 8 | 1) README.txt in Original/ 9 | 10 | 2) README.txt in Raw/ 11 | 12 | 3) README.txt in Synthetic/ 13 | 14 | 4) Execute in root directory: python create_h5_dataset.py ./Data/C_Elegans/Raw/ ./Data/C_Elegans/Synthetic/ ./Data/C_Elegans/c_elegans_dataset.h5 15 | 16 | Instead, you can also execute the script 'train_UDCT_c_elegans.sh' in the directory ../Scripts. It does all these steps automatically. 17 | -------------------------------------------------------------------------------- /Scripts/train_UDCT_nanowire.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Create the raw dataset 4 | echo "Creating the raw dataset" 5 | mkdir -p ../Data/Nanowire/Raw 6 | cd ../Data/Nanowire/Raw/ 7 | wget https://downloads.lbb.ethz.ch/Data/lbb_raw_nanowire_images.h5 8 | cd ../../../notebooks/ 9 | jupyter nbconvert --to notebook --execute make_raw_nanowire.ipynb --ExecutePreprocessor.timeout=1800 10 | 11 | # Create the synthetic images 12 | mkdir -p ../Data/Nanowire/Synthetic 13 | echo "Creating the synthetic dataset" 14 | jupyter nbconvert --to notebook --execute make_synthetic_wires.ipynb --ExecutePreprocessor.timeout=1800 15 | 16 | # Create the hdf5 file 17 | cd ../ 18 | mkdir -p Models 19 | echo "Creating the hdf5 dataset" 20 | python create_h5_dataset.py ./Data/Nanowire/Raw/ ./Data/Nanowire/Synthetic/ ./Data/Nanowire/nanowire_dataset.h5 21 | 22 | # Train the network 23 | echo "Training the network" 24 | python main.py --dataset=./Data/Nanowire/nanowire_dataset.h5 --name=nanowire_new 25 | 26 | # Create the generated synthetic images 27 | python main.py --dataset=./Data/Nanowire/nanowire_dataset.h5 --name=nanowire_new --mode=gen_B 28 | -------------------------------------------------------------------------------- /Scripts/train_UDCT_dead_live_neurons.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Create the synthetic images for both live_vs_dead and colored_live neuron datasets 4 | cd ../notebooks/ 5 | echo "Creating the synthetic dataset" 6 | jupyter nbconvert --to notebook --execute make_synthetic_livedead_neurons_and_colored.ipynb --ExecutePreprocessor.timeout=1800 7 | 8 | # Create the raw dataset 9 | echo "Creating the raw dataset" 10 | mkdir -p ../Data/Neuron_Dead_Live/Raw 11 | cd ../Data/Neuron_Dead_Live/Raw/ 12 | wget https://downloads.lbb.ethz.ch/Data/lbb_raw_neuron_images.h5 13 | cd ../../../notebooks/ 14 | jupyter nbconvert --to notebook --execute make_raw_dead_live_neurons.ipynb --ExecutePreprocessor.timeout=1800 15 | 16 | # Create the hdf5 file 17 | cd ../ 18 | mkdir -p Models 19 | echo "Creating the hdf5 dataset" 20 | python create_h5_dataset.py ./Data/Neuron_Dead_Live/Raw/ ./Data/Neuron_Dead_Live/Synthetic/ ./Data/Neuron_Dead_Live/live_dead_neuron_dataset.h5 21 | 22 | # Train the network 23 | echo "Training the network" 24 | python main.py --dataset=./Data/Neuron_Dead_Live/live_dead_neuron_dataset.h5 --name=live_dead_neuron_new 25 | 26 | # Create the generated synthetic images 27 | python main.py --dataset=./Data/Neuron_Dead_Live/live_dead_neuron_dataset.h5 --name=live_dead_neuron_new --mode=gen_B 28 | -------------------------------------------------------------------------------- /Scripts/train_UDCT_c_elegans.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Download the original BBBC dataset and extract it 4 | cd ../Data/C_Elegans/Original 5 | rm ./00*.png 6 | rm ./1649_1109_0003*.tif 7 | wget https://data.broadinstitute.org/bbbc/BBBC010/BBBC010_v1_images.zip 8 | unzip BBBC010_v1_images.zip -d ./ 9 | rm BBBC010_v1_images.zip 10 | mv BBBC010_v1_images/* ./ 11 | rm -R BBBC010_v1_images 12 | 13 | # Transform the images to pngs 14 | cd ../../../notebooks/ 15 | jupyter nbconvert --to notebook --execute transform_c_elegans_tif_to_png.ipynb --ExecutePreprocessor.timeout=1800 16 | 17 | # Create the raw dataset 18 | echo "Creating the raw dataset" 19 | jupyter nbconvert --to notebook --execute make_raw_c_elegans.ipynb --ExecutePreprocessor.timeout=1800 20 | 21 | # Create the synthetic dataset 22 | echo "Creating the synthetic dataset" 23 | jupyter nbconvert --to notebook --execute make_synthetic_c_elegans.ipynb --ExecutePreprocessor.timeout=1800 24 | 25 | # Create the hdf5 file 26 | cd ../ 27 | mkdir -p Models 28 | echo "Creating the hdf5 dataset" 29 | python create_h5_dataset.py ./Data/C_Elegans/Raw/ ./Data/C_Elegans/Synthetic/ ./Data/C_Elegans/c_elegans_dataset.h5 30 | 31 | # Train the network 32 | echo "Training the network" 33 | python main.py --dataset=./Data/C_Elegans/c_elegans_dataset.h5 --name=c_elegans_new 34 | 35 | # Create the generated synthetic images 36 | python main.py --dataset=./Data/C_Elegans/c_elegans_dataset.h5 --name=c_elegans_new --mode=gen_B 37 | -------------------------------------------------------------------------------- /Discriminator/MultiPatch.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, unicode_literals 2 | 3 | import tensorflow as tf 4 | import PatchGAN34 5 | import PatchGAN70 6 | import PatchGAN142 7 | 8 | class MultiPatch: 9 | """ 10 | This class is creating a PatchGAN discriminator as described by Zhu et al. 2018. 11 | -) save() - Save the current model parameter 12 | -) create() - Create the model layers (graph construction) 13 | -) init() - Initialize the model (load model if exists) 14 | -) load() - Load the parameters from the file 15 | -) run() - ToDo write this 16 | 17 | Only the following functions should be called from outside: 18 | -) create() 19 | -) constructor 20 | """ 21 | 22 | def __init__(self,dis_name,noise=0.25): 23 | """ 24 | Create a PatchGAN model (init). It will check, if a model with such a name has already been saved. If so, the model 25 | is being loaded. Otherwise, a new model with this name will be created. It will only be saved, if the save function 26 | is being called. The describtion of every parameter is given in the code below. 27 | 28 | INPUT: dis_name - This is the name of the discriminator. It is mainly used to establish the place, where the model 29 | is being saved. 30 | 31 | OUTPUT: - The model 32 | """ 33 | self.dis_name = dis_name 34 | self.noise = noise 35 | 36 | self.Patch34 = PatchGAN34.PatchGAN34(self.dis_name,noise=self.noise) 37 | self.Patch70 = PatchGAN70.PatchGAN70(self.dis_name,noise=self.noise) 38 | self.Patch142 = PatchGAN142.PatchGAN142(self.dis_name,noise=self.noise) 39 | 40 | def create(self,X,reuse=True): 41 | 42 | self.out34 = self.Patch34.create(X,reuse) 43 | self.out70 = self.Patch70.create(X,reuse) 44 | self.out142 = self.Patch142.create(X,reuse) 45 | 46 | reshaped34 = tf.reshape(self.out34,[-1,tf.shape(self.out34)[1]*tf.shape(self.out34)[2],1]) 47 | reshaped70 = tf.reshape(self.out70,[-1,tf.shape(self.out70)[1]*tf.shape(self.out70)[2],1]) 48 | reshaped142 = tf.reshape(self.out142,[-1,tf.shape(self.out142)[1]*tf.shape(self.out142)[2],1]) 49 | 50 | self.prediction = tf.concat([reshaped34,reshaped70,reshaped142],axis=1) 51 | return self.prediction -------------------------------------------------------------------------------- /Discriminator/HisDis.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, unicode_literals 2 | 3 | import tensorflow as tf 4 | 5 | class HisDis: 6 | """ 7 | This class is creating a PatchGAN discriminator as described by Zhu et al. 2018. 8 | -) save() - Save the current model parameter 9 | -) create() - Create the model layers (graph construction) 10 | -) init() - Initialize the model (load model if exists) 11 | -) load() - Load the parameters from the file 12 | -) run() - ToDo write this 13 | 14 | Only the following functions should be called from outside: 15 | -) create() 16 | -) constructor 17 | """ 18 | 19 | def __init__(self,dis_name,noise=0.1,keep_prob=0.5): 20 | """ 21 | Create a histogramm discriminator 22 | 23 | INPUT: dis_name - This is the name of the discriminator. It is mainly used to establish the place, where the model 24 | is being saved. 25 | 26 | OUTPUT: - The model 27 | """ 28 | self.dis_name = dis_name 29 | self.noise = noise 30 | self.keep_prob = keep_prob 31 | 32 | 33 | def create(self,X,reuse=True): 34 | """ 35 | Create a histogramm discriminator 36 | 37 | INPUT: X - [None,256*n_chan] 38 | 39 | OUTPUT: - The HisDis prediction 40 | """ 41 | self.hidden_1 = tf.layers.dense(tf.nn.dropout(X-0.5+tf.random_normal(tf.shape(X),0.,self.noise),keep_prob=self.keep_prob), 42 | 64, 43 | reuse=reuse, 44 | name='dis_'+self.dis_name+'_hidden_1', 45 | activation=tf.nn.tanh) 46 | 47 | self.hidden_2 = tf.layers.dense(tf.nn.dropout(self.hidden_1,keep_prob=self.keep_prob), 48 | 64, 49 | reuse=reuse, 50 | name='dis_'+self.dis_name+'_hidden_2', 51 | activation=tf.nn.tanh) 52 | 53 | self.out = tf.layers.dense(tf.nn.dropout(self.hidden_2,keep_prob=self.keep_prob), 54 | 1, 55 | reuse=reuse, 56 | name='dis_'+self.dis_name+'_h_out', 57 | activation=None) + 0.5 58 | 59 | return self.out 60 | 61 | -------------------------------------------------------------------------------- /notebooks/make_raw_nanowire.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Creates the raw images for the nanowire dataset. The raw images can be downloaded from:\n", 8 | "\n", 9 | "https://downloads.lbb.ethz.ch/Data/lbb_raw_nanowire_images.h5" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "path_raw_images = '../Data/Nanowire/Raw/lbb_raw_nanowire_images.h5'" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import numpy as np\n", 28 | "import matplotlib\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "import h5py as h5\n", 31 | "\n", 32 | "import random\n", 33 | "\n", 34 | "import imageio as io\n", 35 | "\n", 36 | "from skimage.filters import gaussian" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "f = h5.File(path_raw_images,\"r\")" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "data = f['Raw/data'][...]/(2**8-1.)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "np.max(data)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "plt.figure(figsize=(10,10))\n", 73 | "plt.imshow(data[0,:,:,0],cmap='gray')\n", 74 | "plt.show()" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "for i in range(data.shape[0]):\n", 84 | " io.imsave('../Data/Nanowire/Raw/'+str(i).zfill(5)+'.png',data[i,:,:,0])" 85 | ] 86 | } 87 | ], 88 | "metadata": { 89 | "kernelspec": { 90 | "display_name": "Python 2", 91 | "language": "python", 92 | "name": "python2" 93 | }, 94 | "language_info": { 95 | "codemirror_mode": { 96 | "name": "ipython", 97 | "version": 2 98 | }, 99 | "file_extension": ".py", 100 | "mimetype": "text/x-python", 101 | "name": "python", 102 | "nbconvert_exporter": "python", 103 | "pygments_lexer": "ipython2", 104 | "version": "2.7.15" 105 | } 106 | }, 107 | "nbformat": 4, 108 | "nbformat_minor": 2 109 | } 110 | -------------------------------------------------------------------------------- /notebooks/make_raw_dead_live_neurons.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Creates the raw images for the dead vs live neuron dataset. The raw brightfield images can be downloaded from:\n", 8 | "\n", 9 | "https://downloads.lbb.ethz.ch/Data/lbb_raw_neuron_images.h5" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "path_raw_images = '../Data/Neuron_Dead_Live/Raw/lbb_raw_neuron_images.h5'" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import numpy as np\n", 28 | "import matplotlib\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "import h5py as h5\n", 31 | "\n", 32 | "import random\n", 33 | "\n", 34 | "import imageio as io\n", 35 | "\n", 36 | "from skimage.filters import gaussian" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "f = h5.File(path_raw_images,\"r\")" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "data = f['Raw/data'][...]/(2**16-1.)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "np.max(data)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "plt.figure(figsize=(10,10))\n", 73 | "plt.imshow(data[1,:,:,0],cmap='gray')\n", 74 | "plt.show()" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "for i in range(data.shape[0]):\n", 84 | " io.imsave('../Data/Neuron_Dead_Live/Raw/'+str(i).zfill(4)+'.png',data[i,:,:,0])" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [] 93 | } 94 | ], 95 | "metadata": { 96 | "kernelspec": { 97 | "display_name": "Python 2", 98 | "language": "python", 99 | "name": "python2" 100 | }, 101 | "language_info": { 102 | "codemirror_mode": { 103 | "name": "ipython", 104 | "version": 2 105 | }, 106 | "file_extension": ".py", 107 | "mimetype": "text/x-python", 108 | "name": "python", 109 | "nbconvert_exporter": "python", 110 | "pygments_lexer": "ipython2", 111 | "version": "2.7.15" 112 | } 113 | }, 114 | "nbformat": 4, 115 | "nbformat_minor": 2 116 | } 117 | -------------------------------------------------------------------------------- /Discriminator/PatchGAN34.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, unicode_literals 2 | 3 | import tensorflow as tf 4 | 5 | class PatchGAN34: 6 | """ 7 | This class is creating a PatchGAN discriminator as described by Zhu et al. 2018. 8 | -) save() - Save the current model parameter 9 | -) create() - Create the model layers (graph construction) 10 | -) init() - Initialize the model (load model if exists) 11 | -) load() - Load the parameters from the file 12 | -) run() - ToDo write this 13 | 14 | Only the following functions should be called from outside: 15 | -) create() 16 | -) constructor 17 | """ 18 | 19 | def __init__(self,dis_name,noise=0.25): 20 | """ 21 | Create a PatchGAN model (init). It will check, if a model with such a name has already been saved. If so, the model 22 | is being loaded. Otherwise, a new model with this name will be created. It will only be saved, if the save function 23 | is being called. The describtion of every parameter is given in the code below. 24 | 25 | INPUT: dis_name - This is the name of the discriminator. It is mainly used to establish the place, where the model 26 | is being saved. 27 | 28 | OUTPUT: - The model 29 | """ 30 | self.dis_name = dis_name 31 | self.noise = noise 32 | 33 | 34 | def create(self,X,reuse=True): 35 | 36 | # C128 37 | # To add noise: 38 | self.C128_c = tf.layers.conv2d(tf.pad(X+tf.random_normal(tf.shape(X),0.,self.noise),[[0,0],[1,1],[1,1],[0,0]],"Reflect"), 39 | filters=128, 40 | kernel_size=4, 41 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 42 | strides=(2,2), 43 | padding='valid', 44 | reuse=reuse, 45 | name='dis_'+self.dis_name+'_34_conv_1') 46 | self.C128_n = tf.contrib.layers.instance_norm(self.C128_c,reuse=reuse,scope='dis_'+self.dis_name+'_34_bnorm_1',trainable=False) 47 | self.C128 = tf.nn.leaky_relu(self.C128_n) 48 | 49 | # C256 50 | self.C256_c = tf.layers.conv2d(tf.pad(self.C128,[[0,0],[1,1],[1,1],[0,0]],"Reflect"), 51 | filters=256, 52 | kernel_size=4, 53 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 54 | strides=(2,2), 55 | padding='valid', 56 | reuse=reuse, 57 | name='dis_'+self.dis_name+'_34_conv_2') 58 | self.C256_n = tf.contrib.layers.instance_norm(self.C256_c,reuse=reuse,scope='dis_'+self.dis_name+'_34_bnorm_2',trainable=False) 59 | self.C256 = tf.nn.leaky_relu(self.C256_n) 60 | 61 | # C512 62 | self.C512_c = tf.layers.conv2d(tf.pad(self.C256,[[0,0],[1,1],[1,1],[0,0]],"Reflect"), 63 | filters=512, 64 | kernel_size=4, 65 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 66 | strides=(1,1), 67 | padding='valid', 68 | reuse=reuse, 69 | name='dis_'+self.dis_name+'_34_conv_3') 70 | self.C512_n = tf.contrib.layers.instance_norm(self.C512_c,reuse=reuse,scope='dis_'+self.dis_name+'_34_bnorm_3',trainable=False) 71 | self.C512 = tf.nn.leaky_relu(self.C512_n) 72 | 73 | # c1 74 | self.c1_c = tf.layers.conv2d(tf.pad(self.C512,[[0,0],[1,1],[1,1],[0,0]],"Reflect"), 75 | filters=1, 76 | kernel_size=4, 77 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 78 | strides=(1,1), 79 | padding='valid', 80 | reuse=reuse, 81 | name='dis_'+self.dis_name+'_34_conv_4') 82 | return self.c1_c -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UDCT 2 | You can find here the Cycle-GAN network (Tensorflow, Python 2.7+) with our histogram loss. Additionally, we provide the scripts we used to generate the synthetic datasets. 3 | 4 | Our results can be found at https://www.biorxiv.org/content/biorxiv/early/2019/03/01/563734.full.pdf 5 | 6 | ## How to use 7 | 1. Clone or download the repository 8 | 2. Create a synthetic dataset similar to your real dataset or use example dataset in ./Data/Example 9 | 3. Execute:
python create_h5_dataset.py <directory_of_raw_images> <directory_of_syn_images> <filename_of_hdf5_file>  
10 | Example:
python create_h5_dataset.py ./Data/Example/Genuine/ \\
./Data/Example/Synthetic/ ./Data/Example/example_dataset.h5
11 | 4. Create the directory 'Models' in the root directory 12 | 5. Execute:
 python main.py --dataset=./Data/..../dataset.h5 --name=name_of_model 
13 | Example:
 python main.py --dataset=./Data/Example/example_dataset.h5 --name=example_model 
14 | 6. This will create a network that is saved in ./Models/ along with a parameter textfile. Furthermore, the average loss terms for each epoch are saved in this directory. 15 | 7. To generate the results after training, use: 16 |
 python main.py --dataset=./Data/Example/example_dataset.h5 --name=example_model --mode=gen_B 
17 | The generated synthetic images can be found in ./Models/<name_of_model>_gen_B.h5 18 | 19 | ### Parameters 20 | 21 | All parameters are of the shape: --<parameter_name>=<value>
22 | Below is the list of all possible parameters that can be set. The standard value used if the parameter is not defined is given in brackets 23 | 24 | name ('unnamed')
25 | Name of the model. This value should be unique to not load/overwrite old models. Its value must be changed to ensure functionality! 26 |
27 | 28 | dataset ('pathtodata.h5')
29 | Describes which h5 files is used. Its value must be changed to ensure functionality! 30 |
31 | 32 | architecture ('Res6')
33 | The network architecture for the generators. Currently, you can choose between 'Res6' and 'Res9', which corresponds to 6 and 9 residual layers, respectively. 34 |
35 | 36 | deconv ('transpose')
37 | Upsampling method used in the generators. You can either choose transpose CNNs ('transpose') or image resizing ('resize'). 38 |
39 | 40 | PatchGAN ('Patch70')
41 | Different PatchGAN Architectures: 'Patch34', 'Patch70', or 'Patch142'. A mixture of these is possible: 'MultiPatch' (experimental). 42 |
43 | 44 | mode ('training')
45 | Decides what should be done with the network. You can either train it ('training'), or create generated images: 'gen_A' for raw images from synthetic images and 'gen_B' for synthetic images from raw images. 46 |
47 | 48 | 49 | 50 | dataset ('pathtodata.h5')
51 | Describes which h5 files is used. Its value must be changed to ensure functionality! 52 |
53 | 54 | lambda_c (10.)
55 | The loss multiplier of the cycle consistency term used while training the generators. 56 |
57 | 58 | lambda_h (1.)
59 | The loss multiplier of the histogram discriminators. If the histogram should not be used, set this term to 0. 60 |
61 | 62 | dis_noise (0.1)
63 | To make the network more stable, we added noise to the input of the discriminators, which slowly decays over time. This value describes how high the std of the gaussian noise is, which is added to the inputs. 64 |
65 | 66 | syn_noise (0.)
67 | It is possible to add gaussian noise to the synthetic dataset. Default: not used. 68 |
69 | 70 | real_noise (0.)
71 | It is possible to add gaussian noise to the real dataset. Default: not used. 72 |
73 | 74 | epoch (200)
75 | Number of training epochs. 76 |
77 | 78 | batch_size (4)
79 | Batch size during training. 80 |
81 | 82 | buffer_size (50)
83 | Size of the buffer (history) saved to train the discriminators. This makes the network more stable. 84 |
85 | 86 | save (1)
87 | If value is not 0, the network progress is saved at the end of each epoch. 88 |
89 | 90 | gpu (0)
91 | If multiple GPUs exist, this parameter choses which GPU should be used. Only one GPU can currently be used. 92 |
93 | 94 | verbose (0)
95 | If value is not 0, the network is more verbose. 96 |
97 | -------------------------------------------------------------------------------- /Discriminator/PatchGAN70.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, unicode_literals 2 | 3 | import tensorflow as tf 4 | 5 | class PatchGAN70: 6 | """ 7 | This class is creating a PatchGAN discriminator as described by Zhu et al. 2018. 8 | -) save() - Save the current model parameter 9 | -) create() - Create the model layers (graph construction) 10 | -) init() - Initialize the model (load model if exists) 11 | -) load() - Load the parameters from the file 12 | -) run() - ToDo write this 13 | 14 | Only the following functions should be called from outside: 15 | -) create() 16 | -) constructor 17 | """ 18 | 19 | def __init__(self,dis_name,noise=0.25): 20 | """ 21 | Create a PatchGAN model (init). It will check, if a model with such a name has already been saved. If so, the model 22 | is being loaded. Otherwise, a new model with this name will be created. It will only be saved, if the save function 23 | is being called. The describtion of every parameter is given in the code below. 24 | 25 | INPUT: dis_name - This is the name of the discriminator. It is mainly used to establish the place, where the model 26 | is being saved. 27 | 28 | OUTPUT: - The model 29 | """ 30 | self.dis_name = dis_name 31 | self.noise = noise 32 | 33 | 34 | def create(self,X,reuse=True): 35 | 36 | # C64 37 | # To add noise: 38 | self.C64_c = tf.layers.conv2d(tf.pad(X+tf.random_normal(tf.shape(X),0.,self.noise),[[0,0],[1,1],[1,1],[0,0]],"Reflect"), 39 | filters=64, 40 | kernel_size=4, 41 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 42 | strides=(2,2), 43 | padding='valid', 44 | reuse=reuse, 45 | name='dis_'+self.dis_name+'_70_conv_0') 46 | self.C64 = tf.nn.leaky_relu(self.C64_c) 47 | 48 | # C128 49 | self.C128_c = tf.layers.conv2d(tf.pad(self.C64,[[0,0],[1,1],[1,1],[0,0]],"Reflect"), 50 | filters=128, 51 | kernel_size=4, 52 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 53 | strides=(2,2), 54 | padding='valid', 55 | reuse=reuse, 56 | name='dis_'+self.dis_name+'_70_conv_1') 57 | self.C128_n = tf.contrib.layers.instance_norm(self.C128_c,reuse=reuse,scope='dis_'+self.dis_name+'_70_bnorm_1',trainable=False) 58 | self.C128 = tf.nn.leaky_relu(self.C128_n) 59 | 60 | # C256 61 | self.C256_c = tf.layers.conv2d(tf.pad(self.C128,[[0,0],[1,1],[1,1],[0,0]],"Reflect"), 62 | filters=256, 63 | kernel_size=4, 64 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 65 | strides=(2,2), 66 | padding='valid', 67 | reuse=reuse, 68 | name='dis_'+self.dis_name+'_70_conv_2') 69 | self.C256_n = tf.contrib.layers.instance_norm(self.C256_c,reuse=reuse,scope='dis_'+self.dis_name+'_70_bnorm_2',trainable=False) 70 | self.C256 = tf.nn.leaky_relu(self.C256_n) 71 | 72 | # C512 73 | self.C512_c = tf.layers.conv2d(tf.pad(self.C256,[[0,0],[1,1],[1,1],[0,0]],"Reflect"), 74 | filters=512, 75 | kernel_size=4, 76 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 77 | strides=(1,1), 78 | padding='valid', 79 | reuse=reuse, 80 | name='dis_'+self.dis_name+'_70_conv_3') 81 | self.C512_n = tf.contrib.layers.instance_norm(self.C512_c,reuse=reuse,scope='dis_'+self.dis_name+'_70_bnorm_3',trainable=False) 82 | self.C512 = tf.nn.leaky_relu(self.C512_n) 83 | 84 | # c1 85 | self.c1_c = tf.layers.conv2d(tf.pad(self.C512,[[0,0],[1,1],[1,1],[0,0]],"Reflect"), 86 | filters=1, 87 | kernel_size=4, 88 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 89 | strides=(1,1), 90 | padding='valid', 91 | reuse=reuse, 92 | name='dis_'+self.dis_name+'_70_conv_4') 93 | return self.c1_c -------------------------------------------------------------------------------- /create_h5_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | import numpy as np 4 | import h5py 5 | import cv2 6 | import os 7 | import sys 8 | 9 | def get_file_list(data_path): 10 | """ 11 | This function returns the list of all png images in a given directory, such as their dimensions. It returns an error, if the image dimensions are not consistent. 12 | 13 | Arguments: 14 | data_path (string) Path to directory of the set of images to be extracted 15 | 16 | Returns: 17 | file_list (List of strings) Filenames of all png images in dataset 18 | dimensions (3 x integer) Dimensions of images: height, width, channels 19 | flag (boolean) Is true, iff the images are grayscale 20 | """ 21 | 22 | # Create list of all png files 23 | file_list = [] 24 | for element in os.listdir(data_path): 25 | if element[-4:] == ".png": 26 | file_list.append(data_path + element) 27 | 28 | 29 | # Compare sizes 30 | dimensions = cv2.imread(file_list[0],cv2.IMREAD_UNCHANGED).shape 31 | for i in range(1,len(file_list)): 32 | if not np.array_equal(dimensions,cv2.imread(file_list[i],cv2.IMREAD_UNCHANGED).shape): 33 | raise Exception('The following two images have different dimensions. Please make sure all images in this directory have the same size \n\r ' +\ 34 | file_list[0] + '\n\r' +\ 35 | file_list[i]) 36 | 37 | # Add a 3rd value two dimensions, if it does not exist (this means it has 1 data channel) 38 | flag = False 39 | if len(dimensions) == 2: 40 | dimensions = np.array([dimensions[0],dimensions[1],1]) 41 | flag = True 42 | 43 | return file_list,dimensions,flag 44 | 45 | def main(): 46 | # Check if the right amount of arguments has been given to the program 47 | if len(sys.argv[1:]) != 3: 48 | print('This script recuires three arguments in order to work:') 49 | print('1: Path to directory containing the genuine/raw images (only png images in directory are used!)') 50 | print('2: Path to directory containing the synthetic images (only png images in directory are used!)') 51 | print('3: Output hdf5 filename') 52 | print(' ') 53 | print('Example: python create_h5_dataset.py ./Data/Example/Genuine/ ./Data/Example/Synthetic/ ./Data/Example/example_dataset.h5') 54 | print(' ') 55 | print('Script aborted') 56 | return -1 57 | 58 | # get the addresses 59 | raw_path = sys.argv[1] 60 | syn_path = sys.argv[2] 61 | filename = sys.argv[3] 62 | 63 | # Create the output hdf5 file 64 | f = h5py.File(filename,"w") 65 | 66 | # Save the raw dataset into the file 67 | raw_files, raw_dimensions, raw_flag = get_file_list(raw_path) 68 | 69 | num_samples = len(raw_files) 70 | num_channel = raw_dimensions[2] 71 | 72 | group = f.create_group('A') 73 | group.create_dataset(name='num_samples', data=num_samples) 74 | group.create_dataset(name='num_channel', data=num_channel) 75 | dtype = np.uint8 76 | 77 | data_A = np.zeros([num_samples,\ 78 | raw_dimensions[0],\ 79 | raw_dimensions[1],\ 80 | num_channel], dtype=dtype) 81 | 82 | for idx,fname in enumerate(raw_files): 83 | if raw_flag: # This means, the images are gray scale 84 | data_A[idx,:,:,0] = np.array(cv2.imread(fname,cv2.IMREAD_GRAYSCALE)) 85 | else: 86 | data_A[idx,:,:,:] = np.flip(np.array(cv2.imread(fname,cv2.IMREAD_COLOR)),2) 87 | 88 | print('Genuine dataset: ', group.create_dataset(name='data', data=(data_A),dtype=dtype)) 89 | 90 | 91 | 92 | 93 | 94 | # Save the syn dataset into the file 95 | syn_files, syn_dimensions, syn_flag = get_file_list(syn_path) 96 | 97 | num_samples = len(syn_files) 98 | num_channel = syn_dimensions[2] 99 | 100 | group = f.create_group('B') 101 | group.create_dataset(name='num_samples', data=num_samples) 102 | group.create_dataset(name='num_channel', data=num_channel) 103 | dtype = np.uint8 104 | 105 | data_B = np.zeros([num_samples,\ 106 | syn_dimensions[0],\ 107 | syn_dimensions[1],\ 108 | num_channel], dtype=dtype) 109 | 110 | for idx,fname in enumerate(syn_files): 111 | if syn_flag: # This means, the images are gray scale 112 | data_B[idx,:,:,0] = np.array(cv2.imread(fname,cv2.IMREAD_GRAYSCALE)) 113 | else: 114 | data_B[idx,:,:,:] = np.flip(np.array(cv2.imread(fname,cv2.IMREAD_COLOR)),2) 115 | 116 | print('Synthetic dataset: ', group.create_dataset(name='data', data=(data_B),dtype=dtype)) 117 | 118 | 119 | # Close the file 120 | f.close() 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /notebooks/make_synthetic_wires.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "%matplotlib inline\n", 13 | "import os\n", 14 | "import colorsys\n", 15 | "import imageio as io\n", 16 | "from scipy.ndimage import binary_dilation\n", 17 | "from skimage.filters import gaussian\n", 18 | "from numpy import random\n", 19 | "from scipy.stats import skewnorm" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "def gen_fake_img(imsize,length_var,width_var,rng_wire_bounds,rng_length_loc_scale,fix_length,fix_width):\n", 29 | " #image = np.zeros(( imsize,imsize, 3)) #replace with noise for stability\n", 30 | " image = np.random.normal(loc=0.1,scale=0.05,size=(imsize,imsize,3))\n", 31 | " saturation_map=np.random.normal(loc=0.8,scale=0.2,size=(imsize,imsize))\n", 32 | " saturation_map=gaussian(saturation_map,sigma=3) #introduce long range noise correlations\n", 33 | " hue_map=np.random.normal(loc=0.8,scale=0.2,size=(imsize,imsize))\n", 34 | " hue_map=gaussian(hue_map,sigma=2)\n", 35 | " n_wires=np.random.randint(rng_wire_bounds[0],rng_wire_bounds[1])\n", 36 | " for n in range(n_wires):\n", 37 | " if length_var:\n", 38 | " length=skewnorm.rvs(a=3,loc=rng_length_loc_scale[0],scale=rng_length_loc_scale[1],size=1)\n", 39 | " if not length_var:\n", 40 | " length=np.array([fix_length])\n", 41 | "\n", 42 | " x_loc = np.array([np.random.rand() * imsize]) \n", 43 | " y_loc = np.array(np.random.rand() * imsize) \n", 44 | " angle = np.random.random() * np.pi\n", 45 | " vec_x=np.cos(angle)\n", 46 | " vec_y=np.sin(angle)\n", 47 | " x_wires = (vec_x.reshape(1, 1) * length / 2. * np.linspace(-1, 1, 200) + x_loc.reshape(1,1)).astype(int)\n", 48 | " y_wires = (vec_y.reshape(1, 1) * length / 2. * np.linspace(-1, 1, 200) + y_loc.reshape(1,1)).astype(int)\n", 49 | " indices = (x_wires * imsize + y_wires).flatten() \n", 50 | " to_rem=np.where(indices>=imsize*imsize)[0] #make sure wire doesnt go out of image\n", 51 | " indices=[indices[i] for i in range(len(indices)) if i not in to_rem]\n", 52 | "\n", 53 | " wire = np.zeros([imsize*imsize], dtype=np.double)\n", 54 | " wire[indices] = 1\n", 55 | " wire = wire.reshape( imsize, imsize)\n", 56 | "\n", 57 | " if width_var:\n", 58 | " it=np.random.choice(np.arange(1,6),p=[0.7,0.2,0.07,0.02,0.01])\n", 59 | " if not width_var:\n", 60 | " it=fix_width\n", 61 | " wire = binary_dilation(wire,iterations=it)\n", 62 | " wire=wire.astype(bool)\n", 63 | " image[:,:,0][wire]=angle/np.pi #hue\n", 64 | " image[:,:,1][wire]=hue_map[wire] #saturation\n", 65 | " image[:,:,2][wire]=saturation_map[wire] #brightness\n", 66 | " image=np.clip(image,0,1)\n", 67 | " image=matplotlib.colors.hsv_to_rgb(image) #Better visualization, also probably better use of all 3 channels (hsv varies only little in channel 2 and 3)\n", 68 | " return image" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "params={'imsize': 256,\n", 78 | " 'length_var': True,\n", 79 | " 'width_var': True,\n", 80 | " 'rng_wire_bounds': [10,40],\n", 81 | " 'rng_length_loc_scale': [20,60],\n", 82 | " 'fix_length': 50,\n", 83 | " 'fix_width': 1}\n", 84 | "N_images=500" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "example=gen_fake_img(**params)\n", 94 | "plt.imshow(example)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "directory='../Data/Nanowire/Synthetic/'" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "for i in range(N_images):\n", 113 | " name=directory+str(i).zfill(5)+'.png'\n", 114 | " image=(gen_fake_img(**params)*255).astype(np.uint8) #map to 0-255,8 bit image\n", 115 | " io.imsave(name,image)\n", 116 | " if i%50==0:\n", 117 | " print(i)" 118 | ] 119 | } 120 | ], 121 | "metadata": { 122 | "kernelspec": { 123 | "display_name": "Python 2", 124 | "language": "python", 125 | "name": "python2" 126 | }, 127 | "language_info": { 128 | "codemirror_mode": { 129 | "name": "ipython", 130 | "version": 2 131 | }, 132 | "file_extension": ".py", 133 | "mimetype": "text/x-python", 134 | "name": "python", 135 | "nbconvert_exporter": "python", 136 | "pygments_lexer": "ipython2", 137 | "version": "2.7.15" 138 | } 139 | }, 140 | "nbformat": 4, 141 | "nbformat_minor": 2 142 | } 143 | -------------------------------------------------------------------------------- /Discriminator/PatchGAN142.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, unicode_literals 2 | 3 | import tensorflow as tf 4 | 5 | class PatchGAN142: 6 | """ 7 | This class is creating a PatchGAN discriminator as described by Zhu et al. 2018. 8 | -) save() - Save the current model parameter 9 | -) create() - Create the model layers (graph construction) 10 | -) init() - Initialize the model (load model if exists) 11 | -) load() - Load the parameters from the file 12 | -) run() - ToDo write this 13 | 14 | Only the following functions should be called from outside: 15 | -) create() 16 | -) constructor 17 | """ 18 | 19 | def __init__(self,dis_name,noise=0.25): 20 | """ 21 | Create a PatchGAN model (init). It will check, if a model with such a name has already been saved. If so, the model 22 | is being loaded. Otherwise, a new model with this name will be created. It will only be saved, if the save function 23 | is being called. The describtion of every parameter is given in the code below. 24 | 25 | INPUT: dis_name - This is the name of the discriminator. It is mainly used to establish the place, where the model 26 | is being saved. 27 | 28 | OUTPUT: - The model 29 | """ 30 | self.dis_name = dis_name 31 | self.noise = noise 32 | 33 | 34 | def create(self,X,reuse=True): 35 | 36 | # C32 37 | # To add noise: 38 | self.C32_c = tf.layers.conv2d(tf.pad(X+tf.random_normal(tf.shape(X),0.,self.noise),[[0,0],[1,1],[1,1],[0,0]],"Reflect"), 39 | filters=32, 40 | kernel_size=4, 41 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 42 | strides=(2,2), 43 | padding='valid', 44 | reuse=reuse, 45 | name='dis_'+self.dis_name+'_142_conv_0') 46 | self.C32 = tf.nn.leaky_relu(self.C32_c) 47 | 48 | # C128 49 | self.C64_c = tf.layers.conv2d(tf.pad(self.C32,[[0,0],[1,1],[1,1],[0,0]],"Reflect"), 50 | filters=64, 51 | kernel_size=4, 52 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 53 | strides=(2,2), 54 | padding='valid', 55 | reuse=reuse, 56 | name='dis_'+self.dis_name+'_142_conv_1') 57 | self.C64_n = tf.contrib.layers.instance_norm(self.C64_c,reuse=reuse,scope='dis_'+self.dis_name+'_142_bnorm_1',trainable=False) 58 | self.C64 = tf.nn.leaky_relu(self.C64_n) 59 | 60 | # C128 61 | self.C128_c = tf.layers.conv2d(tf.pad(self.C64,[[0,0],[1,1],[1,1],[0,0]],"Reflect"), 62 | filters=128, 63 | kernel_size=4, 64 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 65 | strides=(2,2), 66 | padding='valid', 67 | reuse=reuse, 68 | name='dis_'+self.dis_name+'_142_conv_2') 69 | self.C128_n = tf.contrib.layers.instance_norm(self.C128_c,reuse=reuse,scope='dis_'+self.dis_name+'_142_bnorm_2',trainable=False) 70 | self.C128 = tf.nn.leaky_relu(self.C128_n) 71 | 72 | # C256 73 | self.C256_c = tf.layers.conv2d(tf.pad(self.C128,[[0,0],[1,1],[1,1],[0,0]],"Reflect"), 74 | filters=256, 75 | kernel_size=4, 76 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 77 | strides=(2,2), 78 | padding='valid', 79 | reuse=reuse, 80 | name='dis_'+self.dis_name+'_142_conv_3') 81 | self.C256_n = tf.contrib.layers.instance_norm(self.C256_c,reuse=reuse,scope='dis_'+self.dis_name+'_142_bnorm_3',trainable=False) 82 | self.C256 = tf.nn.leaky_relu(self.C256_n) 83 | 84 | # C512 85 | self.C512_c = tf.layers.conv2d(tf.pad(self.C256,[[0,0],[1,1],[1,1],[0,0]],"Reflect"), 86 | filters=512, 87 | kernel_size=4, 88 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 89 | strides=(1,1), 90 | padding='valid', 91 | reuse=reuse, 92 | name='dis_'+self.dis_name+'_142_conv_4') 93 | self.C512_n = tf.contrib.layers.instance_norm(self.C512_c,reuse=reuse,scope='dis_'+self.dis_name+'_142_bnorm_4',trainable=False) 94 | self.C512 = tf.nn.leaky_relu(self.C512_n) 95 | 96 | # c1 97 | self.c1_c = tf.layers.conv2d(tf.pad(self.C512,[[0,0],[1,1],[1,1],[0,0]],"Reflect"), 98 | filters=1, 99 | kernel_size=4, 100 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 101 | strides=(1,1), 102 | padding='valid', 103 | reuse=reuse, 104 | name='dis_'+self.dis_name+'_142_conv_5') 105 | return self.c1_c -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import cycleGAN 2 | import re 3 | import sys 4 | from os import environ as cuda_environment 5 | import os 6 | import numpy as np 7 | 8 | if __name__ == "__main__": 9 | # List of floats 10 | sub_value_f = {} 11 | sub_value_f['lambda_c'] = 10. # Loss multiplier for cycle 12 | sub_value_f['lambda_h'] = 1. # Loss multiplier for histogram 13 | sub_value_f['dis_noise'] = 0.1 # Std of gauss noise added to Dis 14 | sub_value_f['syn_noise'] = 0. # Add gaussian noise to syn images to make non-flat backgrounds 15 | sub_value_f['real_noise'] = 0. # Add gaussian noise to real images to make non-flat backgrounds 16 | 17 | # List of ints 18 | sub_value_i = {} 19 | sub_value_i['epoch'] = 200 # Number of epochs to be trained 20 | sub_value_i['batch_size'] = 4 # Batch size for training 21 | sub_value_i['buffer_size'] = 50 # Number of history elements used for Dis 22 | sub_value_i['save'] = 1 # If not 0, model is saved 23 | sub_value_i['gpu'] = 0 # Choose the GPU ID (if only CPU training, choose nonexistent number) 24 | sub_value_i['verbose'] = 0 # If not 0, some network information is being plotted 25 | 26 | # List of strings 27 | sub_string = {} 28 | sub_string['name'] = 'unnamed' # Name of model (should be unique). Is used to save/load models 29 | sub_string['dataset'] = 'pathtodata.h5' # Describes which h5 file is used 30 | sub_string['architecture'] = 'Res6' # Network architecture: 'Res6' or 'Res9' 31 | sub_string['deconv'] = 'transpose' # Upsampling method: 'transpose' or 'resize' 32 | sub_string['PatchGAN'] = 'Patch70' # Choose the Gan type: 'Patch34', 'Patch70', 'Patch142', 'MultiPatch' 33 | sub_string['mode'] = 'training' # 'train', 'gen_A', 'gen_B' 34 | 35 | # Create complete dictonary 36 | var_dict = sub_string.copy() 37 | var_dict.update(sub_value_i) 38 | var_dict.update(sub_value_f) 39 | 40 | # Update all defined parameters in dictionary 41 | for arg_i in sys.argv[1:]: 42 | var = re.search('(.*)\=', arg_i) # everything before the '=' 43 | g_var = var.group(1)[2:] 44 | if g_var in sub_value_i: 45 | dtype = 'int' 46 | elif g_var in sub_value_f: 47 | dtype = 'float' 48 | elif g_var in sub_string: 49 | dtype = 'string' 50 | else: 51 | print("Unknown key word: " + g_var) 52 | print("Write parameters as: =") 53 | print("Example: 'python main.py buffer_size=32'") 54 | print("Possible key words: " + str(var_dict.keys())) 55 | continue 56 | 57 | content = re.search('\=(.*)',arg_i) # everything after the '=' 58 | g_content = content.group(1) 59 | if dtype == 'int': 60 | var_dict[g_var] = int(g_content) 61 | elif dtype == 'float': 62 | var_dict[g_var] = float(g_content) 63 | else: 64 | var_dict[g_var] = g_content 65 | if not os.path.isfile(var_dict['dataset']): 66 | raise ValueError('Dataset does not exist. Specify loation of an existing h5 file.') 67 | # Get the dataset filename 68 | 69 | 70 | # Restrict usage of GPUs 71 | cuda_environment["CUDA_VISIBLE_DEVICES"]=str(var_dict['gpu']) 72 | with open('Models/'+var_dict['name']+"_params.txt", "w") as myfile: 73 | for key in sorted(var_dict): 74 | myfile.write(key + "," + str(var_dict[key]) + "\n") 75 | 76 | # Find out, if whole network is needed or only the generators 77 | gen_only = False 78 | if 'gen' in var_dict['mode']: 79 | gen_only = True 80 | 81 | # Define the model 82 | model = cycleGAN.Model(\ 83 | mod_name=var_dict['name'],\ 84 | data_file=var_dict['dataset'],\ 85 | buffer_size=var_dict['buffer_size'],\ 86 | dis_noise=var_dict['dis_noise'],\ 87 | architecture=var_dict['architecture'],\ 88 | lambda_c=var_dict['lambda_c'],\ 89 | lambda_h=var_dict['lambda_h'],\ 90 | deconv=var_dict['deconv'],\ 91 | patchgan=var_dict['PatchGAN'],\ 92 | verbose=(var_dict['verbose']!=0),\ 93 | gen_only=gen_only) 94 | 95 | # Plot parameter properties, if applicable 96 | if var_dict['verbose']: 97 | # Print the number of parameters 98 | model.print_count_variables() 99 | model.print_train_and_not_train_variables() 100 | 101 | # Create a graph file 102 | model.save_graph() 103 | 104 | elif var_dict['mode'] == 'training': 105 | # Train the model 106 | loss_gen_A = [] 107 | loss_gen_B = [] 108 | loss_dis_A = [] 109 | loss_dis_B = [] 110 | 111 | for i in range(var_dict['epoch']): 112 | print('') 113 | print('Epoch: ' + str(i+1)) 114 | print('') 115 | lgA,lgB,ldA,ldB = \ 116 | model.train(batch_size=var_dict['batch_size'],\ 117 | lambda_c=var_dict['lambda_c'],\ 118 | lambda_h=var_dict['lambda_h'],\ 119 | save=bool(var_dict['save']),\ 120 | epoch=i,\ 121 | syn_noise=var_dict['syn_noise'],\ 122 | real_noise=var_dict['real_noise']) 123 | loss_gen_A.append(lgA) 124 | loss_gen_B.append(lgB) 125 | loss_dis_A.append(ldA) 126 | loss_dis_B.append(ldB) 127 | np.save("./Models/" + var_dict['name'] + '_loss_gen_A.npy',np.array(loss_gen_A).T) 128 | np.save("./Models/" + var_dict['name'] + '_loss_gen_B.npy',np.array(loss_gen_B).T) 129 | np.save("./Models/" + var_dict['name'] + '_loss_dis_A.npy',np.array(loss_dis_A).T) 130 | np.save("./Models/" + var_dict['name'] + '_loss_dis_B.npy',np.array(loss_dis_B).T) 131 | 132 | elif var_dict['mode'] == 'gen_A': 133 | model.generator_A(batch_size=var_dict['batch_size'],\ 134 | lambda_c=var_dict['lambda_c'],\ 135 | lambda_h=var_dict['lambda_h']) 136 | 137 | elif var_dict['mode'] == 'gen_B': 138 | model.generator_B(batch_size=var_dict['batch_size'],\ 139 | lambda_c=var_dict['lambda_c'],\ 140 | lambda_h=var_dict['lambda_h']) 141 | -------------------------------------------------------------------------------- /Generator/Res_Gen.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, unicode_literals 2 | 3 | import tensorflow as tf 4 | 5 | class ResGen: 6 | """ 7 | This class is creating a ResNet generator as described by Zhu et al. 2018. 8 | -) save() - Save the current model parameter 9 | -) create() - Create the model layers (graph construction) 10 | -) init() - Initialize the model (load model if exists) 11 | -) load() - Load the parameters from the file 12 | -) run() - ToDo write this 13 | 14 | Only the following functions should be called from outside: 15 | -) create() 16 | -) constructor 17 | """ 18 | 19 | def __init__(self, 20 | gen_name, 21 | out_dim, 22 | gen_dim= [32, 64,128, 128,128,128,128,128,128, 64,32 ], 23 | kernel_size=[7, 3,3, 3,3,3,3,3,3, 3,3, 7], 24 | deconv='transpose', 25 | verbose=False): 26 | 27 | # gen_dim= [64, 128,256, 256,256,256,256,256,256,256,256,256, 128,64 ], 28 | # kernel_size=[7, 3,3, 3,3,3,3,3,3,3,3,3, 3,3, 7]): 29 | """ 30 | Create a generator model (init). It will check, if a model with such a name has already been saved. If so, the model 31 | is being loaded. Otherwise, a new model with this name will be created. It will only be saved, if the save function 32 | is being called. The describtion of every parameter is given in the code below. 33 | 34 | INPUT: gen_name - This is the name of the generator. It is mainly used to establish the place, where the model 35 | is being saved. 36 | gen_dim - The number of channels in every layer. The first two elements are stride-2 convolutional layers. 37 | The last two layers (1 value) are fractional stride 1/2 convolutional layers. Everything in between 38 | is a ResNet layer. 39 | 40 | kernel_size - The kernel sizes of all layers. The dimension of this list must be the same as of gen_dim + 1. 41 | 42 | OUTPUT: - The model 43 | """ 44 | self.gen_name = gen_name 45 | self.out_dim = out_dim 46 | self.gen_dim = gen_dim 47 | self.kernel_size = kernel_size 48 | self.deconv = deconv 49 | self.verbose = verbose 50 | 51 | if len(gen_dim) + 1 != len(kernel_size): 52 | raise NameError('The dimensions of the ResGenerator are wrong') 53 | 54 | 55 | 56 | def create(self,X,reuse=True): 57 | num_layers = len(self.kernel_size) 58 | 59 | layer_list = [] 60 | layer_list.append(X) 61 | 62 | if self.verbose: 63 | print('-------------------------------') 64 | print(' ') 65 | print('Create generator ' + self.gen_name) 66 | print(' ') 67 | print('Number of layers: ' + str(num_layers)) 68 | print('Input diminesion: ' + str(tf.shape(X))) 69 | print(' ') 70 | 71 | for i in range(num_layers): 72 | if (i < (num_layers-3)) or (i==(num_layers - 1)): 73 | ps = int((self.kernel_size[i]-1)/2) # pad size 74 | new_pad = tf.pad(layer_list[-1],[[0,0],[ps,ps],[ps,ps],[0,0]],"Reflect") 75 | if self.verbose: 76 | print('- - - - - - - - - - - - - - - -') 77 | print('Load last layer: Do padding with size ' + str(ps)) 78 | else: 79 | new_pad = layer_list[-1] 80 | if self.verbose: 81 | print('- - - - - - - - - - - - - - - -') 82 | print('Load last layer: No padding') 83 | if i==0 or i==(num_layers - 1): 84 | if i==0: 85 | filters = self.gen_dim[i] 86 | else: 87 | filters = self.out_dim 88 | new_conv = tf.layers.conv2d(new_pad, 89 | filters=filters, 90 | kernel_size=self.kernel_size[i], 91 | strides=(1,1), 92 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 93 | padding='valid', 94 | reuse=reuse, 95 | name='gen_'+self.gen_name+'_conv_'+str(i)) 96 | if self.verbose: 97 | print('Conv layer stride 1 with kernel size ' + str(self.kernel_size[i]) + ' and number of filters ' + str(filters)) 98 | elif i < 3: 99 | # Conv layers 100 | new_conv = tf.layers.conv2d(new_pad, 101 | filters=self.gen_dim[i], 102 | kernel_size=self.kernel_size[i], 103 | strides=(2,2), 104 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 105 | padding='valid', 106 | reuse=reuse, 107 | name='gen_'+self.gen_name+'_conv_'+str(i)) 108 | if self.verbose: 109 | print('Conv layer stride 2 with kernel size ' + str(self.kernel_size[i]) + ' and number of filters ' + str(self.gen_dim[i])) 110 | elif i < num_layers-3: 111 | # Res layers 112 | new_conv_0 = tf.layers.conv2d(new_pad, 113 | filters=self.gen_dim[i], 114 | kernel_size=self.kernel_size[i], 115 | strides=(1,1), 116 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 117 | padding='valid', 118 | reuse=reuse, 119 | name='gen_'+self.gen_name+'_conv0_'+str(i)) 120 | # new_norm_0 = tf.contrib.layers.instance_norm(new_conv_0,reuse=reuse,scale=False,center=False,scope='gen_'+self.gen_name+'_bnorm0_'+str(i),trainable=False) 121 | new_mean_0, new_var_0 = tf.nn.moments(new_conv_0,axes=(1,2),keep_dims=True) 122 | new_norm_0 = (new_conv_0 - new_mean_0) / tf.sqrt(new_var_0 + 1e-5) 123 | new_layer_0= tf.nn.relu(new_norm_0) 124 | new_pad_0 = tf.pad(new_layer_0,[[0,0],[ps,ps],[ps,ps],[0,0]],"Reflect") 125 | 126 | new_conv = tf.layers.conv2d(new_pad_0, 127 | filters=self.gen_dim[i], 128 | kernel_size=self.kernel_size[i], 129 | strides=(1,1), 130 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 131 | padding='valid', 132 | reuse=reuse, 133 | name='gen_'+self.gen_name+'_conv_'+str(i)) 134 | 135 | if self.verbose: 136 | print('Conv layer stride 1 with kernel size ' + str(self.kernel_size[i]) + ' and number of filters ' + str(self.gen_dim[i])) 137 | print('Instance Normalization') 138 | print('Relu') 139 | print('Padding with pad size of ' + str(ps)) 140 | print('Conv layer stride 1 with kernel size ' + str(self.kernel_size[i]) + ' and number of filters ' + str(self.gen_dim[i])) 141 | else: 142 | # Deconv layers 143 | if self.deconv == 'transpose': 144 | new_conv = tf.layers.conv2d_transpose(new_pad, 145 | filters=self.gen_dim[i], 146 | kernel_size=self.kernel_size[i], 147 | strides=(2,2), 148 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 149 | padding='same', 150 | reuse=reuse, 151 | name='gen_'+self.gen_name+'_conv_'+str(i)) 152 | 153 | if self.verbose: 154 | print('Conv transpose layer stride 2 with kernel size ' + str(self.kernel_size[i]) + ' and number of filters ' + str(self.gen_dim[i])) 155 | elif self.deconv == 'resize': 156 | if i == num_layers-3: 157 | new_resize = tf.image.resize_images(new_pad,[tf.cast(tf.shape(X)[1]/2,dtype=tf.int32),\ 158 | tf.cast(tf.shape(X)[2]/2,dtype=tf.int32)]) 159 | 160 | else: 161 | new_resize = tf.image.resize_images(new_pad,[tf.shape(X)[1],tf.shape(X)[2]]) 162 | ps = int((self.kernel_size[i]-1)/2) # pad size 163 | new_pad_0 = tf.pad(new_resize,[[0,0],[ps,ps],[ps,ps],[0,0]],"Reflect") 164 | new_conv = tf.layers.conv2d(new_pad_0, 165 | filters=self.gen_dim[i], 166 | kernel_size=self.kernel_size[i], 167 | strides=(1,1), 168 | kernel_initializer=tf.initializers.random_normal(stddev=0.02), 169 | padding='valid', 170 | reuse=reuse, 171 | name='gen_'+self.gen_name+'_conv_'+str(i)) 172 | if self.verbose: 173 | print('Resize image') 174 | print('Padding with pad size of ' + str(ps)) 175 | print('Conv layer stride 1 with kernel size ' + str(self.kernel_size[i]) + ' and number of filters ' + str(self.gen_dim[i])) 176 | else: 177 | print('Unknown deconvolution method') 178 | if i < num_layers - 1: 179 | # new_norm = tf.contrib.layers.instance_norm(new_conv,reuse=reuse,scale=False,center=False,\ 180 | # scope='gen_'+self.gen_name+'_bnorm_'+str(i),trainable=False) 181 | new_mean, new_var = tf.nn.moments(new_conv,axes=(1,2),keep_dims=True) 182 | new_norm = (new_conv - new_mean) / tf.sqrt(new_var + 1e-5) 183 | 184 | if self.verbose: 185 | print('Instance normalization') 186 | else: 187 | new_norm = new_conv 188 | 189 | if i>=3 and i < num_layers-3: 190 | new_layer = new_norm + layer_list[-1] 191 | if self.verbose: 192 | print('Make residual layer (linear activation function)') 193 | elif i < num_layers-1: 194 | new_layer = tf.nn.relu(new_norm) 195 | if self.verbose: 196 | print('ReLu') 197 | else: 198 | self.bef_layer = new_norm 199 | new_layer = (tf.nn.tanh(new_norm)+1)/2. 200 | if self.verbose: 201 | print('[tanh(x)+1]/2') 202 | 203 | layer_list.append(new_layer) 204 | if self.verbose: 205 | print(' ') 206 | print('Final layer: ',new_layer) 207 | 208 | self.layer_list = layer_list 209 | return new_layer 210 | -------------------------------------------------------------------------------- /cycleGAN.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, unicode_literals 2 | 3 | 4 | import tensorflow as tf 5 | 6 | import h5py 7 | import numpy as np 8 | 9 | import os 10 | 11 | import sys 12 | sys.path.append('./Discriminator') 13 | sys.path.append('./Generator') 14 | sys.path.append('./Utilities/') 15 | import Res_Gen 16 | import PatchGAN34 17 | import PatchGAN70 18 | import PatchGAN142 19 | import MultiPatch 20 | import HisDis 21 | import Utilities 22 | import cv2 23 | class Model: 24 | """ 25 | ToDo 26 | -) save() - Save the current model parameter 27 | -) create() - Create the model layers 28 | -) init() - Initialize the model (load model if exists) 29 | -) load() - Load the parameters from the file 30 | -) ToDo 31 | 32 | Only the following functions should be called from outside: 33 | -) ToDo 34 | -) constructor 35 | """ 36 | 37 | def __init__(self, 38 | mod_name, 39 | data_file, 40 | buffer_size=32, 41 | architecture='Res6', 42 | lambda_h=10.,\ 43 | lambda_c=10.,\ 44 | dis_noise=0.25,\ 45 | deconv='transpose',\ 46 | patchgan='Patch70',\ 47 | verbose=False,\ 48 | gen_only=False): 49 | """ 50 | Create a Model (init). It will check, if a model with such a name has already been saved. If so, the model is being 51 | loaded. Otherwise, a new model with this name will be created. It will only be saved, if the save function is being 52 | called. The describtion of every parameter is given in the code below. 53 | 54 | INPUT: mod_name - This is the name of the model. It is mainly used to establish the place, where the model is being 55 | saved. 56 | data_file - hdf5 file that contains the dataset 57 | imsize - The dimension of the input images 58 | 59 | OUTPUT: - The model 60 | """ 61 | 62 | self.mod_name = mod_name # Model name (see above) 63 | 64 | self.data_file = data_file # hdf5 data file 65 | 66 | f = h5py.File(self.data_file,"r") 67 | self.a_chan = int(np.array(f['A/num_channel'])) # Number channels in A 68 | self.b_chan = int(np.array(f['B/num_channel'])) # Number channels in B 69 | self.imsize = int(np.shape(f['A/data'][0,:,0,0])[0]) # Image size (squared) 70 | self.a_size = int(np.array(f['A/num_samples'])) # Number of samples in A 71 | self.b_size = int(np.array(f['B/num_samples'])) # Number of samples in B 72 | f.close() 73 | 74 | # Reset all current saved tf stuff 75 | tf.reset_default_graph() 76 | 77 | self.architecture = architecture 78 | self.lambda_h = lambda_h 79 | self.lambda_c = lambda_c 80 | self.dis_noise_0 = dis_noise # ATTENTION: Name change from dis_noise to dis_noise_0 81 | self.deconv = deconv 82 | self.patchgan = patchgan 83 | self.verbose = verbose 84 | self.gen_only = gen_only # If true, only the generator are used (and loaded) 85 | 86 | # Create the model that is built out of two discriminators and a generator 87 | self.create() 88 | 89 | # Image buffer 90 | self.buffer_size = buffer_size 91 | self.temp_b_s = 0. 92 | self.buffer_real_a = np.zeros([self.buffer_size,self.imsize,self.imsize,self.a_chan]) 93 | self.buffer_real_b = np.zeros([self.buffer_size,self.imsize,self.imsize,self.b_chan]) 94 | self.buffer_fake_a = np.zeros([self.buffer_size,self.imsize,self.imsize,self.a_chan]) 95 | self.buffer_fake_b = np.zeros([self.buffer_size,self.imsize,self.imsize,self.b_chan]) 96 | 97 | # Create the model saver 98 | with self.graph.as_default(): 99 | if not self.gen_only: 100 | self.saver = tf.train.Saver() 101 | else: 102 | self.saver = tf.train.Saver(var_list=self.list_gen) 103 | 104 | def create(self): 105 | """ 106 | Create the model. ToDo 107 | """ 108 | # Create a graph and add all layers 109 | self.graph = tf.Graph() 110 | with self.graph.as_default(): 111 | # Define variable learning rate and dis_noise 112 | self.relative_lr = tf.placeholder_with_default([1.],[1],name="relative_lr") 113 | self.relative_lr = self.relative_lr[0] 114 | 115 | self.rel_dis_noise = tf.placeholder_with_default([1.],[1],name="rel_dis_noise") 116 | self.rel_dis_noise = self.rel_dis_noise[0] 117 | self.dis_noise = self.rel_dis_noise * self.dis_noise_0 118 | 119 | 120 | # Create the generator and discriminator 121 | if self.architecture == 'Res6': 122 | gen_dim = [64, 128,256, 256,256,256,256,256,256, 128,64 ] 123 | kernel_size =[7, 3,3, 3,3,3,3,3,3, 3,3, 7] 124 | elif self.architecture == 'Res9': 125 | gen_dim= [64, 128,256, 256,256,256,256,256,256,256,256,256, 128,64 ] 126 | kernel_size=[7, 3,3, 3,3,3,3,3,3,3,3,3, 3,3, 7] 127 | else: 128 | print('Unknown generator architecture') 129 | return None 130 | 131 | self.genA = Res_Gen.ResGen('BtoA',self.a_chan,gen_dim=gen_dim,kernel_size=kernel_size,deconv=self.deconv,verbose=self.verbose) 132 | self.genB = Res_Gen.ResGen('AtoB',self.b_chan,gen_dim=gen_dim,kernel_size=kernel_size,deconv=self.deconv,verbose=self.verbose) 133 | 134 | if self.patchgan == 'Patch34': 135 | self.disA = PatchGAN34.PatchGAN34('A',noise=self.dis_noise) 136 | self.disB = PatchGAN34.PatchGAN34('B',noise=self.dis_noise) 137 | elif self.patchgan == 'Patch70': 138 | self.disA = PatchGAN70.PatchGAN70('A',noise=self.dis_noise) 139 | self.disB = PatchGAN70.PatchGAN70('B',noise=self.dis_noise) 140 | elif self.patchgan == 'Patch142': 141 | self.disA = PatchGAN142.PatchGAN142('A',noise=self.dis_noise) 142 | self.disB = PatchGAN142.PatchGAN142('B',noise=self.dis_noise) 143 | elif self.patchgan == 'MultiPatch': 144 | self.disA = MultiPatch.MultiPatch('A',noise=self.dis_noise) 145 | self.disB = MultiPatch.MultiPatch('B',noise=self.dis_noise) 146 | else: 147 | print('Unknown Patch discriminator type') 148 | return None 149 | 150 | self.disA_His = HisDis.HisDis('A',noise=self.dis_noise,keep_prob=1.) 151 | self.disB_His = HisDis.HisDis('B',noise=self.dis_noise,keep_prob=1.) 152 | 153 | # Create a placeholder for the input data 154 | self.A = tf.placeholder(tf.float32,[None, None, None, self.a_chan],name="a") 155 | self.B = tf.placeholder(tf.float32,[None, None, None, self.b_chan],name="b") 156 | 157 | if self.verbose: 158 | print('Size A: ' +str(self.a_chan)) # Often 1 --> Real 159 | print('Size B: ' +str(self.b_chan)) # Often 3 --> Syn 160 | 161 | # Create cycleGAN 162 | 163 | self.fake_A = self.genA.create(self.B,False) 164 | self.fake_B = self.genB.create(self.A,False) 165 | 166 | 167 | 168 | # Define the histogram loss 169 | t_A = tf.transpose(tf.reshape(self.A,[-1, self.a_chan]),[1,0]) 170 | t_B = tf.transpose(tf.reshape(self.B,[-1, self.b_chan]),[1,0]) 171 | t_fake_A = tf.transpose(tf.reshape(self.fake_A,[-1, self.a_chan]),[1,0]) 172 | t_fake_B = tf.transpose(tf.reshape(self.fake_B,[-1, self.b_chan]),[1,0]) 173 | 174 | self.s_A,_ = tf.nn.top_k(t_A,tf.shape(t_A)[1]) 175 | self.s_B,_ = tf.nn.top_k(t_B,tf.shape(t_B)[1]) 176 | self.s_fake_A,_ = tf.nn.top_k(t_fake_A,tf.shape(t_fake_A)[1]) 177 | self.s_fake_B,_ = tf.nn.top_k(t_fake_B,tf.shape(t_fake_B)[1]) 178 | 179 | self.m_A = tf.reshape(tf.reduce_mean(tf.reshape(self.s_A,[self.a_chan, self.imsize, -1]),axis=2),[1, -1]) 180 | self.m_B = tf.reshape(tf.reduce_mean(tf.reshape(self.s_B,[self.b_chan, self.imsize, -1]),axis=2),[1, -1]) 181 | self.m_fake_A = tf.reshape(tf.reduce_mean(tf.reshape(self.s_fake_A,[self.a_chan, self.imsize, -1]),axis=2),[1, -1]) 182 | self.m_fake_B = tf.reshape(tf.reduce_mean(tf.reshape(self.s_fake_B,[self.b_chan, self.imsize, -1]),axis=2),[1, -1]) 183 | 184 | # Define generator loss functions 185 | self.lambda_c = tf.placeholder_with_default([self.lambda_c],[1],name="lambda_c") 186 | self.lambda_c = self.lambda_c[0] 187 | self.lambda_h = tf.placeholder_with_default([self.lambda_h],[1],name="lambda_h") 188 | self.lambda_h = self.lambda_h[0] 189 | 190 | self.dis_real_A = self.disA.create(self.A,False) 191 | self.dis_real_Ah = self.disA_His.create(self.m_A,False) 192 | self.dis_real_B = self.disB.create(self.B,False) 193 | self.dis_real_Bh = self.disB_His.create(self.m_B,False) 194 | self.dis_fake_A = self.disA.create(self.fake_A,True) 195 | self.dis_fake_Ah = self.disA_His.create(self.m_fake_A,True) 196 | self.dis_fake_B = self.disB.create(self.fake_B,True) 197 | self.dis_fake_Bh = self.disB_His.create(self.m_fake_B,True) 198 | 199 | self.cyc_A = self.genA.create(self.fake_B,True) 200 | self.cyc_B = self.genB.create(self.fake_A,True) 201 | 202 | 203 | # Define cycle loss (eq. 2) 204 | self.loss_cyc_A = tf.reduce_mean(tf.abs(self.cyc_A-self.A)) 205 | self.loss_cyc_B = tf.reduce_mean(tf.abs(self.cyc_B-self.B)) 206 | 207 | self.loss_cyc = self.loss_cyc_A + self.loss_cyc_B 208 | 209 | # Define discriminator losses (eq. 1) 210 | self.loss_dis_A = (tf.reduce_mean(tf.square(self.dis_real_A)) +\ 211 | tf.reduce_mean(tf.square(1-self.dis_fake_A)))*0.5 +\ 212 | (tf.reduce_mean(tf.square(self.dis_real_Ah)) +\ 213 | tf.reduce_mean(tf.square(1-self.dis_fake_Ah)))*0.5*self.lambda_h 214 | 215 | 216 | self.loss_dis_B = (tf.reduce_mean(tf.square(self.dis_real_B)) +\ 217 | tf.reduce_mean(tf.square(1-self.dis_fake_B)))*0.5 +\ 218 | (tf.reduce_mean(tf.square(self.dis_real_Bh)) +\ 219 | tf.reduce_mean(tf.square(1-self.dis_fake_Bh)))*0.5*self.lambda_h 220 | 221 | self.loss_gen_A = tf.reduce_mean(tf.square(self.dis_fake_A)) +\ 222 | self.lambda_h * tf.reduce_mean(tf.square(self.dis_fake_Ah)) +\ 223 | self.lambda_c * self.loss_cyc/2. 224 | self.loss_gen_B = tf.reduce_mean(tf.square(self.dis_fake_B)) +\ 225 | self.lambda_h * tf.reduce_mean(tf.square(self.dis_fake_Bh)) +\ 226 | self.lambda_c * self.loss_cyc/2. 227 | 228 | # Create the different optimizer 229 | with self.graph.as_default(): 230 | # Optimizer for Gen 231 | self.list_gen = [] 232 | for var in tf.trainable_variables(): 233 | if 'gen' in str(var): 234 | self.list_gen.append(var) 235 | optimizer_gen = tf.train.AdamOptimizer(learning_rate=self.relative_lr*0.0002,beta1=0.5) 236 | self.opt_gen = optimizer_gen.minimize(self.loss_gen_A+self.loss_gen_B,var_list=self.list_gen) 237 | 238 | # Optimizer for Dis 239 | self.list_dis = [] 240 | for var in tf.trainable_variables(): 241 | if 'dis' in str(var): 242 | self.list_dis.append(var) 243 | optimizer_dis = tf.train.AdamOptimizer(learning_rate=self.relative_lr*0.0002,beta1=0.5) 244 | self.opt_dis = optimizer_dis.minimize(self.loss_dis_A + self.loss_dis_B,var_list=self.list_dis) 245 | 246 | def save(self,sess): 247 | """ 248 | Save the model parameter in a ckpt file. The filename is as 249 | follows: 250 | ./Models/.ckpt 251 | 252 | INPUT: sess - The current running session 253 | """ 254 | self.saver.save(sess,"./Models/" + self.mod_name + ".ckpt") 255 | 256 | def init(self,sess): 257 | """ 258 | Init the model. If the model exists in a file, load the model. Otherwise, initalize the variables 259 | 260 | INPUT: sess - The current running session 261 | """ 262 | if not os.path.isfile(\ 263 | "./Models/" + self.mod_name + ".ckpt.meta"): 264 | sess.run(tf.global_variables_initializer()) 265 | return 0 266 | else: 267 | if self.gen_only: 268 | sess.run(tf.global_variables_initializer()) 269 | self.load(sess) 270 | return 1 271 | 272 | def load(self,sess): 273 | """ 274 | Load the model from the parameter file: 275 | ./Models/.ckpt 276 | 277 | INPUT: sess - The current running session 278 | """ 279 | self.saver.restore(sess, "./Models/" + self.mod_name + ".ckpt") 280 | 281 | def train(self,batch_size=32,lambda_c=0.,lambda_h=0.,epoch=0,save=True,syn_noise=0.,real_noise=0.): 282 | f = h5py.File(self.data_file,"r") 283 | 284 | num_samples = min(self.a_size,self.b_size) 285 | num_iterations = num_samples // batch_size 286 | 287 | a_order = np.random.permutation(self.a_size) 288 | b_order = np.random.permutation(self.b_size) 289 | 290 | if self.verbose: 291 | print('lambda_c: ' + str(lambda_c)) 292 | print('lambda_h: ' + str(lambda_h)) 293 | 294 | with tf.Session(graph=self.graph) as sess: 295 | # initialize variables 296 | self.init(sess) 297 | 298 | vec_lcA = [] 299 | vec_lcB = [] 300 | 301 | vec_ldrA = [] 302 | vec_ldrAh = [] 303 | vec_ldrB = [] 304 | vec_ldrBh = [] 305 | vec_ldfA = [] 306 | vec_ldfAh = [] 307 | vec_ldfB = [] 308 | vec_ldfBh = [] 309 | 310 | vec_l_dis_A = [] 311 | vec_l_dis_B = [] 312 | vec_l_gen_A = [] 313 | vec_l_gen_B = [] 314 | 315 | rel_lr = 1. 316 | if epoch > 100: 317 | rel_lr = 2. - epoch/100. 318 | 319 | if epoch < 100: 320 | rel_noise = 0.9**epoch 321 | else: 322 | rel_noise = 0. 323 | 324 | for iteration in range(num_iterations): 325 | images_a = f['A/data'][np.sort(a_order[(iteration*batch_size):((iteration+1)*batch_size)]),:,:,:] 326 | images_b = f['B/data'][np.sort(b_order[(iteration*batch_size):((iteration+1)*batch_size)]),:,:,:] 327 | if images_a.dtype=='uint8': 328 | images_a=images_a/float(2**8-1) 329 | elif images_a.dtype=='uint16': 330 | images_a=images_a/float(2**16-1) 331 | else: 332 | raise ValueError('Dataset A is not int8 or int16') 333 | if images_b.dtype=='uint8': 334 | images_b=images_b/float(2**8-1) 335 | elif images_b.dtype=='uint16': 336 | images_b=images_b/float(2**16-1) 337 | else: 338 | raise ValueError('Dataset B is not int8 or int16') 339 | 340 | images_a += np.random.randn(*images_a.shape)*real_noise 341 | images_b += np.random.randn(*images_b.shape)*syn_noise 342 | 343 | _, l_gen_A, im_fake_A, l_gen_B, im_fake_B, cyc_A, cyc_B, sA, sB, sfA, sfB, lcA, lcB = sess.run([self.opt_gen,\ 344 | self.loss_gen_A,\ 345 | self.fake_A,\ 346 | self.loss_gen_B,\ 347 | self.fake_B,\ 348 | self.cyc_A,\ 349 | self.cyc_B,\ 350 | self.s_A,self.s_B,self.s_fake_A,self.s_fake_B,\ 351 | self.loss_cyc_A,\ 352 | self.loss_cyc_B],\ 353 | feed_dict={self.A: images_a,\ 354 | self.B: images_b,\ 355 | self.lambda_c: lambda_c,\ 356 | self.lambda_h: lambda_h,\ 357 | self.relative_lr: rel_lr,\ 358 | self.rel_dis_noise: rel_noise}) 359 | 360 | if self.temp_b_s >= self.buffer_size: 361 | rand_vec_a = np.random.permutation(self.buffer_size)[:batch_size] 362 | rand_vec_b = np.random.permutation(self.buffer_size)[:batch_size] 363 | 364 | self.buffer_real_a[rand_vec_a,...] = images_a 365 | self.buffer_real_b[rand_vec_b,...] = images_b 366 | self.buffer_fake_a[rand_vec_a,...] = im_fake_A 367 | self.buffer_fake_b[rand_vec_b,...] = im_fake_B 368 | else: 369 | low = int(self.temp_b_s) 370 | high = int(min(self.temp_b_s + batch_size,self.buffer_size)) 371 | self.temp_b_s = high 372 | 373 | self.buffer_real_a[low:high,...] = images_a[:(high-low),...] 374 | self.buffer_real_b[low:high,...] = images_b[:(high-low),...] 375 | self.buffer_fake_a[low:high,...] = im_fake_A[:(high-low),...] 376 | self.buffer_fake_b[low:high,...] = im_fake_B[:(high-low),...] 377 | 378 | # Create dataset out of buffer and gen images to train dis 379 | dis_real_a = np.copy(images_a) 380 | dis_real_b = np.copy(images_b) 381 | dis_fake_a = np.copy(im_fake_A) 382 | dis_fake_b = np.copy(im_fake_B) 383 | 384 | half_b_s = int(batch_size/2) 385 | rand_vec_a = np.random.permutation(self.temp_b_s)[:half_b_s] 386 | rand_vec_b = np.random.permutation(self.temp_b_s)[:half_b_s] 387 | dis_real_a[:half_b_s,...] = self.buffer_real_a[rand_vec_a,...] 388 | dis_fake_a[:half_b_s,...] = self.buffer_fake_a[rand_vec_a,...] 389 | dis_real_b[:half_b_s,...] = self.buffer_real_b[rand_vec_b,...] 390 | dis_fake_b[:half_b_s,...] = self.buffer_fake_b[rand_vec_b,...] 391 | 392 | _, l_dis_A, l_dis_B, \ 393 | ldrA,ldrAh,ldfA,ldfAh,\ 394 | ldrB,ldrBh,ldfB,ldfBh = sess.run([\ 395 | self.opt_dis, 396 | self.loss_dis_A, 397 | self.loss_dis_B, 398 | self.dis_real_A, 399 | self.dis_real_Ah, 400 | self.dis_fake_A, 401 | self.dis_fake_Ah, 402 | self.dis_real_B, 403 | self.dis_real_Bh, 404 | self.dis_fake_B, 405 | self.dis_fake_Bh],feed_dict={self.A: dis_real_a,\ 406 | self.B: dis_real_b,\ 407 | self.fake_A: dis_fake_a,\ 408 | self.fake_B: dis_fake_b,\ 409 | self.lambda_c: lambda_c,\ 410 | self.lambda_h: lambda_h,\ 411 | self.relative_lr: rel_lr,\ 412 | self.rel_dis_noise: rel_noise}) 413 | 414 | vec_l_dis_A.append(l_dis_A) 415 | vec_l_dis_B.append(l_dis_B) 416 | vec_l_gen_A.append(l_gen_A) 417 | vec_l_gen_B.append(l_gen_B) 418 | 419 | vec_lcA.append(lcA) 420 | vec_lcB.append(lcB) 421 | 422 | vec_ldrA.append(ldrA) 423 | vec_ldrAh.append(ldrAh) 424 | vec_ldrB.append(ldrB) 425 | vec_ldrBh.append(ldrBh) 426 | vec_ldfA.append(ldfA) 427 | vec_ldfAh.append(ldfAh) 428 | vec_ldfB.append(ldfB) 429 | vec_ldfBh.append(ldfBh) 430 | 431 | if np.shape(images_b)[-1]==4: 432 | 433 | images_b=np.vstack((images_b[0,:,:,0:3],np.tile(images_b[0,:,:,3].reshape(320,320,1),[1,1,3]))) 434 | im_fake_B=np.vstack((im_fake_B[0,:,:,0:3],np.tile(im_fake_B[0,:,:,3].reshape(320,320,1),[1,1,3]))) 435 | cyc_B=np.vstack((cyc_B[0,:,:,0:3],np.tile(cyc_B[0,:,:,3].reshape(320,320,1),[1,1,3]))) 436 | images_b=images_b[np.newaxis,:,:,:] 437 | im_fake_B=im_fake_B[np.newaxis,:,:,:] 438 | cyc_B=cyc_B[np.newaxis,:,:,:] 439 | 440 | if iteration%5==0: 441 | sneak_peak=Utilities.produce_tiled_images(images_a,images_b,im_fake_A, im_fake_B,cyc_A,cyc_B) 442 | 443 | cv2.imshow("",sneak_peak[:,:,[2,1,0]]) 444 | cv2.waitKey(1) 445 | 446 | print("\rTrain: {}/{} ({:.1f}%)".format(iteration+1, num_iterations,(iteration) * 100 / (num_iterations-1)) + \ 447 | " Loss_dis_A={:.4f}, Loss_dis_B={:.4f}".format(np.mean(vec_l_dis_A),np.mean(vec_l_dis_B)) + \ 448 | ", Loss_gen_A={:.4f}, Loss_gen_B={:.4f}".format(np.mean(vec_l_gen_A),np.mean(vec_l_gen_B))\ 449 | ,end=" ") 450 | 451 | # Save model 452 | if save: 453 | self.save(sess) 454 | cv2.imwrite("./Models/Images/" + self.mod_name + "_Epoch_" + str(epoch) + ".png",sneak_peak[:,:,[2,1,0]]*255) 455 | print("") 456 | 457 | f.close() 458 | 459 | loss_gen_A = [np.mean(np.square(np.array(vec_ldfA))),np.mean(np.square(np.array(vec_ldfAh))),np.mean(np.array(lcA))] 460 | loss_gen_B = [np.mean(np.square(np.array(vec_ldfB))),np.mean(np.square(np.array(vec_ldfBh))),np.mean(np.array(lcB))] 461 | loss_dis_A = [np.mean(np.square(np.array(vec_ldrA))),np.mean(np.square(1.-np.array(vec_ldfA))),\ 462 | np.mean(np.square(np.array(vec_ldrAh))),np.mean(np.square(1.-np.array(vec_ldfAh)))] 463 | loss_dis_B = [np.mean(np.square(np.array(vec_ldrB))),np.mean(np.square(1.-np.array(vec_ldfB))),\ 464 | np.mean(np.square(np.array(vec_ldrBh))),np.mean(np.square(1.-np.array(vec_ldfBh)))] 465 | 466 | return [loss_gen_A,loss_gen_B,loss_dis_A,loss_dis_B] 467 | 468 | def predict(self,lambda_c=0.,lambda_h=0.): 469 | f = h5py.File(self.data_file,"r") 470 | 471 | rand_a = np.random.randint(self.a_size-32) 472 | rand_b = np.random.randint(self.b_size-32) 473 | 474 | images_a = f['A/data'][rand_a:(rand_a+32),:,:,:]/255. 475 | images_b = f['B/data'][rand_b:(rand_b+32),:,:,:]/255. 476 | with tf.Session(graph=self.graph) as sess: 477 | # initialize variables 478 | self.init(sess) 479 | 480 | fake_A, fake_B, cyc_A, cyc_B = \ 481 | sess.run([self.fake_A,self.fake_B,self.cyc_A,self.cyc_B],\ 482 | feed_dict={self.A: images_a,\ 483 | self.B: images_b,\ 484 | self.lambda_c: lambda_c,\ 485 | self.lambda_h: lambda_h}) 486 | 487 | f.close() 488 | return images_a, images_b, fake_A, fake_B, cyc_A, cyc_B 489 | 490 | def generator_A(self,batch_size=32,lambda_c=0.,lambda_h=0.): 491 | f = h5py.File(self.data_file,"r") 492 | f_save = h5py.File("./Models/" + self.mod_name + '_gen_A.h5',"w") 493 | 494 | # Find number of samples 495 | num_samples = self.b_size 496 | num_iterations = num_samples // batch_size 497 | 498 | gen_data = np.zeros((f['B/data'].shape[0],f['B/data'].shape[1],f['B/data'].shape[2],f['A/data'].shape[3]),dtype=np.uint16) 499 | 500 | with tf.Session(graph=self.graph) as sess: 501 | # initialize variables 502 | self.init(sess) 503 | 504 | for iteration in range(num_iterations): 505 | images_b = f['B/data'][(iteration*batch_size):((iteration+1)*batch_size),:,:,:] 506 | if images_b.dtype=='uint8': 507 | images_b=images_b/float(2**8-1) 508 | elif images_b.dtype=='uint16': 509 | images_b=images_b/float(2**16-1) 510 | else: 511 | raise ValueError('Dataset B is not int8 or int16') 512 | 513 | gen_A = sess.run(self.fake_A,feed_dict={self.B: images_b,\ 514 | self.lambda_c: lambda_c,\ 515 | self.lambda_h: lambda_h}) 516 | gen_data[(iteration*batch_size):((iteration+1)*batch_size),:,:,:] = (np.minimum(np.maximum(gen_A,0),1)*(2**16-1)).astype(np.uint16) 517 | 518 | print("\rGenerator A: {}/{} ({:.1f}%)".format(iteration+1, num_iterations, iteration*100/(num_iterations-1)),end=" ") 519 | 520 | group = f_save.create_group('A') 521 | group.create_dataset(name='data', data=gen_data,dtype=np.uint16) 522 | 523 | f_save.close() 524 | f.close() 525 | 526 | return None 527 | 528 | def generator_B(self,batch_size=32,lambda_c=0.,lambda_h=0.): 529 | f = h5py.File(self.data_file,"r") 530 | f_save = h5py.File("./Models/" + self.mod_name + '_gen_B.h5',"w") 531 | 532 | # Find number of samples 533 | num_samples = self.a_size 534 | num_iterations = num_samples // batch_size 535 | 536 | gen_data = np.zeros((f['A/data'].shape[0],f['A/data'].shape[1],f['A/data'].shape[2],f['B/data'].shape[3]),dtype=np.uint16) 537 | 538 | with tf.Session(graph=self.graph) as sess: 539 | # initialize variables 540 | self.init(sess) 541 | 542 | for iteration in range(num_iterations): 543 | images_a = f['A/data'][(iteration*batch_size):((iteration+1)*batch_size),:,:,:] 544 | if images_a.dtype=='uint8': 545 | images_a=images_a/float(2**8-1) 546 | elif images_a.dtype=='uint16': 547 | images_a=images_a/float(2**16-1) 548 | else: 549 | raise ValueError('Dataset A is not int8 or int16') 550 | 551 | gen_B = sess.run(self.fake_B,feed_dict={self.A: images_a,\ 552 | self.lambda_c: lambda_c,\ 553 | self.lambda_h: lambda_h}) 554 | gen_data[(iteration*batch_size):((iteration+1)*batch_size),:,:,:] = (np.minimum(np.maximum(gen_B,0),1)*(2**16-1)).astype(np.uint16) 555 | 556 | print("\rGenerator B: {}/{} ({:.1f}%)".format(iteration+1, num_iterations, iteration*100/(num_iterations-1)),end=" ") 557 | 558 | group = f_save.create_group('B') 559 | group.create_dataset(name='data', data=gen_data,dtype=np.uint16) 560 | 561 | f_save.close() 562 | f.close() 563 | 564 | return None 565 | 566 | 567 | def get_loss(self,lambda_c=0.,lambda_h=0.): 568 | f = h5py.File(self.data_file,"r") 569 | 570 | rand_a = np.random.randint(self.a_size-32) 571 | rand_b = np.random.randint(self.b_size-32) 572 | 573 | images_a = f['A/data'][rand_a:(rand_a+32),:,:,:]/255. 574 | images_b = f['B/data'][rand_b:(rand_b+32),:,:,:]/255. 575 | with tf.Session(graph=self.graph) as sess: 576 | # initialize variables 577 | self.init(sess) 578 | 579 | l_rA,l_rB,l_fA,l_fB = \ 580 | sess.run([self.dis_real_A,self.dis_real_B,self.dis_fake_A,self.dis_fake_B,],\ 581 | feed_dict={self.A: images_a,\ 582 | self.B: images_b,\ 583 | self.lambda_c: lambda_c,\ 584 | self.lambda_h: lambda_h}) 585 | 586 | f.close() 587 | return l_rA,l_rB,l_fA,l_fB 588 | -------------------------------------------------------------------------------- /notebooks/VGG_Synthetic_by_localization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import imageio as io\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "%matplotlib inline\n", 12 | "import numpy as np\n", 13 | "import os\n", 14 | "from skimage.transform import rotate" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "min_cells,max_cells=70,310" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 13, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "pad=10\n", 33 | "xx,yy=np.meshgrid(np.arange(256+2*pad),np.arange(256+2*pad))\n", 34 | "#border pad\n", 35 | "fringe=np.pad(np.zeros((256,256)),((pad,pad),(pad,pad)),mode='constant',constant_values=1)\n", 36 | "fringe=fringe.astype(bool)\n", 37 | "\n", 38 | "\n", 39 | "def make_localization_synthetic_vgg():\n", 40 | " n_cell=np.random.randint(min_cells,max_cells)\n", 41 | " x,y=np.random.randint(0,255,size=(2,n_cell))+pad\n", 42 | " image=np.zeros((256+2*pad,256+2*pad,3))\n", 43 | " #homing_gaussian=np.zeros((256+2*pad,256+2*pad))\n", 44 | " #square_for_area=np.zeros((256+2*pad,256+2*pad))\n", 45 | " noise=np.clip(np.random.normal(loc=0.02,scale=0.03,size=(256+2*pad,256+2*pad)),0,1)\n", 46 | " image[:,:,0]=np.clip(np.random.normal(loc=0.02,scale=0.03,size=(256+2*pad,256+2*pad)),0,1)\n", 47 | " \n", 48 | " for i in range(n_cell):\n", 49 | " mask=(xx-x[i])**2+(yy-y[i])**2<(6)**2\n", 50 | " image[mask,0]=0.3*np.random.rand()+0.6+noise[mask]\n", 51 | " image[:,:,1]+=0.4*np.exp(-((yy-y[i])/3.)**2-((xx-x[i])/3.)**2)\n", 52 | " image[:,:,2]+=0.5*np.exp(-((yy-y[i])/1.)**2-((xx-x[i])/1.)**2)\n", 53 | " return image" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 14, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "data": { 63 | "text/plain": [ 64 | "" 65 | ] 66 | }, 67 | "execution_count": 14, 68 | "metadata": {}, 69 | "output_type": "execute_result" 70 | }, 71 | { 72 | "data": { 73 | "image/png": "\n", 74 | "text/plain": [ 75 | "
" 76 | ] 77 | }, 78 | "metadata": { 79 | "needs_background": "light" 80 | }, 81 | "output_type": "display_data" 82 | } 83 | ], 84 | "source": [ 85 | "plt.imshow(make_localization_synthetic_vgg())" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 17, 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "name": "stdout", 95 | "output_type": "stream", 96 | "text": [ 97 | "\n" 98 | ] 99 | } 100 | ], 101 | "source": [ 102 | "N_images=50\n", 103 | "dataname='VGG_localization'\n", 104 | "train_name='trainA'\n", 105 | "directory='../Data/'+dataname+'/'+train_name+'/'\n", 106 | "if not os.path.exists(directory):\n", 107 | " os.makedirs(directory)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [ 115 | { 116 | "name": "stderr", 117 | "output_type": "stream", 118 | "text": [ 119 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.02268088151]\n", 120 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n" 121 | ] 122 | }, 123 | { 124 | "name": "stdout", 125 | "output_type": "stream", 126 | "text": [ 127 | "0" 128 | ] 129 | }, 130 | { 131 | "name": "stderr", 132 | "output_type": "stream", 133 | "text": [ 134 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.01464227872]\n", 135 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 136 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.02984934249]\n", 137 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 138 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.01064395335]\n", 139 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 140 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.01649309037]\n", 141 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 142 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.02021895102]\n", 143 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 144 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.00492180949]\n", 145 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 146 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.01165869974]\n", 147 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 148 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.01293333199]\n", 149 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 150 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.00414250075]\n", 151 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 152 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.00521218855]\n", 153 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 154 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.27627977337]\n", 155 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 156 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.00136268245]\n", 157 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 158 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.03352604424]\n", 159 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 160 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.01857872308]\n", 161 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 162 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.00677343811]\n", 163 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 164 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.13678852504]\n", 165 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 166 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.02354139996]\n", 167 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 168 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.00641660986]\n", 169 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 170 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.00318504915]\n", 171 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 172 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.01248495563]\n", 173 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 174 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.03589042459]\n", 175 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 176 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.00244328906]\n", 177 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n", 178 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.00883233411]\n", 179 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n" 180 | ] 181 | } 182 | ], 183 | "source": [ 184 | "for i in range(N_images):\n", 185 | " name=directory+str(i).zfill(5)+'.png'\n", 186 | " image=make_localization_synthetic_vgg()\n", 187 | " io.imsave(name,image)\n", 188 | " if i%50==0:\n", 189 | " print i," 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [] 198 | } 199 | ], 200 | "metadata": { 201 | "kernelspec": { 202 | "display_name": "Python 2", 203 | "language": "python", 204 | "name": "python2" 205 | }, 206 | "language_info": { 207 | "codemirror_mode": { 208 | "name": "ipython", 209 | "version": 2 210 | }, 211 | "file_extension": ".py", 212 | "mimetype": "text/x-python", 213 | "name": "python", 214 | "nbconvert_exporter": "python", 215 | "pygments_lexer": "ipython2", 216 | "version": "2.7.12" 217 | } 218 | }, 219 | "nbformat": 4, 220 | "nbformat_minor": 2 221 | } 222 | --------------------------------------------------------------------------------