├── Data
├── C_Elegans
│ ├── Raw
│ │ └── README.txt
│ ├── Synthetic
│ │ └── README.txt
│ ├── Original
│ │ └── README.txt
│ └── README.txt
└── Example
│ ├── Genuine
│ ├── raw_image_0.png
│ ├── raw_image_1.png
│ ├── raw_image_10.png
│ ├── raw_image_11.png
│ ├── raw_image_12.png
│ ├── raw_image_13.png
│ ├── raw_image_14.png
│ ├── raw_image_15.png
│ ├── raw_image_2.png
│ ├── raw_image_3.png
│ ├── raw_image_4.png
│ ├── raw_image_5.png
│ ├── raw_image_6.png
│ ├── raw_image_7.png
│ ├── raw_image_8.png
│ └── raw_image_9.png
│ └── Synthetic
│ ├── syn_image_0.png
│ ├── syn_image_1.png
│ ├── syn_image_10.png
│ ├── syn_image_11.png
│ ├── syn_image_12.png
│ ├── syn_image_13.png
│ ├── syn_image_14.png
│ ├── syn_image_15.png
│ ├── syn_image_2.png
│ ├── syn_image_3.png
│ ├── syn_image_4.png
│ ├── syn_image_5.png
│ ├── syn_image_6.png
│ ├── syn_image_7.png
│ ├── syn_image_8.png
│ └── syn_image_9.png
├── Utilities
├── Utilities.pyc
└── Utilities.py
├── Scripts
├── simple_example.sh
├── README.txt
├── train_UDCT_colored_live_neurons.sh
├── train_UDCT_nanowire.sh
├── train_UDCT_dead_live_neurons.sh
└── train_UDCT_c_elegans.sh
├── Discriminator
├── MultiPatch.py
├── HisDis.py
├── PatchGAN34.py
├── PatchGAN70.py
└── PatchGAN142.py
├── notebooks
├── make_raw_nanowire.ipynb
├── make_raw_dead_live_neurons.ipynb
├── make_synthetic_wires.ipynb
└── VGG_Synthetic_by_localization.ipynb
├── README.md
├── create_h5_dataset.py
├── main.py
├── Generator
└── Res_Gen.py
└── cycleGAN.py
/Data/C_Elegans/Raw/README.txt:
--------------------------------------------------------------------------------
1 | 1) Execute: /notebooks/make_raw_c_elegans.ipynb
2 |
--------------------------------------------------------------------------------
/Data/C_Elegans/Synthetic/README.txt:
--------------------------------------------------------------------------------
1 | 1) Execute: /notebooks/make_synthetic_c_elegans.ipynb
2 |
--------------------------------------------------------------------------------
/Utilities/Utilities.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Utilities/Utilities.pyc
--------------------------------------------------------------------------------
/Data/Example/Genuine/raw_image_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_0.png
--------------------------------------------------------------------------------
/Data/Example/Genuine/raw_image_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_1.png
--------------------------------------------------------------------------------
/Data/Example/Genuine/raw_image_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_10.png
--------------------------------------------------------------------------------
/Data/Example/Genuine/raw_image_11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_11.png
--------------------------------------------------------------------------------
/Data/Example/Genuine/raw_image_12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_12.png
--------------------------------------------------------------------------------
/Data/Example/Genuine/raw_image_13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_13.png
--------------------------------------------------------------------------------
/Data/Example/Genuine/raw_image_14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_14.png
--------------------------------------------------------------------------------
/Data/Example/Genuine/raw_image_15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_15.png
--------------------------------------------------------------------------------
/Data/Example/Genuine/raw_image_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_2.png
--------------------------------------------------------------------------------
/Data/Example/Genuine/raw_image_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_3.png
--------------------------------------------------------------------------------
/Data/Example/Genuine/raw_image_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_4.png
--------------------------------------------------------------------------------
/Data/Example/Genuine/raw_image_5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_5.png
--------------------------------------------------------------------------------
/Data/Example/Genuine/raw_image_6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_6.png
--------------------------------------------------------------------------------
/Data/Example/Genuine/raw_image_7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_7.png
--------------------------------------------------------------------------------
/Data/Example/Genuine/raw_image_8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_8.png
--------------------------------------------------------------------------------
/Data/Example/Genuine/raw_image_9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Genuine/raw_image_9.png
--------------------------------------------------------------------------------
/Data/Example/Synthetic/syn_image_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_0.png
--------------------------------------------------------------------------------
/Data/Example/Synthetic/syn_image_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_1.png
--------------------------------------------------------------------------------
/Data/Example/Synthetic/syn_image_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_10.png
--------------------------------------------------------------------------------
/Data/Example/Synthetic/syn_image_11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_11.png
--------------------------------------------------------------------------------
/Data/Example/Synthetic/syn_image_12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_12.png
--------------------------------------------------------------------------------
/Data/Example/Synthetic/syn_image_13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_13.png
--------------------------------------------------------------------------------
/Data/Example/Synthetic/syn_image_14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_14.png
--------------------------------------------------------------------------------
/Data/Example/Synthetic/syn_image_15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_15.png
--------------------------------------------------------------------------------
/Data/Example/Synthetic/syn_image_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_2.png
--------------------------------------------------------------------------------
/Data/Example/Synthetic/syn_image_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_3.png
--------------------------------------------------------------------------------
/Data/Example/Synthetic/syn_image_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_4.png
--------------------------------------------------------------------------------
/Data/Example/Synthetic/syn_image_5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_5.png
--------------------------------------------------------------------------------
/Data/Example/Synthetic/syn_image_6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_6.png
--------------------------------------------------------------------------------
/Data/Example/Synthetic/syn_image_7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_7.png
--------------------------------------------------------------------------------
/Data/Example/Synthetic/syn_image_8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_8.png
--------------------------------------------------------------------------------
/Data/Example/Synthetic/syn_image_9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UDCTGAN/UDCT/HEAD/Data/Example/Synthetic/syn_image_9.png
--------------------------------------------------------------------------------
/Scripts/simple_example.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # Create the sample dataset
4 | cd ..
5 | python create_h5_dataset.py ./Data/Example/Genuine/ ./Data/Example/Synthetic/ ./Data/Example/example_dataset.h5
6 |
7 | # Create a Directory for model data
8 | mkdir -p Models
9 |
10 | # Train the network
11 | python main.py --dataset=./Data/Example/example_dataset.h5 --name=example_model
12 |
--------------------------------------------------------------------------------
/Data/C_Elegans/Original/README.txt:
--------------------------------------------------------------------------------
1 | 1) Download the original images of the c elegans dataset from the Broad Bioimage Bechnmark Collection:
2 |
3 | https://data.broadinstitute.org/bbbc/BBBC010/BBBC010_v1_images.zip
4 |
5 | Accession number: BBBC010, Version 1
6 |
7 | 2) Extract the images into this directory. Afterwards, there should be 195 tif images in total in this directory.
8 |
9 | 3) Execute /notebooks/transform_c_elegans_tif_to_png.ipynb
10 |
--------------------------------------------------------------------------------
/Scripts/README.txt:
--------------------------------------------------------------------------------
1 | In this directory you can find scripts that can be exectuted in order to reproduce the results of the publication.
2 |
3 | simple_example.sh
4 | A simple example that tests if the code is executing properly.
5 |
6 | train_UDCT_c_elegans.sh
7 | This script downloads the C. elegans dataset from the Broad Bioimage Bechnmark Collection, creates the Raw/Syn dataset and trains a network. Attention: This script is deleting some data in the Data/C_Elegans/Original directory.
8 |
9 | train_UDCT_dead_live_neurons.sh
10 | This script creates the dead vs alive dataset for the neurons. Afterwards, it trains a UDCT cycleGAN on the dataset.
11 |
12 |
13 |
--------------------------------------------------------------------------------
/Utilities/Utilities.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def produce_tiled_images(im_A,im_B,fake_A,fake_B,cyc_A,cyc_B):
4 |
5 | list_of_images=[im_A,im_B,fake_A,fake_B,cyc_A,cyc_B]
6 | for i in range(6):
7 | if np.shape(list_of_images[i])[-1]==1:
8 | list_of_images[i]=np.tile(list_of_images[i],[1,1,1,3])
9 | list_of_images[i]=np.pad(list_of_images[i][0,:,:,:], ((20,20),(20,20),(0,0)), mode='constant', constant_values=[0.5])
10 | im_A,im_B,fake_A,fake_B,cyc_A,cyc_B=list_of_images
11 | a=np.vstack( (im_A,im_B))
12 | b=np.vstack( (fake_B,fake_A))
13 | c=np.vstack( (cyc_A,cyc_B))
14 | return np.hstack((a,b,c))
15 |
--------------------------------------------------------------------------------
/Scripts/train_UDCT_colored_live_neurons.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # Create the raw dataset
4 | echo "Creating the raw dataset"
5 | cd ../notebooks/
6 | jupyter nbconvert --to notebook --execute make_raw_colored_live_neurons.ipynb --ExecutePreprocessor.timeout=1800
7 |
8 | # Create the hdf5 file
9 | cd ../
10 | mkdir -p Models
11 | echo "Creating the hdf5 dataset"
12 | python create_h5_dataset.py ./Data/Neuron_Col_Live/Raw/ ./Data/Neuron_Col_Live/Synthetic/ ./Data/Neuron_Col_Live/colored_live_neuron_dataset.h5
13 |
14 | # Train the network
15 | echo "Training the network"
16 | python main.py --dataset=./Data/Neuron_Col_Live/colored_live_neuron_dataset.h5 --name=live_colored_neuron_new
17 |
18 | # Create the generated synthetic images
19 | python main.py --dataset=./Data/Neuron_Col_Live/colored_live_neuron_dataset.h5 --name=live_colored_neuron_new --mode=gen_B
20 |
--------------------------------------------------------------------------------
/Data/C_Elegans/README.txt:
--------------------------------------------------------------------------------
1 | Disclaimer: 'We used the C.elegans infection live/dead image set version 1 provided by Fred Ausubel and available from the Broad Bioimage Benchmark Collection [Ljosa et al., Nature Methods, 2012].'
2 | https://data.broadinstitute.org/bbbc/BBBC010/
3 |
4 | For copyright reasons, the C. elegans dataset needs to be created by downloading the data from the Broad Bioimage Bechnmark Collection.
5 |
6 | To do so, please follow the README.txt instructions in the following order:
7 |
8 | 1) README.txt in Original/
9 |
10 | 2) README.txt in Raw/
11 |
12 | 3) README.txt in Synthetic/
13 |
14 | 4) Execute in root directory: python create_h5_dataset.py ./Data/C_Elegans/Raw/ ./Data/C_Elegans/Synthetic/ ./Data/C_Elegans/c_elegans_dataset.h5
15 |
16 | Instead, you can also execute the script 'train_UDCT_c_elegans.sh' in the directory ../Scripts. It does all these steps automatically.
17 |
--------------------------------------------------------------------------------
/Scripts/train_UDCT_nanowire.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # Create the raw dataset
4 | echo "Creating the raw dataset"
5 | mkdir -p ../Data/Nanowire/Raw
6 | cd ../Data/Nanowire/Raw/
7 | wget https://downloads.lbb.ethz.ch/Data/lbb_raw_nanowire_images.h5
8 | cd ../../../notebooks/
9 | jupyter nbconvert --to notebook --execute make_raw_nanowire.ipynb --ExecutePreprocessor.timeout=1800
10 |
11 | # Create the synthetic images
12 | mkdir -p ../Data/Nanowire/Synthetic
13 | echo "Creating the synthetic dataset"
14 | jupyter nbconvert --to notebook --execute make_synthetic_wires.ipynb --ExecutePreprocessor.timeout=1800
15 |
16 | # Create the hdf5 file
17 | cd ../
18 | mkdir -p Models
19 | echo "Creating the hdf5 dataset"
20 | python create_h5_dataset.py ./Data/Nanowire/Raw/ ./Data/Nanowire/Synthetic/ ./Data/Nanowire/nanowire_dataset.h5
21 |
22 | # Train the network
23 | echo "Training the network"
24 | python main.py --dataset=./Data/Nanowire/nanowire_dataset.h5 --name=nanowire_new
25 |
26 | # Create the generated synthetic images
27 | python main.py --dataset=./Data/Nanowire/nanowire_dataset.h5 --name=nanowire_new --mode=gen_B
28 |
--------------------------------------------------------------------------------
/Scripts/train_UDCT_dead_live_neurons.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # Create the synthetic images for both live_vs_dead and colored_live neuron datasets
4 | cd ../notebooks/
5 | echo "Creating the synthetic dataset"
6 | jupyter nbconvert --to notebook --execute make_synthetic_livedead_neurons_and_colored.ipynb --ExecutePreprocessor.timeout=1800
7 |
8 | # Create the raw dataset
9 | echo "Creating the raw dataset"
10 | mkdir -p ../Data/Neuron_Dead_Live/Raw
11 | cd ../Data/Neuron_Dead_Live/Raw/
12 | wget https://downloads.lbb.ethz.ch/Data/lbb_raw_neuron_images.h5
13 | cd ../../../notebooks/
14 | jupyter nbconvert --to notebook --execute make_raw_dead_live_neurons.ipynb --ExecutePreprocessor.timeout=1800
15 |
16 | # Create the hdf5 file
17 | cd ../
18 | mkdir -p Models
19 | echo "Creating the hdf5 dataset"
20 | python create_h5_dataset.py ./Data/Neuron_Dead_Live/Raw/ ./Data/Neuron_Dead_Live/Synthetic/ ./Data/Neuron_Dead_Live/live_dead_neuron_dataset.h5
21 |
22 | # Train the network
23 | echo "Training the network"
24 | python main.py --dataset=./Data/Neuron_Dead_Live/live_dead_neuron_dataset.h5 --name=live_dead_neuron_new
25 |
26 | # Create the generated synthetic images
27 | python main.py --dataset=./Data/Neuron_Dead_Live/live_dead_neuron_dataset.h5 --name=live_dead_neuron_new --mode=gen_B
28 |
--------------------------------------------------------------------------------
/Scripts/train_UDCT_c_elegans.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # Download the original BBBC dataset and extract it
4 | cd ../Data/C_Elegans/Original
5 | rm ./00*.png
6 | rm ./1649_1109_0003*.tif
7 | wget https://data.broadinstitute.org/bbbc/BBBC010/BBBC010_v1_images.zip
8 | unzip BBBC010_v1_images.zip -d ./
9 | rm BBBC010_v1_images.zip
10 | mv BBBC010_v1_images/* ./
11 | rm -R BBBC010_v1_images
12 |
13 | # Transform the images to pngs
14 | cd ../../../notebooks/
15 | jupyter nbconvert --to notebook --execute transform_c_elegans_tif_to_png.ipynb --ExecutePreprocessor.timeout=1800
16 |
17 | # Create the raw dataset
18 | echo "Creating the raw dataset"
19 | jupyter nbconvert --to notebook --execute make_raw_c_elegans.ipynb --ExecutePreprocessor.timeout=1800
20 |
21 | # Create the synthetic dataset
22 | echo "Creating the synthetic dataset"
23 | jupyter nbconvert --to notebook --execute make_synthetic_c_elegans.ipynb --ExecutePreprocessor.timeout=1800
24 |
25 | # Create the hdf5 file
26 | cd ../
27 | mkdir -p Models
28 | echo "Creating the hdf5 dataset"
29 | python create_h5_dataset.py ./Data/C_Elegans/Raw/ ./Data/C_Elegans/Synthetic/ ./Data/C_Elegans/c_elegans_dataset.h5
30 |
31 | # Train the network
32 | echo "Training the network"
33 | python main.py --dataset=./Data/C_Elegans/c_elegans_dataset.h5 --name=c_elegans_new
34 |
35 | # Create the generated synthetic images
36 | python main.py --dataset=./Data/C_Elegans/c_elegans_dataset.h5 --name=c_elegans_new --mode=gen_B
37 |
--------------------------------------------------------------------------------
/Discriminator/MultiPatch.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function, unicode_literals
2 |
3 | import tensorflow as tf
4 | import PatchGAN34
5 | import PatchGAN70
6 | import PatchGAN142
7 |
8 | class MultiPatch:
9 | """
10 | This class is creating a PatchGAN discriminator as described by Zhu et al. 2018.
11 | -) save() - Save the current model parameter
12 | -) create() - Create the model layers (graph construction)
13 | -) init() - Initialize the model (load model if exists)
14 | -) load() - Load the parameters from the file
15 | -) run() - ToDo write this
16 |
17 | Only the following functions should be called from outside:
18 | -) create()
19 | -) constructor
20 | """
21 |
22 | def __init__(self,dis_name,noise=0.25):
23 | """
24 | Create a PatchGAN model (init). It will check, if a model with such a name has already been saved. If so, the model
25 | is being loaded. Otherwise, a new model with this name will be created. It will only be saved, if the save function
26 | is being called. The describtion of every parameter is given in the code below.
27 |
28 | INPUT: dis_name - This is the name of the discriminator. It is mainly used to establish the place, where the model
29 | is being saved.
30 |
31 | OUTPUT: - The model
32 | """
33 | self.dis_name = dis_name
34 | self.noise = noise
35 |
36 | self.Patch34 = PatchGAN34.PatchGAN34(self.dis_name,noise=self.noise)
37 | self.Patch70 = PatchGAN70.PatchGAN70(self.dis_name,noise=self.noise)
38 | self.Patch142 = PatchGAN142.PatchGAN142(self.dis_name,noise=self.noise)
39 |
40 | def create(self,X,reuse=True):
41 |
42 | self.out34 = self.Patch34.create(X,reuse)
43 | self.out70 = self.Patch70.create(X,reuse)
44 | self.out142 = self.Patch142.create(X,reuse)
45 |
46 | reshaped34 = tf.reshape(self.out34,[-1,tf.shape(self.out34)[1]*tf.shape(self.out34)[2],1])
47 | reshaped70 = tf.reshape(self.out70,[-1,tf.shape(self.out70)[1]*tf.shape(self.out70)[2],1])
48 | reshaped142 = tf.reshape(self.out142,[-1,tf.shape(self.out142)[1]*tf.shape(self.out142)[2],1])
49 |
50 | self.prediction = tf.concat([reshaped34,reshaped70,reshaped142],axis=1)
51 | return self.prediction
--------------------------------------------------------------------------------
/Discriminator/HisDis.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function, unicode_literals
2 |
3 | import tensorflow as tf
4 |
5 | class HisDis:
6 | """
7 | This class is creating a PatchGAN discriminator as described by Zhu et al. 2018.
8 | -) save() - Save the current model parameter
9 | -) create() - Create the model layers (graph construction)
10 | -) init() - Initialize the model (load model if exists)
11 | -) load() - Load the parameters from the file
12 | -) run() - ToDo write this
13 |
14 | Only the following functions should be called from outside:
15 | -) create()
16 | -) constructor
17 | """
18 |
19 | def __init__(self,dis_name,noise=0.1,keep_prob=0.5):
20 | """
21 | Create a histogramm discriminator
22 |
23 | INPUT: dis_name - This is the name of the discriminator. It is mainly used to establish the place, where the model
24 | is being saved.
25 |
26 | OUTPUT: - The model
27 | """
28 | self.dis_name = dis_name
29 | self.noise = noise
30 | self.keep_prob = keep_prob
31 |
32 |
33 | def create(self,X,reuse=True):
34 | """
35 | Create a histogramm discriminator
36 |
37 | INPUT: X - [None,256*n_chan]
38 |
39 | OUTPUT: - The HisDis prediction
40 | """
41 | self.hidden_1 = tf.layers.dense(tf.nn.dropout(X-0.5+tf.random_normal(tf.shape(X),0.,self.noise),keep_prob=self.keep_prob),
42 | 64,
43 | reuse=reuse,
44 | name='dis_'+self.dis_name+'_hidden_1',
45 | activation=tf.nn.tanh)
46 |
47 | self.hidden_2 = tf.layers.dense(tf.nn.dropout(self.hidden_1,keep_prob=self.keep_prob),
48 | 64,
49 | reuse=reuse,
50 | name='dis_'+self.dis_name+'_hidden_2',
51 | activation=tf.nn.tanh)
52 |
53 | self.out = tf.layers.dense(tf.nn.dropout(self.hidden_2,keep_prob=self.keep_prob),
54 | 1,
55 | reuse=reuse,
56 | name='dis_'+self.dis_name+'_h_out',
57 | activation=None) + 0.5
58 |
59 | return self.out
60 |
61 |
--------------------------------------------------------------------------------
/notebooks/make_raw_nanowire.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "Creates the raw images for the nanowire dataset. The raw images can be downloaded from:\n",
8 | "\n",
9 | "https://downloads.lbb.ethz.ch/Data/lbb_raw_nanowire_images.h5"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "path_raw_images = '../Data/Nanowire/Raw/lbb_raw_nanowire_images.h5'"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "import numpy as np\n",
28 | "import matplotlib\n",
29 | "import matplotlib.pyplot as plt\n",
30 | "import h5py as h5\n",
31 | "\n",
32 | "import random\n",
33 | "\n",
34 | "import imageio as io\n",
35 | "\n",
36 | "from skimage.filters import gaussian"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "f = h5.File(path_raw_images,\"r\")"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "data = f['Raw/data'][...]/(2**8-1.)"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": null,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "np.max(data)"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "metadata": {},
70 | "outputs": [],
71 | "source": [
72 | "plt.figure(figsize=(10,10))\n",
73 | "plt.imshow(data[0,:,:,0],cmap='gray')\n",
74 | "plt.show()"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "for i in range(data.shape[0]):\n",
84 | " io.imsave('../Data/Nanowire/Raw/'+str(i).zfill(5)+'.png',data[i,:,:,0])"
85 | ]
86 | }
87 | ],
88 | "metadata": {
89 | "kernelspec": {
90 | "display_name": "Python 2",
91 | "language": "python",
92 | "name": "python2"
93 | },
94 | "language_info": {
95 | "codemirror_mode": {
96 | "name": "ipython",
97 | "version": 2
98 | },
99 | "file_extension": ".py",
100 | "mimetype": "text/x-python",
101 | "name": "python",
102 | "nbconvert_exporter": "python",
103 | "pygments_lexer": "ipython2",
104 | "version": "2.7.15"
105 | }
106 | },
107 | "nbformat": 4,
108 | "nbformat_minor": 2
109 | }
110 |
--------------------------------------------------------------------------------
/notebooks/make_raw_dead_live_neurons.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "Creates the raw images for the dead vs live neuron dataset. The raw brightfield images can be downloaded from:\n",
8 | "\n",
9 | "https://downloads.lbb.ethz.ch/Data/lbb_raw_neuron_images.h5"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "path_raw_images = '../Data/Neuron_Dead_Live/Raw/lbb_raw_neuron_images.h5'"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "import numpy as np\n",
28 | "import matplotlib\n",
29 | "import matplotlib.pyplot as plt\n",
30 | "import h5py as h5\n",
31 | "\n",
32 | "import random\n",
33 | "\n",
34 | "import imageio as io\n",
35 | "\n",
36 | "from skimage.filters import gaussian"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "f = h5.File(path_raw_images,\"r\")"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "data = f['Raw/data'][...]/(2**16-1.)"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": null,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "np.max(data)"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "metadata": {},
70 | "outputs": [],
71 | "source": [
72 | "plt.figure(figsize=(10,10))\n",
73 | "plt.imshow(data[1,:,:,0],cmap='gray')\n",
74 | "plt.show()"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "for i in range(data.shape[0]):\n",
84 | " io.imsave('../Data/Neuron_Dead_Live/Raw/'+str(i).zfill(4)+'.png',data[i,:,:,0])"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": []
93 | }
94 | ],
95 | "metadata": {
96 | "kernelspec": {
97 | "display_name": "Python 2",
98 | "language": "python",
99 | "name": "python2"
100 | },
101 | "language_info": {
102 | "codemirror_mode": {
103 | "name": "ipython",
104 | "version": 2
105 | },
106 | "file_extension": ".py",
107 | "mimetype": "text/x-python",
108 | "name": "python",
109 | "nbconvert_exporter": "python",
110 | "pygments_lexer": "ipython2",
111 | "version": "2.7.15"
112 | }
113 | },
114 | "nbformat": 4,
115 | "nbformat_minor": 2
116 | }
117 |
--------------------------------------------------------------------------------
/Discriminator/PatchGAN34.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function, unicode_literals
2 |
3 | import tensorflow as tf
4 |
5 | class PatchGAN34:
6 | """
7 | This class is creating a PatchGAN discriminator as described by Zhu et al. 2018.
8 | -) save() - Save the current model parameter
9 | -) create() - Create the model layers (graph construction)
10 | -) init() - Initialize the model (load model if exists)
11 | -) load() - Load the parameters from the file
12 | -) run() - ToDo write this
13 |
14 | Only the following functions should be called from outside:
15 | -) create()
16 | -) constructor
17 | """
18 |
19 | def __init__(self,dis_name,noise=0.25):
20 | """
21 | Create a PatchGAN model (init). It will check, if a model with such a name has already been saved. If so, the model
22 | is being loaded. Otherwise, a new model with this name will be created. It will only be saved, if the save function
23 | is being called. The describtion of every parameter is given in the code below.
24 |
25 | INPUT: dis_name - This is the name of the discriminator. It is mainly used to establish the place, where the model
26 | is being saved.
27 |
28 | OUTPUT: - The model
29 | """
30 | self.dis_name = dis_name
31 | self.noise = noise
32 |
33 |
34 | def create(self,X,reuse=True):
35 |
36 | # C128
37 | # To add noise:
38 | self.C128_c = tf.layers.conv2d(tf.pad(X+tf.random_normal(tf.shape(X),0.,self.noise),[[0,0],[1,1],[1,1],[0,0]],"Reflect"),
39 | filters=128,
40 | kernel_size=4,
41 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
42 | strides=(2,2),
43 | padding='valid',
44 | reuse=reuse,
45 | name='dis_'+self.dis_name+'_34_conv_1')
46 | self.C128_n = tf.contrib.layers.instance_norm(self.C128_c,reuse=reuse,scope='dis_'+self.dis_name+'_34_bnorm_1',trainable=False)
47 | self.C128 = tf.nn.leaky_relu(self.C128_n)
48 |
49 | # C256
50 | self.C256_c = tf.layers.conv2d(tf.pad(self.C128,[[0,0],[1,1],[1,1],[0,0]],"Reflect"),
51 | filters=256,
52 | kernel_size=4,
53 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
54 | strides=(2,2),
55 | padding='valid',
56 | reuse=reuse,
57 | name='dis_'+self.dis_name+'_34_conv_2')
58 | self.C256_n = tf.contrib.layers.instance_norm(self.C256_c,reuse=reuse,scope='dis_'+self.dis_name+'_34_bnorm_2',trainable=False)
59 | self.C256 = tf.nn.leaky_relu(self.C256_n)
60 |
61 | # C512
62 | self.C512_c = tf.layers.conv2d(tf.pad(self.C256,[[0,0],[1,1],[1,1],[0,0]],"Reflect"),
63 | filters=512,
64 | kernel_size=4,
65 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
66 | strides=(1,1),
67 | padding='valid',
68 | reuse=reuse,
69 | name='dis_'+self.dis_name+'_34_conv_3')
70 | self.C512_n = tf.contrib.layers.instance_norm(self.C512_c,reuse=reuse,scope='dis_'+self.dis_name+'_34_bnorm_3',trainable=False)
71 | self.C512 = tf.nn.leaky_relu(self.C512_n)
72 |
73 | # c1
74 | self.c1_c = tf.layers.conv2d(tf.pad(self.C512,[[0,0],[1,1],[1,1],[0,0]],"Reflect"),
75 | filters=1,
76 | kernel_size=4,
77 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
78 | strides=(1,1),
79 | padding='valid',
80 | reuse=reuse,
81 | name='dis_'+self.dis_name+'_34_conv_4')
82 | return self.c1_c
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # UDCT
2 | You can find here the Cycle-GAN network (Tensorflow, Python 2.7+) with our histogram loss. Additionally, we provide the scripts we used to generate the synthetic datasets.
3 |
4 | Our results can be found at https://www.biorxiv.org/content/biorxiv/early/2019/03/01/563734.full.pdf
5 |
6 | ## How to use
7 | 1. Clone or download the repository
8 | 2. Create a synthetic dataset similar to your real dataset or use example dataset in ./Data/Example
9 | 3. Execute:
python create_h5_dataset.py <directory_of_raw_images> <directory_of_syn_images> <filename_of_hdf5_file>
10 | Example: python create_h5_dataset.py ./Data/Example/Genuine/ \\ ./Data/Example/Synthetic/ ./Data/Example/example_dataset.h5
11 | 4. Create the directory 'Models' in the root directory
12 | 5. Execute: python main.py --dataset=./Data/..../dataset.h5 --name=name_of_model
13 | Example: python main.py --dataset=./Data/Example/example_dataset.h5 --name=example_model
14 | 6. This will create a network that is saved in ./Models/ along with a parameter textfile. Furthermore, the average loss terms for each epoch are saved in this directory.
15 | 7. To generate the results after training, use:
16 | python main.py --dataset=./Data/Example/example_dataset.h5 --name=example_model --mode=gen_B
17 | The generated synthetic images can be found in ./Models/<name_of_model>_gen_B.h5
18 |
19 | ### Parameters
20 |
21 | All parameters are of the shape: --<parameter_name>=<value>
22 | Below is the list of all possible parameters that can be set. The standard value used if the parameter is not defined is given in brackets
23 |
24 | name ('unnamed')
25 | Name of the model. This value should be unique to not load/overwrite old models. Its value must be changed to ensure functionality!
26 |
27 |
28 | dataset ('pathtodata.h5')
29 | Describes which h5 files is used. Its value must be changed to ensure functionality!
30 |
31 |
32 | architecture ('Res6')
33 | The network architecture for the generators. Currently, you can choose between 'Res6' and 'Res9', which corresponds to 6 and 9 residual layers, respectively.
34 |
35 |
36 | deconv ('transpose')
37 | Upsampling method used in the generators. You can either choose transpose CNNs ('transpose') or image resizing ('resize').
38 |
39 |
40 | PatchGAN ('Patch70')
41 | Different PatchGAN Architectures: 'Patch34', 'Patch70', or 'Patch142'. A mixture of these is possible: 'MultiPatch' (experimental).
42 |
43 |
44 | mode ('training')
45 | Decides what should be done with the network. You can either train it ('training'), or create generated images: 'gen_A' for raw images from synthetic images and 'gen_B' for synthetic images from raw images.
46 |
47 |
48 |
49 |
50 | dataset ('pathtodata.h5')
51 | Describes which h5 files is used. Its value must be changed to ensure functionality!
52 |
53 |
54 | lambda_c (10.)
55 | The loss multiplier of the cycle consistency term used while training the generators.
56 |
57 |
58 | lambda_h (1.)
59 | The loss multiplier of the histogram discriminators. If the histogram should not be used, set this term to 0.
60 |
61 |
62 | dis_noise (0.1)
63 | To make the network more stable, we added noise to the input of the discriminators, which slowly decays over time. This value describes how high the std of the gaussian noise is, which is added to the inputs.
64 |
65 |
66 | syn_noise (0.)
67 | It is possible to add gaussian noise to the synthetic dataset. Default: not used.
68 |
69 |
70 | real_noise (0.)
71 | It is possible to add gaussian noise to the real dataset. Default: not used.
72 |
73 |
74 | epoch (200)
75 | Number of training epochs.
76 |
77 |
78 | batch_size (4)
79 | Batch size during training.
80 |
81 |
82 | buffer_size (50)
83 | Size of the buffer (history) saved to train the discriminators. This makes the network more stable.
84 |
85 |
86 | save (1)
87 | If value is not 0, the network progress is saved at the end of each epoch.
88 |
89 |
90 | gpu (0)
91 | If multiple GPUs exist, this parameter choses which GPU should be used. Only one GPU can currently be used.
92 |
93 |
94 | verbose (0)
95 | If value is not 0, the network is more verbose.
96 |
97 |
--------------------------------------------------------------------------------
/Discriminator/PatchGAN70.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function, unicode_literals
2 |
3 | import tensorflow as tf
4 |
5 | class PatchGAN70:
6 | """
7 | This class is creating a PatchGAN discriminator as described by Zhu et al. 2018.
8 | -) save() - Save the current model parameter
9 | -) create() - Create the model layers (graph construction)
10 | -) init() - Initialize the model (load model if exists)
11 | -) load() - Load the parameters from the file
12 | -) run() - ToDo write this
13 |
14 | Only the following functions should be called from outside:
15 | -) create()
16 | -) constructor
17 | """
18 |
19 | def __init__(self,dis_name,noise=0.25):
20 | """
21 | Create a PatchGAN model (init). It will check, if a model with such a name has already been saved. If so, the model
22 | is being loaded. Otherwise, a new model with this name will be created. It will only be saved, if the save function
23 | is being called. The describtion of every parameter is given in the code below.
24 |
25 | INPUT: dis_name - This is the name of the discriminator. It is mainly used to establish the place, where the model
26 | is being saved.
27 |
28 | OUTPUT: - The model
29 | """
30 | self.dis_name = dis_name
31 | self.noise = noise
32 |
33 |
34 | def create(self,X,reuse=True):
35 |
36 | # C64
37 | # To add noise:
38 | self.C64_c = tf.layers.conv2d(tf.pad(X+tf.random_normal(tf.shape(X),0.,self.noise),[[0,0],[1,1],[1,1],[0,0]],"Reflect"),
39 | filters=64,
40 | kernel_size=4,
41 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
42 | strides=(2,2),
43 | padding='valid',
44 | reuse=reuse,
45 | name='dis_'+self.dis_name+'_70_conv_0')
46 | self.C64 = tf.nn.leaky_relu(self.C64_c)
47 |
48 | # C128
49 | self.C128_c = tf.layers.conv2d(tf.pad(self.C64,[[0,0],[1,1],[1,1],[0,0]],"Reflect"),
50 | filters=128,
51 | kernel_size=4,
52 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
53 | strides=(2,2),
54 | padding='valid',
55 | reuse=reuse,
56 | name='dis_'+self.dis_name+'_70_conv_1')
57 | self.C128_n = tf.contrib.layers.instance_norm(self.C128_c,reuse=reuse,scope='dis_'+self.dis_name+'_70_bnorm_1',trainable=False)
58 | self.C128 = tf.nn.leaky_relu(self.C128_n)
59 |
60 | # C256
61 | self.C256_c = tf.layers.conv2d(tf.pad(self.C128,[[0,0],[1,1],[1,1],[0,0]],"Reflect"),
62 | filters=256,
63 | kernel_size=4,
64 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
65 | strides=(2,2),
66 | padding='valid',
67 | reuse=reuse,
68 | name='dis_'+self.dis_name+'_70_conv_2')
69 | self.C256_n = tf.contrib.layers.instance_norm(self.C256_c,reuse=reuse,scope='dis_'+self.dis_name+'_70_bnorm_2',trainable=False)
70 | self.C256 = tf.nn.leaky_relu(self.C256_n)
71 |
72 | # C512
73 | self.C512_c = tf.layers.conv2d(tf.pad(self.C256,[[0,0],[1,1],[1,1],[0,0]],"Reflect"),
74 | filters=512,
75 | kernel_size=4,
76 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
77 | strides=(1,1),
78 | padding='valid',
79 | reuse=reuse,
80 | name='dis_'+self.dis_name+'_70_conv_3')
81 | self.C512_n = tf.contrib.layers.instance_norm(self.C512_c,reuse=reuse,scope='dis_'+self.dis_name+'_70_bnorm_3',trainable=False)
82 | self.C512 = tf.nn.leaky_relu(self.C512_n)
83 |
84 | # c1
85 | self.c1_c = tf.layers.conv2d(tf.pad(self.C512,[[0,0],[1,1],[1,1],[0,0]],"Reflect"),
86 | filters=1,
87 | kernel_size=4,
88 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
89 | strides=(1,1),
90 | padding='valid',
91 | reuse=reuse,
92 | name='dis_'+self.dis_name+'_70_conv_4')
93 | return self.c1_c
--------------------------------------------------------------------------------
/create_h5_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 |
3 | import numpy as np
4 | import h5py
5 | import cv2
6 | import os
7 | import sys
8 |
9 | def get_file_list(data_path):
10 | """
11 | This function returns the list of all png images in a given directory, such as their dimensions. It returns an error, if the image dimensions are not consistent.
12 |
13 | Arguments:
14 | data_path (string) Path to directory of the set of images to be extracted
15 |
16 | Returns:
17 | file_list (List of strings) Filenames of all png images in dataset
18 | dimensions (3 x integer) Dimensions of images: height, width, channels
19 | flag (boolean) Is true, iff the images are grayscale
20 | """
21 |
22 | # Create list of all png files
23 | file_list = []
24 | for element in os.listdir(data_path):
25 | if element[-4:] == ".png":
26 | file_list.append(data_path + element)
27 |
28 |
29 | # Compare sizes
30 | dimensions = cv2.imread(file_list[0],cv2.IMREAD_UNCHANGED).shape
31 | for i in range(1,len(file_list)):
32 | if not np.array_equal(dimensions,cv2.imread(file_list[i],cv2.IMREAD_UNCHANGED).shape):
33 | raise Exception('The following two images have different dimensions. Please make sure all images in this directory have the same size \n\r ' +\
34 | file_list[0] + '\n\r' +\
35 | file_list[i])
36 |
37 | # Add a 3rd value two dimensions, if it does not exist (this means it has 1 data channel)
38 | flag = False
39 | if len(dimensions) == 2:
40 | dimensions = np.array([dimensions[0],dimensions[1],1])
41 | flag = True
42 |
43 | return file_list,dimensions,flag
44 |
45 | def main():
46 | # Check if the right amount of arguments has been given to the program
47 | if len(sys.argv[1:]) != 3:
48 | print('This script recuires three arguments in order to work:')
49 | print('1: Path to directory containing the genuine/raw images (only png images in directory are used!)')
50 | print('2: Path to directory containing the synthetic images (only png images in directory are used!)')
51 | print('3: Output hdf5 filename')
52 | print(' ')
53 | print('Example: python create_h5_dataset.py ./Data/Example/Genuine/ ./Data/Example/Synthetic/ ./Data/Example/example_dataset.h5')
54 | print(' ')
55 | print('Script aborted')
56 | return -1
57 |
58 | # get the addresses
59 | raw_path = sys.argv[1]
60 | syn_path = sys.argv[2]
61 | filename = sys.argv[3]
62 |
63 | # Create the output hdf5 file
64 | f = h5py.File(filename,"w")
65 |
66 | # Save the raw dataset into the file
67 | raw_files, raw_dimensions, raw_flag = get_file_list(raw_path)
68 |
69 | num_samples = len(raw_files)
70 | num_channel = raw_dimensions[2]
71 |
72 | group = f.create_group('A')
73 | group.create_dataset(name='num_samples', data=num_samples)
74 | group.create_dataset(name='num_channel', data=num_channel)
75 | dtype = np.uint8
76 |
77 | data_A = np.zeros([num_samples,\
78 | raw_dimensions[0],\
79 | raw_dimensions[1],\
80 | num_channel], dtype=dtype)
81 |
82 | for idx,fname in enumerate(raw_files):
83 | if raw_flag: # This means, the images are gray scale
84 | data_A[idx,:,:,0] = np.array(cv2.imread(fname,cv2.IMREAD_GRAYSCALE))
85 | else:
86 | data_A[idx,:,:,:] = np.flip(np.array(cv2.imread(fname,cv2.IMREAD_COLOR)),2)
87 |
88 | print('Genuine dataset: ', group.create_dataset(name='data', data=(data_A),dtype=dtype))
89 |
90 |
91 |
92 |
93 |
94 | # Save the syn dataset into the file
95 | syn_files, syn_dimensions, syn_flag = get_file_list(syn_path)
96 |
97 | num_samples = len(syn_files)
98 | num_channel = syn_dimensions[2]
99 |
100 | group = f.create_group('B')
101 | group.create_dataset(name='num_samples', data=num_samples)
102 | group.create_dataset(name='num_channel', data=num_channel)
103 | dtype = np.uint8
104 |
105 | data_B = np.zeros([num_samples,\
106 | syn_dimensions[0],\
107 | syn_dimensions[1],\
108 | num_channel], dtype=dtype)
109 |
110 | for idx,fname in enumerate(syn_files):
111 | if syn_flag: # This means, the images are gray scale
112 | data_B[idx,:,:,0] = np.array(cv2.imread(fname,cv2.IMREAD_GRAYSCALE))
113 | else:
114 | data_B[idx,:,:,:] = np.flip(np.array(cv2.imread(fname,cv2.IMREAD_COLOR)),2)
115 |
116 | print('Synthetic dataset: ', group.create_dataset(name='data', data=(data_B),dtype=dtype))
117 |
118 |
119 | # Close the file
120 | f.close()
121 |
122 | if __name__ == "__main__":
123 | main()
124 |
--------------------------------------------------------------------------------
/notebooks/make_synthetic_wires.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import matplotlib\n",
11 | "import matplotlib.pyplot as plt\n",
12 | "%matplotlib inline\n",
13 | "import os\n",
14 | "import colorsys\n",
15 | "import imageio as io\n",
16 | "from scipy.ndimage import binary_dilation\n",
17 | "from skimage.filters import gaussian\n",
18 | "from numpy import random\n",
19 | "from scipy.stats import skewnorm"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "def gen_fake_img(imsize,length_var,width_var,rng_wire_bounds,rng_length_loc_scale,fix_length,fix_width):\n",
29 | " #image = np.zeros(( imsize,imsize, 3)) #replace with noise for stability\n",
30 | " image = np.random.normal(loc=0.1,scale=0.05,size=(imsize,imsize,3))\n",
31 | " saturation_map=np.random.normal(loc=0.8,scale=0.2,size=(imsize,imsize))\n",
32 | " saturation_map=gaussian(saturation_map,sigma=3) #introduce long range noise correlations\n",
33 | " hue_map=np.random.normal(loc=0.8,scale=0.2,size=(imsize,imsize))\n",
34 | " hue_map=gaussian(hue_map,sigma=2)\n",
35 | " n_wires=np.random.randint(rng_wire_bounds[0],rng_wire_bounds[1])\n",
36 | " for n in range(n_wires):\n",
37 | " if length_var:\n",
38 | " length=skewnorm.rvs(a=3,loc=rng_length_loc_scale[0],scale=rng_length_loc_scale[1],size=1)\n",
39 | " if not length_var:\n",
40 | " length=np.array([fix_length])\n",
41 | "\n",
42 | " x_loc = np.array([np.random.rand() * imsize]) \n",
43 | " y_loc = np.array(np.random.rand() * imsize) \n",
44 | " angle = np.random.random() * np.pi\n",
45 | " vec_x=np.cos(angle)\n",
46 | " vec_y=np.sin(angle)\n",
47 | " x_wires = (vec_x.reshape(1, 1) * length / 2. * np.linspace(-1, 1, 200) + x_loc.reshape(1,1)).astype(int)\n",
48 | " y_wires = (vec_y.reshape(1, 1) * length / 2. * np.linspace(-1, 1, 200) + y_loc.reshape(1,1)).astype(int)\n",
49 | " indices = (x_wires * imsize + y_wires).flatten() \n",
50 | " to_rem=np.where(indices>=imsize*imsize)[0] #make sure wire doesnt go out of image\n",
51 | " indices=[indices[i] for i in range(len(indices)) if i not in to_rem]\n",
52 | "\n",
53 | " wire = np.zeros([imsize*imsize], dtype=np.double)\n",
54 | " wire[indices] = 1\n",
55 | " wire = wire.reshape( imsize, imsize)\n",
56 | "\n",
57 | " if width_var:\n",
58 | " it=np.random.choice(np.arange(1,6),p=[0.7,0.2,0.07,0.02,0.01])\n",
59 | " if not width_var:\n",
60 | " it=fix_width\n",
61 | " wire = binary_dilation(wire,iterations=it)\n",
62 | " wire=wire.astype(bool)\n",
63 | " image[:,:,0][wire]=angle/np.pi #hue\n",
64 | " image[:,:,1][wire]=hue_map[wire] #saturation\n",
65 | " image[:,:,2][wire]=saturation_map[wire] #brightness\n",
66 | " image=np.clip(image,0,1)\n",
67 | " image=matplotlib.colors.hsv_to_rgb(image) #Better visualization, also probably better use of all 3 channels (hsv varies only little in channel 2 and 3)\n",
68 | " return image"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "params={'imsize': 256,\n",
78 | " 'length_var': True,\n",
79 | " 'width_var': True,\n",
80 | " 'rng_wire_bounds': [10,40],\n",
81 | " 'rng_length_loc_scale': [20,60],\n",
82 | " 'fix_length': 50,\n",
83 | " 'fix_width': 1}\n",
84 | "N_images=500"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "example=gen_fake_img(**params)\n",
94 | "plt.imshow(example)"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "directory='../Data/Nanowire/Synthetic/'"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": null,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "for i in range(N_images):\n",
113 | " name=directory+str(i).zfill(5)+'.png'\n",
114 | " image=(gen_fake_img(**params)*255).astype(np.uint8) #map to 0-255,8 bit image\n",
115 | " io.imsave(name,image)\n",
116 | " if i%50==0:\n",
117 | " print(i)"
118 | ]
119 | }
120 | ],
121 | "metadata": {
122 | "kernelspec": {
123 | "display_name": "Python 2",
124 | "language": "python",
125 | "name": "python2"
126 | },
127 | "language_info": {
128 | "codemirror_mode": {
129 | "name": "ipython",
130 | "version": 2
131 | },
132 | "file_extension": ".py",
133 | "mimetype": "text/x-python",
134 | "name": "python",
135 | "nbconvert_exporter": "python",
136 | "pygments_lexer": "ipython2",
137 | "version": "2.7.15"
138 | }
139 | },
140 | "nbformat": 4,
141 | "nbformat_minor": 2
142 | }
143 |
--------------------------------------------------------------------------------
/Discriminator/PatchGAN142.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function, unicode_literals
2 |
3 | import tensorflow as tf
4 |
5 | class PatchGAN142:
6 | """
7 | This class is creating a PatchGAN discriminator as described by Zhu et al. 2018.
8 | -) save() - Save the current model parameter
9 | -) create() - Create the model layers (graph construction)
10 | -) init() - Initialize the model (load model if exists)
11 | -) load() - Load the parameters from the file
12 | -) run() - ToDo write this
13 |
14 | Only the following functions should be called from outside:
15 | -) create()
16 | -) constructor
17 | """
18 |
19 | def __init__(self,dis_name,noise=0.25):
20 | """
21 | Create a PatchGAN model (init). It will check, if a model with such a name has already been saved. If so, the model
22 | is being loaded. Otherwise, a new model with this name will be created. It will only be saved, if the save function
23 | is being called. The describtion of every parameter is given in the code below.
24 |
25 | INPUT: dis_name - This is the name of the discriminator. It is mainly used to establish the place, where the model
26 | is being saved.
27 |
28 | OUTPUT: - The model
29 | """
30 | self.dis_name = dis_name
31 | self.noise = noise
32 |
33 |
34 | def create(self,X,reuse=True):
35 |
36 | # C32
37 | # To add noise:
38 | self.C32_c = tf.layers.conv2d(tf.pad(X+tf.random_normal(tf.shape(X),0.,self.noise),[[0,0],[1,1],[1,1],[0,0]],"Reflect"),
39 | filters=32,
40 | kernel_size=4,
41 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
42 | strides=(2,2),
43 | padding='valid',
44 | reuse=reuse,
45 | name='dis_'+self.dis_name+'_142_conv_0')
46 | self.C32 = tf.nn.leaky_relu(self.C32_c)
47 |
48 | # C128
49 | self.C64_c = tf.layers.conv2d(tf.pad(self.C32,[[0,0],[1,1],[1,1],[0,0]],"Reflect"),
50 | filters=64,
51 | kernel_size=4,
52 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
53 | strides=(2,2),
54 | padding='valid',
55 | reuse=reuse,
56 | name='dis_'+self.dis_name+'_142_conv_1')
57 | self.C64_n = tf.contrib.layers.instance_norm(self.C64_c,reuse=reuse,scope='dis_'+self.dis_name+'_142_bnorm_1',trainable=False)
58 | self.C64 = tf.nn.leaky_relu(self.C64_n)
59 |
60 | # C128
61 | self.C128_c = tf.layers.conv2d(tf.pad(self.C64,[[0,0],[1,1],[1,1],[0,0]],"Reflect"),
62 | filters=128,
63 | kernel_size=4,
64 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
65 | strides=(2,2),
66 | padding='valid',
67 | reuse=reuse,
68 | name='dis_'+self.dis_name+'_142_conv_2')
69 | self.C128_n = tf.contrib.layers.instance_norm(self.C128_c,reuse=reuse,scope='dis_'+self.dis_name+'_142_bnorm_2',trainable=False)
70 | self.C128 = tf.nn.leaky_relu(self.C128_n)
71 |
72 | # C256
73 | self.C256_c = tf.layers.conv2d(tf.pad(self.C128,[[0,0],[1,1],[1,1],[0,0]],"Reflect"),
74 | filters=256,
75 | kernel_size=4,
76 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
77 | strides=(2,2),
78 | padding='valid',
79 | reuse=reuse,
80 | name='dis_'+self.dis_name+'_142_conv_3')
81 | self.C256_n = tf.contrib.layers.instance_norm(self.C256_c,reuse=reuse,scope='dis_'+self.dis_name+'_142_bnorm_3',trainable=False)
82 | self.C256 = tf.nn.leaky_relu(self.C256_n)
83 |
84 | # C512
85 | self.C512_c = tf.layers.conv2d(tf.pad(self.C256,[[0,0],[1,1],[1,1],[0,0]],"Reflect"),
86 | filters=512,
87 | kernel_size=4,
88 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
89 | strides=(1,1),
90 | padding='valid',
91 | reuse=reuse,
92 | name='dis_'+self.dis_name+'_142_conv_4')
93 | self.C512_n = tf.contrib.layers.instance_norm(self.C512_c,reuse=reuse,scope='dis_'+self.dis_name+'_142_bnorm_4',trainable=False)
94 | self.C512 = tf.nn.leaky_relu(self.C512_n)
95 |
96 | # c1
97 | self.c1_c = tf.layers.conv2d(tf.pad(self.C512,[[0,0],[1,1],[1,1],[0,0]],"Reflect"),
98 | filters=1,
99 | kernel_size=4,
100 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
101 | strides=(1,1),
102 | padding='valid',
103 | reuse=reuse,
104 | name='dis_'+self.dis_name+'_142_conv_5')
105 | return self.c1_c
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import cycleGAN
2 | import re
3 | import sys
4 | from os import environ as cuda_environment
5 | import os
6 | import numpy as np
7 |
8 | if __name__ == "__main__":
9 | # List of floats
10 | sub_value_f = {}
11 | sub_value_f['lambda_c'] = 10. # Loss multiplier for cycle
12 | sub_value_f['lambda_h'] = 1. # Loss multiplier for histogram
13 | sub_value_f['dis_noise'] = 0.1 # Std of gauss noise added to Dis
14 | sub_value_f['syn_noise'] = 0. # Add gaussian noise to syn images to make non-flat backgrounds
15 | sub_value_f['real_noise'] = 0. # Add gaussian noise to real images to make non-flat backgrounds
16 |
17 | # List of ints
18 | sub_value_i = {}
19 | sub_value_i['epoch'] = 200 # Number of epochs to be trained
20 | sub_value_i['batch_size'] = 4 # Batch size for training
21 | sub_value_i['buffer_size'] = 50 # Number of history elements used for Dis
22 | sub_value_i['save'] = 1 # If not 0, model is saved
23 | sub_value_i['gpu'] = 0 # Choose the GPU ID (if only CPU training, choose nonexistent number)
24 | sub_value_i['verbose'] = 0 # If not 0, some network information is being plotted
25 |
26 | # List of strings
27 | sub_string = {}
28 | sub_string['name'] = 'unnamed' # Name of model (should be unique). Is used to save/load models
29 | sub_string['dataset'] = 'pathtodata.h5' # Describes which h5 file is used
30 | sub_string['architecture'] = 'Res6' # Network architecture: 'Res6' or 'Res9'
31 | sub_string['deconv'] = 'transpose' # Upsampling method: 'transpose' or 'resize'
32 | sub_string['PatchGAN'] = 'Patch70' # Choose the Gan type: 'Patch34', 'Patch70', 'Patch142', 'MultiPatch'
33 | sub_string['mode'] = 'training' # 'train', 'gen_A', 'gen_B'
34 |
35 | # Create complete dictonary
36 | var_dict = sub_string.copy()
37 | var_dict.update(sub_value_i)
38 | var_dict.update(sub_value_f)
39 |
40 | # Update all defined parameters in dictionary
41 | for arg_i in sys.argv[1:]:
42 | var = re.search('(.*)\=', arg_i) # everything before the '='
43 | g_var = var.group(1)[2:]
44 | if g_var in sub_value_i:
45 | dtype = 'int'
46 | elif g_var in sub_value_f:
47 | dtype = 'float'
48 | elif g_var in sub_string:
49 | dtype = 'string'
50 | else:
51 | print("Unknown key word: " + g_var)
52 | print("Write parameters as: =")
53 | print("Example: 'python main.py buffer_size=32'")
54 | print("Possible key words: " + str(var_dict.keys()))
55 | continue
56 |
57 | content = re.search('\=(.*)',arg_i) # everything after the '='
58 | g_content = content.group(1)
59 | if dtype == 'int':
60 | var_dict[g_var] = int(g_content)
61 | elif dtype == 'float':
62 | var_dict[g_var] = float(g_content)
63 | else:
64 | var_dict[g_var] = g_content
65 | if not os.path.isfile(var_dict['dataset']):
66 | raise ValueError('Dataset does not exist. Specify loation of an existing h5 file.')
67 | # Get the dataset filename
68 |
69 |
70 | # Restrict usage of GPUs
71 | cuda_environment["CUDA_VISIBLE_DEVICES"]=str(var_dict['gpu'])
72 | with open('Models/'+var_dict['name']+"_params.txt", "w") as myfile:
73 | for key in sorted(var_dict):
74 | myfile.write(key + "," + str(var_dict[key]) + "\n")
75 |
76 | # Find out, if whole network is needed or only the generators
77 | gen_only = False
78 | if 'gen' in var_dict['mode']:
79 | gen_only = True
80 |
81 | # Define the model
82 | model = cycleGAN.Model(\
83 | mod_name=var_dict['name'],\
84 | data_file=var_dict['dataset'],\
85 | buffer_size=var_dict['buffer_size'],\
86 | dis_noise=var_dict['dis_noise'],\
87 | architecture=var_dict['architecture'],\
88 | lambda_c=var_dict['lambda_c'],\
89 | lambda_h=var_dict['lambda_h'],\
90 | deconv=var_dict['deconv'],\
91 | patchgan=var_dict['PatchGAN'],\
92 | verbose=(var_dict['verbose']!=0),\
93 | gen_only=gen_only)
94 |
95 | # Plot parameter properties, if applicable
96 | if var_dict['verbose']:
97 | # Print the number of parameters
98 | model.print_count_variables()
99 | model.print_train_and_not_train_variables()
100 |
101 | # Create a graph file
102 | model.save_graph()
103 |
104 | elif var_dict['mode'] == 'training':
105 | # Train the model
106 | loss_gen_A = []
107 | loss_gen_B = []
108 | loss_dis_A = []
109 | loss_dis_B = []
110 |
111 | for i in range(var_dict['epoch']):
112 | print('')
113 | print('Epoch: ' + str(i+1))
114 | print('')
115 | lgA,lgB,ldA,ldB = \
116 | model.train(batch_size=var_dict['batch_size'],\
117 | lambda_c=var_dict['lambda_c'],\
118 | lambda_h=var_dict['lambda_h'],\
119 | save=bool(var_dict['save']),\
120 | epoch=i,\
121 | syn_noise=var_dict['syn_noise'],\
122 | real_noise=var_dict['real_noise'])
123 | loss_gen_A.append(lgA)
124 | loss_gen_B.append(lgB)
125 | loss_dis_A.append(ldA)
126 | loss_dis_B.append(ldB)
127 | np.save("./Models/" + var_dict['name'] + '_loss_gen_A.npy',np.array(loss_gen_A).T)
128 | np.save("./Models/" + var_dict['name'] + '_loss_gen_B.npy',np.array(loss_gen_B).T)
129 | np.save("./Models/" + var_dict['name'] + '_loss_dis_A.npy',np.array(loss_dis_A).T)
130 | np.save("./Models/" + var_dict['name'] + '_loss_dis_B.npy',np.array(loss_dis_B).T)
131 |
132 | elif var_dict['mode'] == 'gen_A':
133 | model.generator_A(batch_size=var_dict['batch_size'],\
134 | lambda_c=var_dict['lambda_c'],\
135 | lambda_h=var_dict['lambda_h'])
136 |
137 | elif var_dict['mode'] == 'gen_B':
138 | model.generator_B(batch_size=var_dict['batch_size'],\
139 | lambda_c=var_dict['lambda_c'],\
140 | lambda_h=var_dict['lambda_h'])
141 |
--------------------------------------------------------------------------------
/Generator/Res_Gen.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function, unicode_literals
2 |
3 | import tensorflow as tf
4 |
5 | class ResGen:
6 | """
7 | This class is creating a ResNet generator as described by Zhu et al. 2018.
8 | -) save() - Save the current model parameter
9 | -) create() - Create the model layers (graph construction)
10 | -) init() - Initialize the model (load model if exists)
11 | -) load() - Load the parameters from the file
12 | -) run() - ToDo write this
13 |
14 | Only the following functions should be called from outside:
15 | -) create()
16 | -) constructor
17 | """
18 |
19 | def __init__(self,
20 | gen_name,
21 | out_dim,
22 | gen_dim= [32, 64,128, 128,128,128,128,128,128, 64,32 ],
23 | kernel_size=[7, 3,3, 3,3,3,3,3,3, 3,3, 7],
24 | deconv='transpose',
25 | verbose=False):
26 |
27 | # gen_dim= [64, 128,256, 256,256,256,256,256,256,256,256,256, 128,64 ],
28 | # kernel_size=[7, 3,3, 3,3,3,3,3,3,3,3,3, 3,3, 7]):
29 | """
30 | Create a generator model (init). It will check, if a model with such a name has already been saved. If so, the model
31 | is being loaded. Otherwise, a new model with this name will be created. It will only be saved, if the save function
32 | is being called. The describtion of every parameter is given in the code below.
33 |
34 | INPUT: gen_name - This is the name of the generator. It is mainly used to establish the place, where the model
35 | is being saved.
36 | gen_dim - The number of channels in every layer. The first two elements are stride-2 convolutional layers.
37 | The last two layers (1 value) are fractional stride 1/2 convolutional layers. Everything in between
38 | is a ResNet layer.
39 |
40 | kernel_size - The kernel sizes of all layers. The dimension of this list must be the same as of gen_dim + 1.
41 |
42 | OUTPUT: - The model
43 | """
44 | self.gen_name = gen_name
45 | self.out_dim = out_dim
46 | self.gen_dim = gen_dim
47 | self.kernel_size = kernel_size
48 | self.deconv = deconv
49 | self.verbose = verbose
50 |
51 | if len(gen_dim) + 1 != len(kernel_size):
52 | raise NameError('The dimensions of the ResGenerator are wrong')
53 |
54 |
55 |
56 | def create(self,X,reuse=True):
57 | num_layers = len(self.kernel_size)
58 |
59 | layer_list = []
60 | layer_list.append(X)
61 |
62 | if self.verbose:
63 | print('-------------------------------')
64 | print(' ')
65 | print('Create generator ' + self.gen_name)
66 | print(' ')
67 | print('Number of layers: ' + str(num_layers))
68 | print('Input diminesion: ' + str(tf.shape(X)))
69 | print(' ')
70 |
71 | for i in range(num_layers):
72 | if (i < (num_layers-3)) or (i==(num_layers - 1)):
73 | ps = int((self.kernel_size[i]-1)/2) # pad size
74 | new_pad = tf.pad(layer_list[-1],[[0,0],[ps,ps],[ps,ps],[0,0]],"Reflect")
75 | if self.verbose:
76 | print('- - - - - - - - - - - - - - - -')
77 | print('Load last layer: Do padding with size ' + str(ps))
78 | else:
79 | new_pad = layer_list[-1]
80 | if self.verbose:
81 | print('- - - - - - - - - - - - - - - -')
82 | print('Load last layer: No padding')
83 | if i==0 or i==(num_layers - 1):
84 | if i==0:
85 | filters = self.gen_dim[i]
86 | else:
87 | filters = self.out_dim
88 | new_conv = tf.layers.conv2d(new_pad,
89 | filters=filters,
90 | kernel_size=self.kernel_size[i],
91 | strides=(1,1),
92 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
93 | padding='valid',
94 | reuse=reuse,
95 | name='gen_'+self.gen_name+'_conv_'+str(i))
96 | if self.verbose:
97 | print('Conv layer stride 1 with kernel size ' + str(self.kernel_size[i]) + ' and number of filters ' + str(filters))
98 | elif i < 3:
99 | # Conv layers
100 | new_conv = tf.layers.conv2d(new_pad,
101 | filters=self.gen_dim[i],
102 | kernel_size=self.kernel_size[i],
103 | strides=(2,2),
104 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
105 | padding='valid',
106 | reuse=reuse,
107 | name='gen_'+self.gen_name+'_conv_'+str(i))
108 | if self.verbose:
109 | print('Conv layer stride 2 with kernel size ' + str(self.kernel_size[i]) + ' and number of filters ' + str(self.gen_dim[i]))
110 | elif i < num_layers-3:
111 | # Res layers
112 | new_conv_0 = tf.layers.conv2d(new_pad,
113 | filters=self.gen_dim[i],
114 | kernel_size=self.kernel_size[i],
115 | strides=(1,1),
116 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
117 | padding='valid',
118 | reuse=reuse,
119 | name='gen_'+self.gen_name+'_conv0_'+str(i))
120 | # new_norm_0 = tf.contrib.layers.instance_norm(new_conv_0,reuse=reuse,scale=False,center=False,scope='gen_'+self.gen_name+'_bnorm0_'+str(i),trainable=False)
121 | new_mean_0, new_var_0 = tf.nn.moments(new_conv_0,axes=(1,2),keep_dims=True)
122 | new_norm_0 = (new_conv_0 - new_mean_0) / tf.sqrt(new_var_0 + 1e-5)
123 | new_layer_0= tf.nn.relu(new_norm_0)
124 | new_pad_0 = tf.pad(new_layer_0,[[0,0],[ps,ps],[ps,ps],[0,0]],"Reflect")
125 |
126 | new_conv = tf.layers.conv2d(new_pad_0,
127 | filters=self.gen_dim[i],
128 | kernel_size=self.kernel_size[i],
129 | strides=(1,1),
130 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
131 | padding='valid',
132 | reuse=reuse,
133 | name='gen_'+self.gen_name+'_conv_'+str(i))
134 |
135 | if self.verbose:
136 | print('Conv layer stride 1 with kernel size ' + str(self.kernel_size[i]) + ' and number of filters ' + str(self.gen_dim[i]))
137 | print('Instance Normalization')
138 | print('Relu')
139 | print('Padding with pad size of ' + str(ps))
140 | print('Conv layer stride 1 with kernel size ' + str(self.kernel_size[i]) + ' and number of filters ' + str(self.gen_dim[i]))
141 | else:
142 | # Deconv layers
143 | if self.deconv == 'transpose':
144 | new_conv = tf.layers.conv2d_transpose(new_pad,
145 | filters=self.gen_dim[i],
146 | kernel_size=self.kernel_size[i],
147 | strides=(2,2),
148 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
149 | padding='same',
150 | reuse=reuse,
151 | name='gen_'+self.gen_name+'_conv_'+str(i))
152 |
153 | if self.verbose:
154 | print('Conv transpose layer stride 2 with kernel size ' + str(self.kernel_size[i]) + ' and number of filters ' + str(self.gen_dim[i]))
155 | elif self.deconv == 'resize':
156 | if i == num_layers-3:
157 | new_resize = tf.image.resize_images(new_pad,[tf.cast(tf.shape(X)[1]/2,dtype=tf.int32),\
158 | tf.cast(tf.shape(X)[2]/2,dtype=tf.int32)])
159 |
160 | else:
161 | new_resize = tf.image.resize_images(new_pad,[tf.shape(X)[1],tf.shape(X)[2]])
162 | ps = int((self.kernel_size[i]-1)/2) # pad size
163 | new_pad_0 = tf.pad(new_resize,[[0,0],[ps,ps],[ps,ps],[0,0]],"Reflect")
164 | new_conv = tf.layers.conv2d(new_pad_0,
165 | filters=self.gen_dim[i],
166 | kernel_size=self.kernel_size[i],
167 | strides=(1,1),
168 | kernel_initializer=tf.initializers.random_normal(stddev=0.02),
169 | padding='valid',
170 | reuse=reuse,
171 | name='gen_'+self.gen_name+'_conv_'+str(i))
172 | if self.verbose:
173 | print('Resize image')
174 | print('Padding with pad size of ' + str(ps))
175 | print('Conv layer stride 1 with kernel size ' + str(self.kernel_size[i]) + ' and number of filters ' + str(self.gen_dim[i]))
176 | else:
177 | print('Unknown deconvolution method')
178 | if i < num_layers - 1:
179 | # new_norm = tf.contrib.layers.instance_norm(new_conv,reuse=reuse,scale=False,center=False,\
180 | # scope='gen_'+self.gen_name+'_bnorm_'+str(i),trainable=False)
181 | new_mean, new_var = tf.nn.moments(new_conv,axes=(1,2),keep_dims=True)
182 | new_norm = (new_conv - new_mean) / tf.sqrt(new_var + 1e-5)
183 |
184 | if self.verbose:
185 | print('Instance normalization')
186 | else:
187 | new_norm = new_conv
188 |
189 | if i>=3 and i < num_layers-3:
190 | new_layer = new_norm + layer_list[-1]
191 | if self.verbose:
192 | print('Make residual layer (linear activation function)')
193 | elif i < num_layers-1:
194 | new_layer = tf.nn.relu(new_norm)
195 | if self.verbose:
196 | print('ReLu')
197 | else:
198 | self.bef_layer = new_norm
199 | new_layer = (tf.nn.tanh(new_norm)+1)/2.
200 | if self.verbose:
201 | print('[tanh(x)+1]/2')
202 |
203 | layer_list.append(new_layer)
204 | if self.verbose:
205 | print(' ')
206 | print('Final layer: ',new_layer)
207 |
208 | self.layer_list = layer_list
209 | return new_layer
210 |
--------------------------------------------------------------------------------
/cycleGAN.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function, unicode_literals
2 |
3 |
4 | import tensorflow as tf
5 |
6 | import h5py
7 | import numpy as np
8 |
9 | import os
10 |
11 | import sys
12 | sys.path.append('./Discriminator')
13 | sys.path.append('./Generator')
14 | sys.path.append('./Utilities/')
15 | import Res_Gen
16 | import PatchGAN34
17 | import PatchGAN70
18 | import PatchGAN142
19 | import MultiPatch
20 | import HisDis
21 | import Utilities
22 | import cv2
23 | class Model:
24 | """
25 | ToDo
26 | -) save() - Save the current model parameter
27 | -) create() - Create the model layers
28 | -) init() - Initialize the model (load model if exists)
29 | -) load() - Load the parameters from the file
30 | -) ToDo
31 |
32 | Only the following functions should be called from outside:
33 | -) ToDo
34 | -) constructor
35 | """
36 |
37 | def __init__(self,
38 | mod_name,
39 | data_file,
40 | buffer_size=32,
41 | architecture='Res6',
42 | lambda_h=10.,\
43 | lambda_c=10.,\
44 | dis_noise=0.25,\
45 | deconv='transpose',\
46 | patchgan='Patch70',\
47 | verbose=False,\
48 | gen_only=False):
49 | """
50 | Create a Model (init). It will check, if a model with such a name has already been saved. If so, the model is being
51 | loaded. Otherwise, a new model with this name will be created. It will only be saved, if the save function is being
52 | called. The describtion of every parameter is given in the code below.
53 |
54 | INPUT: mod_name - This is the name of the model. It is mainly used to establish the place, where the model is being
55 | saved.
56 | data_file - hdf5 file that contains the dataset
57 | imsize - The dimension of the input images
58 |
59 | OUTPUT: - The model
60 | """
61 |
62 | self.mod_name = mod_name # Model name (see above)
63 |
64 | self.data_file = data_file # hdf5 data file
65 |
66 | f = h5py.File(self.data_file,"r")
67 | self.a_chan = int(np.array(f['A/num_channel'])) # Number channels in A
68 | self.b_chan = int(np.array(f['B/num_channel'])) # Number channels in B
69 | self.imsize = int(np.shape(f['A/data'][0,:,0,0])[0]) # Image size (squared)
70 | self.a_size = int(np.array(f['A/num_samples'])) # Number of samples in A
71 | self.b_size = int(np.array(f['B/num_samples'])) # Number of samples in B
72 | f.close()
73 |
74 | # Reset all current saved tf stuff
75 | tf.reset_default_graph()
76 |
77 | self.architecture = architecture
78 | self.lambda_h = lambda_h
79 | self.lambda_c = lambda_c
80 | self.dis_noise_0 = dis_noise # ATTENTION: Name change from dis_noise to dis_noise_0
81 | self.deconv = deconv
82 | self.patchgan = patchgan
83 | self.verbose = verbose
84 | self.gen_only = gen_only # If true, only the generator are used (and loaded)
85 |
86 | # Create the model that is built out of two discriminators and a generator
87 | self.create()
88 |
89 | # Image buffer
90 | self.buffer_size = buffer_size
91 | self.temp_b_s = 0.
92 | self.buffer_real_a = np.zeros([self.buffer_size,self.imsize,self.imsize,self.a_chan])
93 | self.buffer_real_b = np.zeros([self.buffer_size,self.imsize,self.imsize,self.b_chan])
94 | self.buffer_fake_a = np.zeros([self.buffer_size,self.imsize,self.imsize,self.a_chan])
95 | self.buffer_fake_b = np.zeros([self.buffer_size,self.imsize,self.imsize,self.b_chan])
96 |
97 | # Create the model saver
98 | with self.graph.as_default():
99 | if not self.gen_only:
100 | self.saver = tf.train.Saver()
101 | else:
102 | self.saver = tf.train.Saver(var_list=self.list_gen)
103 |
104 | def create(self):
105 | """
106 | Create the model. ToDo
107 | """
108 | # Create a graph and add all layers
109 | self.graph = tf.Graph()
110 | with self.graph.as_default():
111 | # Define variable learning rate and dis_noise
112 | self.relative_lr = tf.placeholder_with_default([1.],[1],name="relative_lr")
113 | self.relative_lr = self.relative_lr[0]
114 |
115 | self.rel_dis_noise = tf.placeholder_with_default([1.],[1],name="rel_dis_noise")
116 | self.rel_dis_noise = self.rel_dis_noise[0]
117 | self.dis_noise = self.rel_dis_noise * self.dis_noise_0
118 |
119 |
120 | # Create the generator and discriminator
121 | if self.architecture == 'Res6':
122 | gen_dim = [64, 128,256, 256,256,256,256,256,256, 128,64 ]
123 | kernel_size =[7, 3,3, 3,3,3,3,3,3, 3,3, 7]
124 | elif self.architecture == 'Res9':
125 | gen_dim= [64, 128,256, 256,256,256,256,256,256,256,256,256, 128,64 ]
126 | kernel_size=[7, 3,3, 3,3,3,3,3,3,3,3,3, 3,3, 7]
127 | else:
128 | print('Unknown generator architecture')
129 | return None
130 |
131 | self.genA = Res_Gen.ResGen('BtoA',self.a_chan,gen_dim=gen_dim,kernel_size=kernel_size,deconv=self.deconv,verbose=self.verbose)
132 | self.genB = Res_Gen.ResGen('AtoB',self.b_chan,gen_dim=gen_dim,kernel_size=kernel_size,deconv=self.deconv,verbose=self.verbose)
133 |
134 | if self.patchgan == 'Patch34':
135 | self.disA = PatchGAN34.PatchGAN34('A',noise=self.dis_noise)
136 | self.disB = PatchGAN34.PatchGAN34('B',noise=self.dis_noise)
137 | elif self.patchgan == 'Patch70':
138 | self.disA = PatchGAN70.PatchGAN70('A',noise=self.dis_noise)
139 | self.disB = PatchGAN70.PatchGAN70('B',noise=self.dis_noise)
140 | elif self.patchgan == 'Patch142':
141 | self.disA = PatchGAN142.PatchGAN142('A',noise=self.dis_noise)
142 | self.disB = PatchGAN142.PatchGAN142('B',noise=self.dis_noise)
143 | elif self.patchgan == 'MultiPatch':
144 | self.disA = MultiPatch.MultiPatch('A',noise=self.dis_noise)
145 | self.disB = MultiPatch.MultiPatch('B',noise=self.dis_noise)
146 | else:
147 | print('Unknown Patch discriminator type')
148 | return None
149 |
150 | self.disA_His = HisDis.HisDis('A',noise=self.dis_noise,keep_prob=1.)
151 | self.disB_His = HisDis.HisDis('B',noise=self.dis_noise,keep_prob=1.)
152 |
153 | # Create a placeholder for the input data
154 | self.A = tf.placeholder(tf.float32,[None, None, None, self.a_chan],name="a")
155 | self.B = tf.placeholder(tf.float32,[None, None, None, self.b_chan],name="b")
156 |
157 | if self.verbose:
158 | print('Size A: ' +str(self.a_chan)) # Often 1 --> Real
159 | print('Size B: ' +str(self.b_chan)) # Often 3 --> Syn
160 |
161 | # Create cycleGAN
162 |
163 | self.fake_A = self.genA.create(self.B,False)
164 | self.fake_B = self.genB.create(self.A,False)
165 |
166 |
167 |
168 | # Define the histogram loss
169 | t_A = tf.transpose(tf.reshape(self.A,[-1, self.a_chan]),[1,0])
170 | t_B = tf.transpose(tf.reshape(self.B,[-1, self.b_chan]),[1,0])
171 | t_fake_A = tf.transpose(tf.reshape(self.fake_A,[-1, self.a_chan]),[1,0])
172 | t_fake_B = tf.transpose(tf.reshape(self.fake_B,[-1, self.b_chan]),[1,0])
173 |
174 | self.s_A,_ = tf.nn.top_k(t_A,tf.shape(t_A)[1])
175 | self.s_B,_ = tf.nn.top_k(t_B,tf.shape(t_B)[1])
176 | self.s_fake_A,_ = tf.nn.top_k(t_fake_A,tf.shape(t_fake_A)[1])
177 | self.s_fake_B,_ = tf.nn.top_k(t_fake_B,tf.shape(t_fake_B)[1])
178 |
179 | self.m_A = tf.reshape(tf.reduce_mean(tf.reshape(self.s_A,[self.a_chan, self.imsize, -1]),axis=2),[1, -1])
180 | self.m_B = tf.reshape(tf.reduce_mean(tf.reshape(self.s_B,[self.b_chan, self.imsize, -1]),axis=2),[1, -1])
181 | self.m_fake_A = tf.reshape(tf.reduce_mean(tf.reshape(self.s_fake_A,[self.a_chan, self.imsize, -1]),axis=2),[1, -1])
182 | self.m_fake_B = tf.reshape(tf.reduce_mean(tf.reshape(self.s_fake_B,[self.b_chan, self.imsize, -1]),axis=2),[1, -1])
183 |
184 | # Define generator loss functions
185 | self.lambda_c = tf.placeholder_with_default([self.lambda_c],[1],name="lambda_c")
186 | self.lambda_c = self.lambda_c[0]
187 | self.lambda_h = tf.placeholder_with_default([self.lambda_h],[1],name="lambda_h")
188 | self.lambda_h = self.lambda_h[0]
189 |
190 | self.dis_real_A = self.disA.create(self.A,False)
191 | self.dis_real_Ah = self.disA_His.create(self.m_A,False)
192 | self.dis_real_B = self.disB.create(self.B,False)
193 | self.dis_real_Bh = self.disB_His.create(self.m_B,False)
194 | self.dis_fake_A = self.disA.create(self.fake_A,True)
195 | self.dis_fake_Ah = self.disA_His.create(self.m_fake_A,True)
196 | self.dis_fake_B = self.disB.create(self.fake_B,True)
197 | self.dis_fake_Bh = self.disB_His.create(self.m_fake_B,True)
198 |
199 | self.cyc_A = self.genA.create(self.fake_B,True)
200 | self.cyc_B = self.genB.create(self.fake_A,True)
201 |
202 |
203 | # Define cycle loss (eq. 2)
204 | self.loss_cyc_A = tf.reduce_mean(tf.abs(self.cyc_A-self.A))
205 | self.loss_cyc_B = tf.reduce_mean(tf.abs(self.cyc_B-self.B))
206 |
207 | self.loss_cyc = self.loss_cyc_A + self.loss_cyc_B
208 |
209 | # Define discriminator losses (eq. 1)
210 | self.loss_dis_A = (tf.reduce_mean(tf.square(self.dis_real_A)) +\
211 | tf.reduce_mean(tf.square(1-self.dis_fake_A)))*0.5 +\
212 | (tf.reduce_mean(tf.square(self.dis_real_Ah)) +\
213 | tf.reduce_mean(tf.square(1-self.dis_fake_Ah)))*0.5*self.lambda_h
214 |
215 |
216 | self.loss_dis_B = (tf.reduce_mean(tf.square(self.dis_real_B)) +\
217 | tf.reduce_mean(tf.square(1-self.dis_fake_B)))*0.5 +\
218 | (tf.reduce_mean(tf.square(self.dis_real_Bh)) +\
219 | tf.reduce_mean(tf.square(1-self.dis_fake_Bh)))*0.5*self.lambda_h
220 |
221 | self.loss_gen_A = tf.reduce_mean(tf.square(self.dis_fake_A)) +\
222 | self.lambda_h * tf.reduce_mean(tf.square(self.dis_fake_Ah)) +\
223 | self.lambda_c * self.loss_cyc/2.
224 | self.loss_gen_B = tf.reduce_mean(tf.square(self.dis_fake_B)) +\
225 | self.lambda_h * tf.reduce_mean(tf.square(self.dis_fake_Bh)) +\
226 | self.lambda_c * self.loss_cyc/2.
227 |
228 | # Create the different optimizer
229 | with self.graph.as_default():
230 | # Optimizer for Gen
231 | self.list_gen = []
232 | for var in tf.trainable_variables():
233 | if 'gen' in str(var):
234 | self.list_gen.append(var)
235 | optimizer_gen = tf.train.AdamOptimizer(learning_rate=self.relative_lr*0.0002,beta1=0.5)
236 | self.opt_gen = optimizer_gen.minimize(self.loss_gen_A+self.loss_gen_B,var_list=self.list_gen)
237 |
238 | # Optimizer for Dis
239 | self.list_dis = []
240 | for var in tf.trainable_variables():
241 | if 'dis' in str(var):
242 | self.list_dis.append(var)
243 | optimizer_dis = tf.train.AdamOptimizer(learning_rate=self.relative_lr*0.0002,beta1=0.5)
244 | self.opt_dis = optimizer_dis.minimize(self.loss_dis_A + self.loss_dis_B,var_list=self.list_dis)
245 |
246 | def save(self,sess):
247 | """
248 | Save the model parameter in a ckpt file. The filename is as
249 | follows:
250 | ./Models/.ckpt
251 |
252 | INPUT: sess - The current running session
253 | """
254 | self.saver.save(sess,"./Models/" + self.mod_name + ".ckpt")
255 |
256 | def init(self,sess):
257 | """
258 | Init the model. If the model exists in a file, load the model. Otherwise, initalize the variables
259 |
260 | INPUT: sess - The current running session
261 | """
262 | if not os.path.isfile(\
263 | "./Models/" + self.mod_name + ".ckpt.meta"):
264 | sess.run(tf.global_variables_initializer())
265 | return 0
266 | else:
267 | if self.gen_only:
268 | sess.run(tf.global_variables_initializer())
269 | self.load(sess)
270 | return 1
271 |
272 | def load(self,sess):
273 | """
274 | Load the model from the parameter file:
275 | ./Models/.ckpt
276 |
277 | INPUT: sess - The current running session
278 | """
279 | self.saver.restore(sess, "./Models/" + self.mod_name + ".ckpt")
280 |
281 | def train(self,batch_size=32,lambda_c=0.,lambda_h=0.,epoch=0,save=True,syn_noise=0.,real_noise=0.):
282 | f = h5py.File(self.data_file,"r")
283 |
284 | num_samples = min(self.a_size,self.b_size)
285 | num_iterations = num_samples // batch_size
286 |
287 | a_order = np.random.permutation(self.a_size)
288 | b_order = np.random.permutation(self.b_size)
289 |
290 | if self.verbose:
291 | print('lambda_c: ' + str(lambda_c))
292 | print('lambda_h: ' + str(lambda_h))
293 |
294 | with tf.Session(graph=self.graph) as sess:
295 | # initialize variables
296 | self.init(sess)
297 |
298 | vec_lcA = []
299 | vec_lcB = []
300 |
301 | vec_ldrA = []
302 | vec_ldrAh = []
303 | vec_ldrB = []
304 | vec_ldrBh = []
305 | vec_ldfA = []
306 | vec_ldfAh = []
307 | vec_ldfB = []
308 | vec_ldfBh = []
309 |
310 | vec_l_dis_A = []
311 | vec_l_dis_B = []
312 | vec_l_gen_A = []
313 | vec_l_gen_B = []
314 |
315 | rel_lr = 1.
316 | if epoch > 100:
317 | rel_lr = 2. - epoch/100.
318 |
319 | if epoch < 100:
320 | rel_noise = 0.9**epoch
321 | else:
322 | rel_noise = 0.
323 |
324 | for iteration in range(num_iterations):
325 | images_a = f['A/data'][np.sort(a_order[(iteration*batch_size):((iteration+1)*batch_size)]),:,:,:]
326 | images_b = f['B/data'][np.sort(b_order[(iteration*batch_size):((iteration+1)*batch_size)]),:,:,:]
327 | if images_a.dtype=='uint8':
328 | images_a=images_a/float(2**8-1)
329 | elif images_a.dtype=='uint16':
330 | images_a=images_a/float(2**16-1)
331 | else:
332 | raise ValueError('Dataset A is not int8 or int16')
333 | if images_b.dtype=='uint8':
334 | images_b=images_b/float(2**8-1)
335 | elif images_b.dtype=='uint16':
336 | images_b=images_b/float(2**16-1)
337 | else:
338 | raise ValueError('Dataset B is not int8 or int16')
339 |
340 | images_a += np.random.randn(*images_a.shape)*real_noise
341 | images_b += np.random.randn(*images_b.shape)*syn_noise
342 |
343 | _, l_gen_A, im_fake_A, l_gen_B, im_fake_B, cyc_A, cyc_B, sA, sB, sfA, sfB, lcA, lcB = sess.run([self.opt_gen,\
344 | self.loss_gen_A,\
345 | self.fake_A,\
346 | self.loss_gen_B,\
347 | self.fake_B,\
348 | self.cyc_A,\
349 | self.cyc_B,\
350 | self.s_A,self.s_B,self.s_fake_A,self.s_fake_B,\
351 | self.loss_cyc_A,\
352 | self.loss_cyc_B],\
353 | feed_dict={self.A: images_a,\
354 | self.B: images_b,\
355 | self.lambda_c: lambda_c,\
356 | self.lambda_h: lambda_h,\
357 | self.relative_lr: rel_lr,\
358 | self.rel_dis_noise: rel_noise})
359 |
360 | if self.temp_b_s >= self.buffer_size:
361 | rand_vec_a = np.random.permutation(self.buffer_size)[:batch_size]
362 | rand_vec_b = np.random.permutation(self.buffer_size)[:batch_size]
363 |
364 | self.buffer_real_a[rand_vec_a,...] = images_a
365 | self.buffer_real_b[rand_vec_b,...] = images_b
366 | self.buffer_fake_a[rand_vec_a,...] = im_fake_A
367 | self.buffer_fake_b[rand_vec_b,...] = im_fake_B
368 | else:
369 | low = int(self.temp_b_s)
370 | high = int(min(self.temp_b_s + batch_size,self.buffer_size))
371 | self.temp_b_s = high
372 |
373 | self.buffer_real_a[low:high,...] = images_a[:(high-low),...]
374 | self.buffer_real_b[low:high,...] = images_b[:(high-low),...]
375 | self.buffer_fake_a[low:high,...] = im_fake_A[:(high-low),...]
376 | self.buffer_fake_b[low:high,...] = im_fake_B[:(high-low),...]
377 |
378 | # Create dataset out of buffer and gen images to train dis
379 | dis_real_a = np.copy(images_a)
380 | dis_real_b = np.copy(images_b)
381 | dis_fake_a = np.copy(im_fake_A)
382 | dis_fake_b = np.copy(im_fake_B)
383 |
384 | half_b_s = int(batch_size/2)
385 | rand_vec_a = np.random.permutation(self.temp_b_s)[:half_b_s]
386 | rand_vec_b = np.random.permutation(self.temp_b_s)[:half_b_s]
387 | dis_real_a[:half_b_s,...] = self.buffer_real_a[rand_vec_a,...]
388 | dis_fake_a[:half_b_s,...] = self.buffer_fake_a[rand_vec_a,...]
389 | dis_real_b[:half_b_s,...] = self.buffer_real_b[rand_vec_b,...]
390 | dis_fake_b[:half_b_s,...] = self.buffer_fake_b[rand_vec_b,...]
391 |
392 | _, l_dis_A, l_dis_B, \
393 | ldrA,ldrAh,ldfA,ldfAh,\
394 | ldrB,ldrBh,ldfB,ldfBh = sess.run([\
395 | self.opt_dis,
396 | self.loss_dis_A,
397 | self.loss_dis_B,
398 | self.dis_real_A,
399 | self.dis_real_Ah,
400 | self.dis_fake_A,
401 | self.dis_fake_Ah,
402 | self.dis_real_B,
403 | self.dis_real_Bh,
404 | self.dis_fake_B,
405 | self.dis_fake_Bh],feed_dict={self.A: dis_real_a,\
406 | self.B: dis_real_b,\
407 | self.fake_A: dis_fake_a,\
408 | self.fake_B: dis_fake_b,\
409 | self.lambda_c: lambda_c,\
410 | self.lambda_h: lambda_h,\
411 | self.relative_lr: rel_lr,\
412 | self.rel_dis_noise: rel_noise})
413 |
414 | vec_l_dis_A.append(l_dis_A)
415 | vec_l_dis_B.append(l_dis_B)
416 | vec_l_gen_A.append(l_gen_A)
417 | vec_l_gen_B.append(l_gen_B)
418 |
419 | vec_lcA.append(lcA)
420 | vec_lcB.append(lcB)
421 |
422 | vec_ldrA.append(ldrA)
423 | vec_ldrAh.append(ldrAh)
424 | vec_ldrB.append(ldrB)
425 | vec_ldrBh.append(ldrBh)
426 | vec_ldfA.append(ldfA)
427 | vec_ldfAh.append(ldfAh)
428 | vec_ldfB.append(ldfB)
429 | vec_ldfBh.append(ldfBh)
430 |
431 | if np.shape(images_b)[-1]==4:
432 |
433 | images_b=np.vstack((images_b[0,:,:,0:3],np.tile(images_b[0,:,:,3].reshape(320,320,1),[1,1,3])))
434 | im_fake_B=np.vstack((im_fake_B[0,:,:,0:3],np.tile(im_fake_B[0,:,:,3].reshape(320,320,1),[1,1,3])))
435 | cyc_B=np.vstack((cyc_B[0,:,:,0:3],np.tile(cyc_B[0,:,:,3].reshape(320,320,1),[1,1,3])))
436 | images_b=images_b[np.newaxis,:,:,:]
437 | im_fake_B=im_fake_B[np.newaxis,:,:,:]
438 | cyc_B=cyc_B[np.newaxis,:,:,:]
439 |
440 | if iteration%5==0:
441 | sneak_peak=Utilities.produce_tiled_images(images_a,images_b,im_fake_A, im_fake_B,cyc_A,cyc_B)
442 |
443 | cv2.imshow("",sneak_peak[:,:,[2,1,0]])
444 | cv2.waitKey(1)
445 |
446 | print("\rTrain: {}/{} ({:.1f}%)".format(iteration+1, num_iterations,(iteration) * 100 / (num_iterations-1)) + \
447 | " Loss_dis_A={:.4f}, Loss_dis_B={:.4f}".format(np.mean(vec_l_dis_A),np.mean(vec_l_dis_B)) + \
448 | ", Loss_gen_A={:.4f}, Loss_gen_B={:.4f}".format(np.mean(vec_l_gen_A),np.mean(vec_l_gen_B))\
449 | ,end=" ")
450 |
451 | # Save model
452 | if save:
453 | self.save(sess)
454 | cv2.imwrite("./Models/Images/" + self.mod_name + "_Epoch_" + str(epoch) + ".png",sneak_peak[:,:,[2,1,0]]*255)
455 | print("")
456 |
457 | f.close()
458 |
459 | loss_gen_A = [np.mean(np.square(np.array(vec_ldfA))),np.mean(np.square(np.array(vec_ldfAh))),np.mean(np.array(lcA))]
460 | loss_gen_B = [np.mean(np.square(np.array(vec_ldfB))),np.mean(np.square(np.array(vec_ldfBh))),np.mean(np.array(lcB))]
461 | loss_dis_A = [np.mean(np.square(np.array(vec_ldrA))),np.mean(np.square(1.-np.array(vec_ldfA))),\
462 | np.mean(np.square(np.array(vec_ldrAh))),np.mean(np.square(1.-np.array(vec_ldfAh)))]
463 | loss_dis_B = [np.mean(np.square(np.array(vec_ldrB))),np.mean(np.square(1.-np.array(vec_ldfB))),\
464 | np.mean(np.square(np.array(vec_ldrBh))),np.mean(np.square(1.-np.array(vec_ldfBh)))]
465 |
466 | return [loss_gen_A,loss_gen_B,loss_dis_A,loss_dis_B]
467 |
468 | def predict(self,lambda_c=0.,lambda_h=0.):
469 | f = h5py.File(self.data_file,"r")
470 |
471 | rand_a = np.random.randint(self.a_size-32)
472 | rand_b = np.random.randint(self.b_size-32)
473 |
474 | images_a = f['A/data'][rand_a:(rand_a+32),:,:,:]/255.
475 | images_b = f['B/data'][rand_b:(rand_b+32),:,:,:]/255.
476 | with tf.Session(graph=self.graph) as sess:
477 | # initialize variables
478 | self.init(sess)
479 |
480 | fake_A, fake_B, cyc_A, cyc_B = \
481 | sess.run([self.fake_A,self.fake_B,self.cyc_A,self.cyc_B],\
482 | feed_dict={self.A: images_a,\
483 | self.B: images_b,\
484 | self.lambda_c: lambda_c,\
485 | self.lambda_h: lambda_h})
486 |
487 | f.close()
488 | return images_a, images_b, fake_A, fake_B, cyc_A, cyc_B
489 |
490 | def generator_A(self,batch_size=32,lambda_c=0.,lambda_h=0.):
491 | f = h5py.File(self.data_file,"r")
492 | f_save = h5py.File("./Models/" + self.mod_name + '_gen_A.h5',"w")
493 |
494 | # Find number of samples
495 | num_samples = self.b_size
496 | num_iterations = num_samples // batch_size
497 |
498 | gen_data = np.zeros((f['B/data'].shape[0],f['B/data'].shape[1],f['B/data'].shape[2],f['A/data'].shape[3]),dtype=np.uint16)
499 |
500 | with tf.Session(graph=self.graph) as sess:
501 | # initialize variables
502 | self.init(sess)
503 |
504 | for iteration in range(num_iterations):
505 | images_b = f['B/data'][(iteration*batch_size):((iteration+1)*batch_size),:,:,:]
506 | if images_b.dtype=='uint8':
507 | images_b=images_b/float(2**8-1)
508 | elif images_b.dtype=='uint16':
509 | images_b=images_b/float(2**16-1)
510 | else:
511 | raise ValueError('Dataset B is not int8 or int16')
512 |
513 | gen_A = sess.run(self.fake_A,feed_dict={self.B: images_b,\
514 | self.lambda_c: lambda_c,\
515 | self.lambda_h: lambda_h})
516 | gen_data[(iteration*batch_size):((iteration+1)*batch_size),:,:,:] = (np.minimum(np.maximum(gen_A,0),1)*(2**16-1)).astype(np.uint16)
517 |
518 | print("\rGenerator A: {}/{} ({:.1f}%)".format(iteration+1, num_iterations, iteration*100/(num_iterations-1)),end=" ")
519 |
520 | group = f_save.create_group('A')
521 | group.create_dataset(name='data', data=gen_data,dtype=np.uint16)
522 |
523 | f_save.close()
524 | f.close()
525 |
526 | return None
527 |
528 | def generator_B(self,batch_size=32,lambda_c=0.,lambda_h=0.):
529 | f = h5py.File(self.data_file,"r")
530 | f_save = h5py.File("./Models/" + self.mod_name + '_gen_B.h5',"w")
531 |
532 | # Find number of samples
533 | num_samples = self.a_size
534 | num_iterations = num_samples // batch_size
535 |
536 | gen_data = np.zeros((f['A/data'].shape[0],f['A/data'].shape[1],f['A/data'].shape[2],f['B/data'].shape[3]),dtype=np.uint16)
537 |
538 | with tf.Session(graph=self.graph) as sess:
539 | # initialize variables
540 | self.init(sess)
541 |
542 | for iteration in range(num_iterations):
543 | images_a = f['A/data'][(iteration*batch_size):((iteration+1)*batch_size),:,:,:]
544 | if images_a.dtype=='uint8':
545 | images_a=images_a/float(2**8-1)
546 | elif images_a.dtype=='uint16':
547 | images_a=images_a/float(2**16-1)
548 | else:
549 | raise ValueError('Dataset A is not int8 or int16')
550 |
551 | gen_B = sess.run(self.fake_B,feed_dict={self.A: images_a,\
552 | self.lambda_c: lambda_c,\
553 | self.lambda_h: lambda_h})
554 | gen_data[(iteration*batch_size):((iteration+1)*batch_size),:,:,:] = (np.minimum(np.maximum(gen_B,0),1)*(2**16-1)).astype(np.uint16)
555 |
556 | print("\rGenerator B: {}/{} ({:.1f}%)".format(iteration+1, num_iterations, iteration*100/(num_iterations-1)),end=" ")
557 |
558 | group = f_save.create_group('B')
559 | group.create_dataset(name='data', data=gen_data,dtype=np.uint16)
560 |
561 | f_save.close()
562 | f.close()
563 |
564 | return None
565 |
566 |
567 | def get_loss(self,lambda_c=0.,lambda_h=0.):
568 | f = h5py.File(self.data_file,"r")
569 |
570 | rand_a = np.random.randint(self.a_size-32)
571 | rand_b = np.random.randint(self.b_size-32)
572 |
573 | images_a = f['A/data'][rand_a:(rand_a+32),:,:,:]/255.
574 | images_b = f['B/data'][rand_b:(rand_b+32),:,:,:]/255.
575 | with tf.Session(graph=self.graph) as sess:
576 | # initialize variables
577 | self.init(sess)
578 |
579 | l_rA,l_rB,l_fA,l_fB = \
580 | sess.run([self.dis_real_A,self.dis_real_B,self.dis_fake_A,self.dis_fake_B,],\
581 | feed_dict={self.A: images_a,\
582 | self.B: images_b,\
583 | self.lambda_c: lambda_c,\
584 | self.lambda_h: lambda_h})
585 |
586 | f.close()
587 | return l_rA,l_rB,l_fA,l_fB
588 |
--------------------------------------------------------------------------------
/notebooks/VGG_Synthetic_by_localization.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import imageio as io\n",
10 | "import matplotlib.pyplot as plt\n",
11 | "%matplotlib inline\n",
12 | "import numpy as np\n",
13 | "import os\n",
14 | "from skimage.transform import rotate"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 2,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "min_cells,max_cells=70,310"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 13,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "pad=10\n",
33 | "xx,yy=np.meshgrid(np.arange(256+2*pad),np.arange(256+2*pad))\n",
34 | "#border pad\n",
35 | "fringe=np.pad(np.zeros((256,256)),((pad,pad),(pad,pad)),mode='constant',constant_values=1)\n",
36 | "fringe=fringe.astype(bool)\n",
37 | "\n",
38 | "\n",
39 | "def make_localization_synthetic_vgg():\n",
40 | " n_cell=np.random.randint(min_cells,max_cells)\n",
41 | " x,y=np.random.randint(0,255,size=(2,n_cell))+pad\n",
42 | " image=np.zeros((256+2*pad,256+2*pad,3))\n",
43 | " #homing_gaussian=np.zeros((256+2*pad,256+2*pad))\n",
44 | " #square_for_area=np.zeros((256+2*pad,256+2*pad))\n",
45 | " noise=np.clip(np.random.normal(loc=0.02,scale=0.03,size=(256+2*pad,256+2*pad)),0,1)\n",
46 | " image[:,:,0]=np.clip(np.random.normal(loc=0.02,scale=0.03,size=(256+2*pad,256+2*pad)),0,1)\n",
47 | " \n",
48 | " for i in range(n_cell):\n",
49 | " mask=(xx-x[i])**2+(yy-y[i])**2<(6)**2\n",
50 | " image[mask,0]=0.3*np.random.rand()+0.6+noise[mask]\n",
51 | " image[:,:,1]+=0.4*np.exp(-((yy-y[i])/3.)**2-((xx-x[i])/3.)**2)\n",
52 | " image[:,:,2]+=0.5*np.exp(-((yy-y[i])/1.)**2-((xx-x[i])/1.)**2)\n",
53 | " return image"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 14,
59 | "metadata": {},
60 | "outputs": [
61 | {
62 | "data": {
63 | "text/plain": [
64 | ""
65 | ]
66 | },
67 | "execution_count": 14,
68 | "metadata": {},
69 | "output_type": "execute_result"
70 | },
71 | {
72 | "data": {
73 | "image/png": "\n",
74 | "text/plain": [
75 | ""
76 | ]
77 | },
78 | "metadata": {
79 | "needs_background": "light"
80 | },
81 | "output_type": "display_data"
82 | }
83 | ],
84 | "source": [
85 | "plt.imshow(make_localization_synthetic_vgg())"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": 17,
91 | "metadata": {},
92 | "outputs": [
93 | {
94 | "name": "stdout",
95 | "output_type": "stream",
96 | "text": [
97 | "\n"
98 | ]
99 | }
100 | ],
101 | "source": [
102 | "N_images=50\n",
103 | "dataname='VGG_localization'\n",
104 | "train_name='trainA'\n",
105 | "directory='../Data/'+dataname+'/'+train_name+'/'\n",
106 | "if not os.path.exists(directory):\n",
107 | " os.makedirs(directory)"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "metadata": {},
114 | "outputs": [
115 | {
116 | "name": "stderr",
117 | "output_type": "stream",
118 | "text": [
119 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.02268088151]\n",
120 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n"
121 | ]
122 | },
123 | {
124 | "name": "stdout",
125 | "output_type": "stream",
126 | "text": [
127 | "0"
128 | ]
129 | },
130 | {
131 | "name": "stderr",
132 | "output_type": "stream",
133 | "text": [
134 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.01464227872]\n",
135 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
136 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.02984934249]\n",
137 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
138 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.01064395335]\n",
139 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
140 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.01649309037]\n",
141 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
142 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.02021895102]\n",
143 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
144 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.00492180949]\n",
145 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
146 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.01165869974]\n",
147 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
148 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.01293333199]\n",
149 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
150 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.00414250075]\n",
151 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
152 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.00521218855]\n",
153 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
154 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.27627977337]\n",
155 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
156 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.00136268245]\n",
157 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
158 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.03352604424]\n",
159 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
160 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.01857872308]\n",
161 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
162 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.00677343811]\n",
163 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
164 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.13678852504]\n",
165 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
166 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.02354139996]\n",
167 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
168 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.00641660986]\n",
169 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
170 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.00318504915]\n",
171 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
172 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.01248495563]\n",
173 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
174 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.03589042459]\n",
175 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
176 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.00244328906]\n",
177 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n",
178 | "/usr/local/lib/python2.7/dist-packages/imageio/core/util.py:104: UserWarning: Conversion from float64 to uint8, range [0.0, 1.00883233411]\n",
179 | " 'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))\n"
180 | ]
181 | }
182 | ],
183 | "source": [
184 | "for i in range(N_images):\n",
185 | " name=directory+str(i).zfill(5)+'.png'\n",
186 | " image=make_localization_synthetic_vgg()\n",
187 | " io.imsave(name,image)\n",
188 | " if i%50==0:\n",
189 | " print i,"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": null,
195 | "metadata": {},
196 | "outputs": [],
197 | "source": []
198 | }
199 | ],
200 | "metadata": {
201 | "kernelspec": {
202 | "display_name": "Python 2",
203 | "language": "python",
204 | "name": "python2"
205 | },
206 | "language_info": {
207 | "codemirror_mode": {
208 | "name": "ipython",
209 | "version": 2
210 | },
211 | "file_extension": ".py",
212 | "mimetype": "text/x-python",
213 | "name": "python",
214 | "nbconvert_exporter": "python",
215 | "pygments_lexer": "ipython2",
216 | "version": "2.7.12"
217 | }
218 | },
219 | "nbformat": 4,
220 | "nbformat_minor": 2
221 | }
222 |
--------------------------------------------------------------------------------