├── .gitignore
├── README.md
├── requirements.d
├── base.txt
└── venv.txt
├── src
├── cache.py
├── cifar10.py
├── dataset.py
├── dataset_.py
├── download.py
├── download2.py
├── inception_score.py
├── jupyter
│ ├── .idea
│ │ └── vcs.xml
│ ├── gan_toy.ipynb
│ ├── record_video.py
│ ├── tensorflow_utils.py
│ └── utils.py
├── main.py
├── plot.py
├── solver.py
├── tensorflow_utils.py
├── utils.py
└── wgan_gp.py
└── tox.ini
/.gitignore:
--------------------------------------------------------------------------------
1 | src/jupyter/.idea/
2 | src/jupyter/.ipynb_checkpoints/
3 | src/jupyter/__pycache__/
4 | src/jupyter/*.jpg
5 | src/jupyter/img/
6 | src/jupyter/*.gif
7 | src/jupyter/*.mp4
8 | src/.idea/
9 | src/mnist/
10 | src/imagenet/
11 | src/imagenet64/
12 | src/cifar10/
13 | src/__pycache__/
14 | github_imgs/
15 |
16 | # Byte-compiled / optimized / DLL files
17 | __pycache__/
18 | *.py[cod]
19 | *$py.class
20 |
21 | # C extensions
22 | *.so
23 |
24 | # Distribution / packaging
25 | .Python
26 | build/
27 | develop-eggs/
28 | dist/
29 | downloads/
30 | eggs/
31 | .eggs/
32 | lib/
33 | lib64/
34 | parts/
35 | sdist/
36 | var/
37 | wheels/
38 | pip-wheel-metadata/
39 | share/python-wheels/
40 | *.egg-info/
41 | .installed.cfg
42 | *.egg
43 | MANIFEST
44 |
45 | # PyInstaller
46 | # Usually these files are written by a python script from a template
47 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
48 | *.manifest
49 | *.spec
50 |
51 | # Installer logs
52 | pip-log.txt
53 | pip-delete-this-directory.txt
54 |
55 | # Unit test / coverage reports
56 | htmlcov/
57 | .tox/
58 | .nox/
59 | .coverage
60 | .coverage.*
61 | .cache
62 | nosetests.xml
63 | coverage.xml
64 | *.cover
65 | .hypothesis/
66 | .pytest_cache/
67 |
68 | # Translations
69 | *.mo
70 | *.pot
71 |
72 | # Django stuff:
73 | *.log
74 | local_settings.py
75 | db.sqlite3
76 |
77 | # Flask stuff:
78 | instance/
79 | .webassets-cache
80 |
81 | # Scrapy stuff:
82 | .scrapy
83 |
84 | # Sphinx documentation
85 | docs/_build/
86 |
87 | # PyBuilder
88 | target/
89 |
90 | # Jupyter Notebook
91 | .ipynb_checkpoints
92 |
93 | # IPython
94 | profile_default/
95 | ipython_config.py
96 |
97 | # pyenv
98 | .python-version
99 |
100 | # pipenv
101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
103 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not
104 | # install all needed dependencies.
105 | #Pipfile.lock
106 |
107 | # celery beat schedule file
108 | celerybeat-schedule
109 |
110 | # SageMath parsed files
111 | *.sage.py
112 |
113 | # Environments
114 | .env
115 | .venv
116 | env/
117 | venv/
118 | ENV/
119 | env.bak/
120 | venv.bak/
121 |
122 | # Spyder project settings
123 | .spyderproject
124 | .spyproject
125 |
126 | # Rope project settings
127 | .ropeproject
128 |
129 | # mkdocs documentation
130 | /site
131 |
132 | # mypy
133 | .mypy_cache/
134 | .dmypy.json
135 | dmypy.json
136 |
137 | # Pyre type checker
138 | .pyre/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # WGAN-GP-tensorflow
2 | This repository is a Tensorflow implementation of the [WGAN-GP](https://arxiv.org/abs/1704.00028) for MNIST, CIFAR-10, and ImageNet64.
3 |
4 |
5 |
7 |
8 | * *All samples in README.md are genearted by neural network except the first image for each row.*
9 |
10 | ## Install Prerequisites
11 |
12 | * python 3.5, 3.6 or 3.7
13 | * python3-tk
14 |
15 | Ubuntu/Debian/etc.:
16 |
17 | sudo apt install python3.5 python3.5-tk
18 |
19 | ## Create Virtual Environment
20 |
21 | python -m venv venv
22 |
23 | ## Activate Virtual Environment
24 |
25 | Windows:
26 |
27 | venv/Scripts/activate
28 |
29 | Bash:
30 |
31 | source venv/bin/activate
32 |
33 | ## Install Virtual Environment Requirements
34 |
35 | pip install -r requirements.d/venv.txt
36 |
37 | ## Create Execution Environments
38 |
39 | tox --notest
40 |
41 | That will install tensorflow which uses only the CPU.
42 |
43 | To use an Nvidia GPU:
44 |
45 | .tox/py35/bin/python -m pip uninstall tensorflow
46 | .tox/py35/bin/python -m pip install tensorflow-gpu==1.13.1
47 | .tox/py36/bin/python -m pip uninstall tensorflow
48 | .tox/py36/bin/python -m pip install tensorflow-gpu==1.13.1
49 | .tox/py37/bin/python -m pip uninstall tensorflow
50 | .tox/py37/bin/python -m pip install tensorflow-gpu==1.13.1
51 |
52 | To use an AMD GPU:
53 |
54 | .tox/py35/bin/python -m pip uninstall tensorflow
55 | .tox/py35/bin/python -m pip install tensorflow-rocm==1.13.1
56 | .tox/py36/bin/python -m pip uninstall tensorflow
57 | .tox/py36/bin/python -m pip install tensorflow-rocm==1.13.1
58 | .tox/py36/bin/python -m pip uninstall tensorflow
59 | .tox/py37/bin/python -m pip install tensorflow-rocm==1.13.1
60 |
61 | ## Generated Images
62 | ### 1. Toy Dataset
63 | Results from 2-dimensional of the 8 Gaussian Mixture Models, 25 Gaussian Mixture Models, and Swiss Roll data. [Ipython Notebook](https://github.com/ChengBinJin/WGAN-GP-tensorflow/tree/master/src/jupyter).
64 |
65 | **Note:** To demonstrate following experiment, we held the generator distribution Pg fixed at the real distribution plus unit-variance Gaussian noise.
66 | - **Top:** GAN discriminator
67 | - **Middle:** WGAN critic with weight clipping
68 | - **Bottom:** WGAN critic with weight penalty
69 |
70 |
71 |
72 |
73 |
74 |
75 | **Note:** For the next experiment, we did not fix generator and showed generated points by the generator.
76 | - **Top:** GAN discriminator
77 | - **Middle:** WGAN critic with weight clipping
78 | - **Bottom:** WGAN critic with weight penalty
79 |
80 |
81 |
82 |
83 |
84 |
85 | ### 2. MNIST Dataset
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 | ### 3. CIFAR-10
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 | ### 4. IMAGENET64
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 | ## Documentation
113 | ### Download Dataset
114 | 'MNIST' and 'CIFAR10' dataset will be downloaded automatically from the code if in a specific folder there are no dataset. 'ImageNet64' dataset can be download from the [Downsampled ImageNet](http://image-net.org/small/download.php).
115 |
116 | ### Directory Hierarchy
117 | ```
118 | .
119 | │ WGAN-GP
120 | │ ├── src
121 | │ │ ├── imagenet (folder saved inception network weights that downloaded from the inception_score.py)
122 | │ │ ├── cache.py
123 | │ │ ├── cifar10.py
124 | │ │ ├── dataset.py
125 | │ │ ├── dataset_.py
126 | │ │ ├── download.py
127 | │ │ ├── inception_score.py
128 | │ │ ├── main.py
129 | │ │ ├── plot.py
130 | │ │ ├── solver.py
131 | │ │ ├── tensorflow_utils.py
132 | │ │ ├── utils.py
133 | │ │ └── wgan_gp.py
134 | │ Data
135 | │ ├── mnist
136 | │ ├── cifar10
137 | │ └── imagenet64
138 | ```
139 | **src**: source codes of the WGAN-GP
140 |
141 | ### Training WGAN-GP
142 | Use `main.py` to train a WGAN-GP network. Example usage:
143 |
144 | ```
145 | python main.py
146 | ```
147 | - `gpu_index`: gpu index, default: `0`
148 | - `batch_size`: batch size for one feed forward, default: `64`
149 | - `dataset`: dataset name from [mnist, cifar10, imagenet64], default: `mnist`
150 |
151 | - `is_train`: training or inference mode, default: `True`
152 | - `learning_rate`: initial learning rate for Adam, default: `0.001`
153 | - `num_critic`: the number of iterations of the critic per generator iteration, default: `5`
154 | - `z_dim`: dimension of z vector, default: `128`
155 | - `lambda_`: gradient penalty lambda hyperparameter, default: `10.`
156 | - `beta1`: beta1 momentum term of Adam, default: `0.5`
157 | - `beta2`: beta2 momentum term of Adam, default: `0.9`
158 |
159 | - `iters`: number of interations, default: `200000`
160 | - `print_freq`: print frequency for loss, default: `100`
161 | - `save_freq`: save frequency for model, default: `10000`
162 | - `sample_freq`: sample frequency for saving image, default: `500`
163 | - `inception_freq`: calculation frequence of the inception score, default: `1000`
164 | - `sample_batch`: number of sampling images for check generator quality, default: `64`
165 | - `load_model`: folder of save model that you wish to test, (e.g. 20181120-1558). default: `None`
166 |
167 | ### WGAN-GP During Training
168 | **Note:** From the following figures, the Y axises are tge negative critic loss for the WGAN-GP.
169 | 1. **MNIST**
170 |
171 |
172 |
173 |
174 | 2. **CIFAR10**
175 |
176 |
177 |
178 |
179 | 3. **IMAGENET64**
180 |
181 |
182 |
183 |
184 | ### Inception Score on CIFAR10 During Training
185 | **Note:** Inception score was calculated every 1000 iterations.
186 |
187 |
188 |
189 |
190 | ### Test WGAN-GP
191 | Use `main.py` to test a WGAN-GP network. Example usage:
192 |
193 | ```
194 | python main.py --is_train=false --load_model=folder/you/wish/to/test/e.g./20181120-1558
195 | ```
196 | Please refer to the above arguments.
197 |
198 | ### Citation
199 | ```
200 | @misc{chengbinjin2018wgan-gp,
201 | author = {Cheng-Bin Jin},
202 | title = {WGAN-GP-tensorflow},
203 | year = {2018},
204 | howpublished = {\url{https://github.com/ChengBinJin/WGAN-GP-tensorflow}},
205 | note = {commit xxxxxxx}
206 | }
207 | ```
208 |
209 | ### Attributions/Thanks
210 | - This project borrowed some code from [igul222](https://github.com/igul222/improved_wgan_training).
211 | - Some readme formatting was borrowed from [Logan Engstrom](https://github.com/lengstrom/fast-style-transfer).
212 |
213 | ## License
214 | Copyright (c) 2018 Cheng-Bin Jin. Contact me for commercial use (or rather any use that is not academic research) (email: sbkim0407@gmail.com). Free for research use, as long as proper attribution is given and this copyright notice is retained.
215 |
216 | ## Related Projects
217 | - [Vanilla GAN](https://github.com/ChengBinJin/VanillaGAN-TensorFlow)
218 | - [DCGAN](https://github.com/ChengBinJin/DCGAN-TensorFlow)
219 | - [WGAN](https://github.com/ChengBinJin/WGAN-TensorFlow)
220 |
--------------------------------------------------------------------------------
/requirements.d/base.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.7.1
2 | astor==0.7.1
3 | attrs==19.1.0
4 | certifi==2019.3.9
5 | chardet==3.0.4
6 | cycler==0.10.0
7 | gast==0.2.2
8 | grpcio==1.20.1
9 | h5py==2.9.0
10 | idna==2.8
11 | jsonschema==3.0.1
12 | Keras==2.2.4
13 | Keras-Applications==1.0.7
14 | Keras-Preprocessing==1.0.9
15 | kiwisolver==1.1.0
16 | Markdown==3.1
17 | matplotlib==3.0.3
18 | mock==3.0.4
19 | numpy==1.16.3
20 | Pillow==6.2.0
21 | protobuf==3.7.1
22 | pyparsing==2.4.0
23 | pyrsistent==0.15.1
24 | python-dateutil==2.8.0
25 | PyYAML==5.1
26 | requests==2.21.0
27 | scipy==1.2.1
28 | six==1.12.0
29 | tensorboard==1.13.1
30 | tensorflow==1.15.0
31 | tensorflow-estimator==1.13.0
32 | termcolor==1.1.0
33 | tqdm==4.31.1
34 | urllib3==1.25.2
35 | Werkzeug==0.15.3
36 |
--------------------------------------------------------------------------------
/requirements.d/venv.txt:
--------------------------------------------------------------------------------
1 | filelock==3.0.10
2 | pluggy==0.9.0
3 | py==1.8.0
4 | six==1.12.0
5 | toml==0.10.0
6 | tox==3.9.0
7 | virtualenv==16.5.0
8 |
--------------------------------------------------------------------------------
/src/cache.py:
--------------------------------------------------------------------------------
1 | ########################################################################
2 | #
3 | # Cache-wrapper for a function or class.
4 | #
5 | # Save the result of calling a function or creating an object-instance
6 | # to harddisk. This is used to persist the data so it can be reloaded
7 | # very quickly and easily.
8 | #
9 | # Implemented in Python 3.5
10 | #
11 | ########################################################################
12 | #
13 | # This file is part of the TensorFlow Tutorials available at:
14 | #
15 | # https://github.com/Hvass-Labs/TensorFlow-Tutorials
16 | #
17 | # Published under the MIT License. See the file LICENSE for details.
18 | #
19 | # Copyright 2016 by Magnus Erik Hvass Pedersen
20 | #
21 | ########################################################################
22 |
23 | import os
24 | import pickle
25 | import numpy as np
26 |
27 | ########################################################################
28 |
29 |
30 | def cache(cache_path, fn, *args, **kwargs):
31 | """
32 | Cache-wrapper for a function or class. If the cache-file exists
33 | then the data is reloaded and returned, otherwise the function
34 | is called and the result is saved to cache. The fn-argument can
35 | also be a class instead, in which case an object-instance is
36 | created and saved to the cache-file.
37 | :param cache_path:
38 | File-path for the cache-file.
39 | :param fn:
40 | Function or class to be called.
41 | :param args:
42 | Arguments to the function or class-init.
43 | :param kwargs:
44 | Keyword arguments to the function or class-init.
45 | :return:
46 | The result of calling the function or creating the object-instance.
47 | """
48 |
49 | # If the cache-file exists.
50 | if os.path.exists(cache_path):
51 | # Load the cached data from the file.
52 | with open(cache_path, mode='rb') as file:
53 | obj_ = pickle.load(file)
54 |
55 | print("- Data loaded from cache-file: " + cache_path)
56 | else:
57 | # The cache-file does not exist.
58 |
59 | # Call the function / class-init with the supplied arguments.
60 | obj_ = fn(*args, **kwargs)
61 |
62 | # Save the data to a cache-file.
63 | with open(cache_path, mode='wb') as file:
64 | pickle.dump(obj_, file)
65 |
66 | print("- Data saved to cache-file: " + cache_path)
67 |
68 | return obj_
69 |
70 |
71 | ########################################################################
72 |
73 |
74 | def convert_numpy2pickle(in_path, out_path):
75 | """
76 | Convert a numpy-file to pickle-file.
77 | The first version of the cache-function used numpy for saving the data.
78 | Instead of re-calculating all the data, you can just convert the
79 | cache-file using this function.
80 | :param in_path:
81 | Input file in numpy-format written using numpy.save().
82 | :param out_path:
83 | Output file written as a pickle-file.
84 | :return:
85 | Nothing.
86 | """
87 |
88 | # Load the data using numpy.
89 | data = np.load(in_path)
90 |
91 | # Save the data using pickle.
92 | with open(out_path, mode='wb') as file:
93 | pickle.dump(data, file)
94 |
95 |
96 | ########################################################################
97 |
98 | if __name__ == '__main__':
99 | # This is a short example of using a cache-file.
100 |
101 | # This is the function that will only get called if the result
102 | # is not already saved in the cache-file. This would normally
103 | # be a function that takes a long time to compute, or if you
104 | # need persistent data for some other reason.
105 | def expensive_function(a, b):
106 | return a * b
107 |
108 | print('Computing expensive_function() ...')
109 |
110 | # Either load the result from a cache-file if it already exists,
111 | # otherwise calculate expensive_function(a=123, b=456) and
112 | # save the result to the cache-file for next time.
113 | result = cache(cache_path='cache_expensive_function.pkl',
114 | fn=expensive_function, a=123, b=456)
115 |
116 | print('result =', result)
117 |
118 | # Newline.
119 | print()
120 |
121 | # This is another example which saves an object to a cache-file.
122 |
123 | # We want to cache an object-instance of this class.
124 | # The motivation is to do an expensive computation only once,
125 | # or if we need to persist the data for some other reason.
126 | class ExpensiveClass:
127 | def __init__(self, c, d):
128 | self.c = c
129 | self.d = d
130 | self.result = c * d
131 |
132 | def print_result(self):
133 | print('c =', self.c)
134 | print('d =', self.d)
135 | print('result = c * d =', self.result)
136 |
137 | print('Creating object from ExpensiveClass() ...')
138 |
139 | # Either load the object from a cache-file if it already exists,
140 | # otherwise make an object-instance ExpensiveClass(c=123, d=456)
141 | # and save the object to the cache-file for the next time.
142 | obj = cache(cache_path='cache_ExpensiveClass.pkl',
143 | fn=ExpensiveClass, c=123, d=456)
144 |
145 | obj.print_result()
146 |
147 | ########################################################################
148 |
--------------------------------------------------------------------------------
/src/cifar10.py:
--------------------------------------------------------------------------------
1 | ########################################################################
2 | #
3 | # Functions for downloading the CIFAR-10 data-set from the internet
4 | # and loading it into memory.
5 | #
6 | # Implemented in Python 3.5
7 | #
8 | # Usage:
9 | # 1) Set the variable data_path with the desired storage path.
10 | # 2) Call maybe_download_and_extract() to download the data-set
11 | # if it is not already located in the given data_path.
12 | # 3) Call load_class_names() to get an array of the class-names.
13 | # 4) Call load_training_data() and load_test_data() to get
14 | # the images, class-numbers and one-hot encoded class-labels
15 | # for the training-set and test-set.
16 | # 5) Use the returned data in your own program.
17 | #
18 | # Format:
19 | # The images for the training- and test-sets are returned as 4-dim numpy
20 | # arrays each with the shape: [image_number, height, width, channel]
21 | # where the individual pixels are floats between 0.0 and 1.0.
22 | #
23 | ########################################################################
24 | #
25 | # This file is part of the TensorFlow Tutorials available at:
26 | #
27 | # https://github.com/Hvass-Labs/TensorFlow-Tutorials
28 | #
29 | # Published under the MIT License. See the file LICENSE for details.
30 | #
31 | # Copyright 2016 by Magnus Erik Hvass Pedersen
32 | #
33 | ########################################################################
34 |
35 | import numpy as np
36 | import pickle
37 | import os
38 | import download
39 | from dataset import one_hot_encoded
40 |
41 | ########################################################################
42 |
43 | # Directory where you want to download and save the data-set.
44 | # Set this before you start calling any of the functions below.
45 | data_path = "../../Data/cifar10/"
46 |
47 | # URL for the data-set on the internet.
48 | data_url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
49 |
50 | ########################################################################
51 | # Various constants for the size of the images.
52 | # Use these constants in your own program.
53 |
54 | # Width and height of each image.
55 | img_size = 32
56 |
57 | # Number of channels in each image, 3 channels: Red, Green, Blue.
58 | num_channels = 3
59 |
60 | # Length of an image when flattened to a 1-dim array.
61 | img_size_flat = img_size * img_size * num_channels
62 |
63 | # Number of classes.
64 | num_classes = 10
65 |
66 | ########################################################################
67 | # Various constants used to allocate arrays of the correct size.
68 |
69 | # Number of files for the training-set.
70 | _num_files_train = 5
71 |
72 | # Number of images for each batch-file in the training-set.
73 | _images_per_file = 10000
74 |
75 | # Total number of images in the training-set.
76 | # This is used to pre-allocate arrays for efficiency.
77 | _num_images_train = _num_files_train * _images_per_file
78 |
79 | ########################################################################
80 | # Private functions for downloading, unpacking and loading data-files.
81 |
82 |
83 | def _get_file_path(filename=""):
84 | """
85 | Return the full path of a data-file for the data-set.
86 | If filename=="" then return the directory of the files.
87 | """
88 |
89 | return os.path.join(data_path, "cifar-10-batches-py/", filename)
90 |
91 |
92 | def _unpickle(filename):
93 | """
94 | Unpickle the given file and return the data.
95 | Note that the appropriate dir-name is prepended the filename.
96 | """
97 |
98 | # Create full path for the file.
99 | file_path = _get_file_path(filename)
100 |
101 | print("Loading data: " + file_path)
102 |
103 | with open(file_path, mode='rb') as file:
104 | # In Python 3.X it is important to set the encoding,
105 | # otherwise an exception is raised here.
106 | data = pickle.load(file, encoding='bytes')
107 |
108 | return data
109 |
110 |
111 | def _convert_images(raw):
112 | """
113 | Convert images from the CIFAR-10 format and
114 | return a 4-dim array with shape: [image_number, height, width, channel]
115 | where the pixels are floats between 0.0 and 1.0.
116 | """
117 |
118 | # Convert the raw images from the data-files to floating-points.
119 | raw_float = np.array(raw, dtype=float) / 255.0
120 |
121 | # Reshape the array to 4-dimensions.
122 | images = raw_float.reshape([-1, num_channels, img_size, img_size])
123 |
124 | # Reorder the indices of the array.
125 | images = images.transpose([0, 2, 3, 1])
126 |
127 | return images
128 |
129 |
130 | def _load_data(filename):
131 | """
132 | Load a pickled data-file from the CIFAR-10 data-set
133 | and return the converted images (see above) and the class-number
134 | for each image.
135 | """
136 |
137 | # Load the pickled data-file.
138 | data = _unpickle(filename)
139 |
140 | # Get the raw images.
141 | raw_images = data[b'data']
142 |
143 | # Get the class-numbers for each image. Convert to numpy-array.
144 | cls = np.array(data[b'labels'])
145 |
146 | # Convert the images.
147 | images = _convert_images(raw_images)
148 |
149 | return images, cls
150 |
151 |
152 | ########################################################################
153 | # Public functions that you may call to download the data-set from
154 | # the internet and load the data into memory.
155 |
156 |
157 | def maybe_download_and_extract():
158 | """
159 | Download and extract the CIFAR-10 data-set if it doesn't already exist
160 | in data_path (set this variable first to the desired path).
161 | """
162 |
163 | download.maybe_download_and_extract(url=data_url, download_dir=data_path)
164 |
165 |
166 | def load_class_names():
167 | """
168 | Load the names for the classes in the CIFAR-10 data-set.
169 | Returns a list with the names. Example: names[3] is the name
170 | associated with class-number 3.
171 | """
172 |
173 | # Load the class-names from the pickled file.
174 | raw = _unpickle(filename="batches.meta")[b'label_names']
175 |
176 | # Convert from binary strings.
177 | names = [x.decode('utf-8') for x in raw]
178 |
179 | return names
180 |
181 |
182 | def load_training_data():
183 | """
184 | Load all the training-data for the CIFAR-10 data-set.
185 | The data-set is split into 5 data-files which are merged here.
186 | Returns the images, class-numbers and one-hot encoded class-labels.
187 | """
188 |
189 | # Pre-allocate the arrays for the images and class-numbers for efficiency.
190 | images = np.zeros(shape=[_num_images_train, img_size, img_size, num_channels], dtype=float)
191 | cls = np.zeros(shape=[_num_images_train], dtype=int)
192 |
193 | # Begin-index for the current batch.
194 | begin = 0
195 |
196 | # For each data-file.
197 | for i in range(_num_files_train):
198 | # Load the images and class-numbers from the data-file.
199 | images_batch, cls_batch = _load_data(filename="data_batch_" + str(i + 1))
200 |
201 | # Number of images in this batch.
202 | num_images = len(images_batch)
203 |
204 | # End-index for the current batch.
205 | end = begin + num_images
206 |
207 | # Store the images into the array.
208 | images[begin:end, :] = images_batch
209 |
210 | # Store the class-numbers into the array.
211 | cls[begin:end] = cls_batch
212 |
213 | # The begin-index for the next batch is the current end-index.
214 | begin = end
215 |
216 | return images, cls, one_hot_encoded(class_numbers=cls, num_classes=num_classes)
217 |
218 |
219 | def load_test_data():
220 | """
221 | Load all the test-data for the CIFAR-10 data-set.
222 | Returns the images, class-numbers and one-hot encoded class-labels.
223 | """
224 |
225 | images, cls = _load_data(filename="test_batch")
226 |
227 | return images, cls, one_hot_encoded(class_numbers=cls, num_classes=num_classes)
228 |
229 | ########################################################################
230 |
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | ########################################################################
2 | #
3 | # Class for creating a data-set consisting of all files in a directory.
4 | #
5 | # Example usage is shown in the file knifey.py and Tutorial #09.
6 | #
7 | # Implemented in Python 3.5
8 | #
9 | ########################################################################
10 | #
11 | # This file is part of the TensorFlow Tutorials available at:
12 | #
13 | # https://github.com/Hvass-Labs/TensorFlow-Tutorials
14 | #
15 | # Published under the MIT License. See the file LICENSE for details.
16 | #
17 | # Copyright 2016 by Magnus Erik Hvass Pedersen
18 | #
19 | ########################################################################
20 |
21 | import numpy as np
22 | import os
23 | import shutil
24 | from cache import cache
25 |
26 |
27 | ########################################################################
28 |
29 |
30 | def one_hot_encoded(class_numbers, num_classes=None):
31 | """
32 | Generate the One-Hot encoded class-labels from an array of integers.
33 | For example, if class_number=2 and num_classes=4 then
34 | the one-hot encoded label is the float array: [0. 0. 1. 0.]
35 | :param class_numbers:
36 | Array of integers with class-numbers.
37 | Assume the integers are from zero to num_classes-1 inclusive.
38 | :param num_classes:
39 | Number of classes. If None then use max(class_numbers)+1.
40 | :return:
41 | 2-dim array of shape: [len(class_numbers), num_classes]
42 | """
43 |
44 | # Find the number of classes if None is provided.
45 | # Assumes the lowest class-number is zero.
46 | if num_classes is None:
47 | num_classes = np.max(class_numbers) + 1
48 |
49 | return np.eye(num_classes, dtype=float)[class_numbers]
50 |
51 |
52 | ########################################################################
53 |
54 |
55 | class DataSet:
56 | def __init__(self, in_dir, exts='.jpg'):
57 | """
58 | Create a data-set consisting of the filenames in the given directory
59 | and sub-dirs that match the given filename-extensions.
60 | For example, the knifey-spoony data-set (see knifey.py) has the
61 | following dir-structure:
62 | knifey-spoony/forky/
63 | knifey-spoony/knifey/
64 | knifey-spoony/spoony/
65 | knifey-spoony/forky/test/
66 | knifey-spoony/knifey/test/
67 | knifey-spoony/spoony/test/
68 | This means there are 3 classes called: forky, knifey, and spoony.
69 | If we set in_dir = "knifey-spoony/" and create a new DataSet-object
70 | then it will scan through these directories and create a training-set
71 | and test-set for each of these classes.
72 | The training-set will contain a list of all the *.jpg filenames
73 | in the following directories:
74 | knifey-spoony/forky/
75 | knifey-spoony/knifey/
76 | knifey-spoony/spoony/
77 | The test-set will contain a list of all the *.jpg filenames
78 | in the following directories:
79 | knifey-spoony/forky/test/
80 | knifey-spoony/knifey/test/
81 | knifey-spoony/spoony/test/
82 | See the TensorFlow Tutorial #09 for a usage example.
83 | :param in_dir:
84 | Root-dir for the files in the data-set.
85 | This would be 'knifey-spoony/' in the example above.
86 | :param exts:
87 | String or tuple of strings with valid filename-extensions.
88 | Not case-sensitive.
89 | :return:
90 | Object instance.
91 | """
92 |
93 | # Extend the input directory to the full path.
94 | in_dir = os.path.abspath(in_dir)
95 |
96 | # Input directory.
97 | self.in_dir = in_dir
98 |
99 | # Convert all file-extensions to lower-case.
100 | self.exts = tuple(ext.lower() for ext in exts)
101 |
102 | # Names for the classes.
103 | self.class_names = []
104 |
105 | # Filenames for all the files in the training-set.
106 | self.filenames = []
107 |
108 | # Filenames for all the files in the test-set.
109 | self.filenames_test = []
110 |
111 | # Class-number for each file in the training-set.
112 | self.class_numbers = []
113 |
114 | # Class-number for each file in the test-set.
115 | self.class_numbers_test = []
116 |
117 | # Total number of classes in the data-set.
118 | self.num_classes = 0
119 |
120 | # For all files/dirs in the input directory.
121 | for name in os.listdir(in_dir):
122 | # Full path for the file / dir.
123 | current_dir = os.path.join(in_dir, name)
124 |
125 | # If it is a directory.
126 | if os.path.isdir(current_dir):
127 | # Add the dir-name to the list of class-names.
128 | self.class_names.append(name)
129 |
130 | # Training-set.
131 |
132 | # Get all the valid filenames in the dir (not sub-dirs).
133 | filenames = self._get_filenames(current_dir)
134 |
135 | # Append them to the list of all filenames for the training-set.
136 | self.filenames.extend(filenames)
137 |
138 | # The class-number for this class.
139 | class_number = self.num_classes
140 |
141 | # Create an array of class-numbers.
142 | class_numbers = [class_number] * len(filenames)
143 |
144 | # Append them to the list of all class-numbers for the training-set.
145 | self.class_numbers.extend(class_numbers)
146 |
147 | # Test-set.
148 |
149 | # Get all the valid filenames in the sub-dir named 'test'.
150 | filenames_test = self._get_filenames(os.path.join(current_dir, 'test'))
151 |
152 | # Append them to the list of all filenames for the test-set.
153 | self.filenames_test.extend(filenames_test)
154 |
155 | # Create an array of class-numbers.
156 | class_numbers = [class_number] * len(filenames_test)
157 |
158 | # Append them to the list of all class-numbers for the test-set.
159 | self.class_numbers_test.extend(class_numbers)
160 |
161 | # Increase the total number of classes in the data-set.
162 | self.num_classes += 1
163 |
164 | def _get_filenames(self, dir_):
165 | """
166 | Create and return a list of filenames with matching extensions in the given directory.
167 | :param dir_:
168 | Directory to scan for files. Sub-dirs are not scanned.
169 | :return:
170 | List of filenames. Only filenames. Does not include the directory.
171 | """
172 |
173 | # Initialize empty list.
174 | filenames = []
175 |
176 | # If the directory exists.
177 | if os.path.exists(dir_):
178 | # Get all the filenames with matching extensions.
179 | for filename in os.listdir(dir_):
180 | if filename.lower().endswith(self.exts):
181 | filenames.append(filename)
182 |
183 | return filenames
184 |
185 | def get_paths(self, test=False):
186 | """
187 | Get the full paths for the files in the data-set.
188 | :param test:
189 | Boolean. Return the paths for the test-set (True) or training-set (False).
190 | :return:
191 | Iterator with strings for the path-names.
192 | """
193 |
194 | if test:
195 | # Use the filenames and class-numbers for the test-set.
196 | filenames = self.filenames_test
197 | class_numbers = self.class_numbers_test
198 |
199 | # Sub-dir for test-set.
200 | test_dir = "test/"
201 | else:
202 | # Use the filenames and class-numbers for the training-set.
203 | filenames = self.filenames
204 | class_numbers = self.class_numbers
205 |
206 | # Don't use a sub-dir for test-set.
207 | test_dir = ""
208 |
209 | for filename, cls in zip(filenames, class_numbers):
210 | # Full path-name for the file.
211 | path = os.path.join(self.in_dir, self.class_names[cls], test_dir, filename)
212 |
213 | yield path
214 |
215 | def get_training_set(self):
216 | """
217 | Return the list of paths for the files in the training-set,
218 | and the list of class-numbers as integers,
219 | and the class-numbers as one-hot encoded arrays.
220 | """
221 |
222 | return list(self.get_paths()), np.asarray(self.class_numbers), one_hot_encoded(
223 | class_numbers=self.class_numbers, num_classes=self.num_classes)
224 |
225 | def get_test_set(self):
226 | """
227 | Return the list of paths for the files in the test-set,
228 | and the list of class-numbers as integers,
229 | and the class-numbers as one-hot encoded arrays.
230 | """
231 |
232 | return list(self.get_paths(test=True)), np.asarray(self.class_numbers_test), one_hot_encoded(
233 | class_numbers=self.class_numbers_test, num_classes=self.num_classes)
234 |
235 | def copy_files(self, train_dir, test_dir):
236 | """
237 | Copy all the files in the training-set to train_dir
238 | and copy all the files in the test-set to test_dir.
239 | For example, the normal directory structure for the
240 | different classes in the training-set is:
241 | knifey-spoony/forky/
242 | knifey-spoony/knifey/
243 | knifey-spoony/spoony/
244 | Normally the test-set is a sub-dir of the training-set:
245 | knifey-spoony/forky/test/
246 | knifey-spoony/knifey/test/
247 | knifey-spoony/spoony/test/
248 | But some APIs use another dir-structure for the training-set:
249 |
250 | knifey-spoony/train/forky/
251 | knifey-spoony/train/knifey/
252 | knifey-spoony/train/spoony/
253 | and for the test-set:
254 |
255 | knifey-spoony/test/forky/
256 | knifey-spoony/test/knifey/
257 | knifey-spoony/test/spoony/
258 | :param train_dir: Directory for the training-set e.g. 'knifey-spoony/train/'
259 | :param test_dir: Directory for the test-set e.g. 'knifey-spoony/test/'
260 | :return: Nothing.
261 | """
262 |
263 | # Helper-function for actually copying the files.
264 | def _copy_files(src_paths, dst_dir, class_numbers):
265 |
266 | # Create a list of dirs for each class, e.g.:
267 | # ['knifey-spoony/test/forky/',
268 | # 'knifey-spoony/test/knifey/',
269 | # 'knifey-spoony/test/spoony/']
270 | class_dirs = [os.path.join(dst_dir, class_name + "/")
271 | for class_name in self.class_names]
272 |
273 | # Check if each class-directory exists, otherwise create it.
274 | for dir_ in class_dirs:
275 | if not os.path.exists(dir_):
276 | os.makedirs(dir_)
277 |
278 | # For all the file-paths and associated class-numbers,
279 | # copy the file to the destination dir for that class.
280 | for src, cls in zip(src_paths, class_numbers):
281 | shutil.copy(src=src, dst=class_dirs[cls])
282 |
283 | # Copy the files for the training-set.
284 | _copy_files(src_paths=self.get_paths(test=False),
285 | dst_dir=train_dir,
286 | class_numbers=self.class_numbers)
287 |
288 | print("- Copied training-set to:", train_dir)
289 |
290 | # Copy the files for the test-set.
291 | _copy_files(src_paths=self.get_paths(test=True),
292 | dst_dir=test_dir,
293 | class_numbers=self.class_numbers_test)
294 |
295 | print("- Copied test-set to:", test_dir)
296 |
297 |
298 | ########################################################################
299 |
300 |
301 | def load_cached(cache_path, in_dir):
302 | """
303 | Wrapper-function for creating a DataSet-object, which will be
304 | loaded from a cache-file if it already exists, otherwise a new
305 | object will be created and saved to the cache-file.
306 | This is useful if you need to ensure the ordering of the
307 | filenames is consistent every time you load the data-set,
308 | for example if you use the DataSet-object in combination
309 | with Transfer Values saved to another cache-file, see e.g.
310 | Tutorial #09 for an example of this.
311 | :param cache_path:
312 | File-path for the cache-file.
313 | :param in_dir:
314 | Root-dir for the files in the data-set.
315 | This is an argument for the DataSet-init function.
316 | :return:
317 | The DataSet-object.
318 | """
319 |
320 | print("Creating dataset from the files in: " + in_dir)
321 |
322 | # If the object-instance for DataSet(in_dir=data_dir) already
323 | # exists in the cache-file then reload it, otherwise create
324 | # an object instance and save it to the cache-file for next time.
325 | dataset = cache(cache_path=cache_path,
326 | fn=DataSet, in_dir=in_dir)
327 |
328 | return dataset
329 |
330 | ########################################################################
331 |
--------------------------------------------------------------------------------
/src/dataset_.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------
2 | # Tensorflow WGAN-GP Implementation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Cheng-Bin Jin
5 | # Email: sbkim0407@gmail.com
6 | # ---------------------------------------------------------
7 | import os
8 | import logging
9 | import numpy as np
10 | import scipy.misc
11 | import tensorflow as tf
12 |
13 | import utils as utils
14 |
15 | logger = logging.getLogger(__name__) # logger
16 | logger.setLevel(logging.INFO)
17 |
18 |
19 | def _init_logger(flags, log_path):
20 | if flags.is_train:
21 | formatter = logging.Formatter('%(asctime)s:%(name)s:%(message)s')
22 | # file handler
23 | file_handler = logging.FileHandler(os.path.join(log_path, 'dataset.log'))
24 | file_handler.setFormatter(formatter)
25 | file_handler.setLevel(logging.INFO)
26 | # stream handler
27 | stream_handler = logging.StreamHandler()
28 | stream_handler.setFormatter(formatter)
29 | # add handlers
30 | logger.addHandler(file_handler)
31 | logger.addHandler(stream_handler)
32 |
33 |
34 | class MnistDataset(object):
35 | def __init__(self, sess, flags, dataset_name):
36 | self.sess = sess
37 | self.flags = flags
38 | self.dataset_name = dataset_name
39 | self.image_size = (32, 32, 1)
40 | self.img_buffle = 100000 # image buffer for image shuflling
41 | self.num_trains, self.num_tests = 0, 0
42 |
43 | self.mnist_path = os.path.join('../../Data', self.dataset_name)
44 | self._load_mnist()
45 |
46 | def _load_mnist(self):
47 | logger.info('Load {} dataset...'.format(self.dataset_name))
48 | self.train_data, self.test_data = tf.keras.datasets.mnist.load_data()
49 | # self.train_data is tuple: (image, label)
50 | self.num_trains = self.train_data[0].shape[0]
51 | self.num_tests = self.test_data[0].shape[0]
52 |
53 | # TensorFlow Dataset API
54 | train_x, train_y = self.train_data
55 | dataset = tf.data.Dataset.from_tensor_slices(({'image': train_x}, train_y))
56 | dataset = dataset.shuffle(self.img_buffle).repeat().batch(self.flags.batch_size)
57 |
58 | iterator = dataset.make_one_shot_iterator()
59 | self.next_batch = iterator.get_next()
60 |
61 | logger.info('Load {} dataset SUCCESS!'.format(self.dataset_name))
62 | logger.info('Img size: {}'.format(self.image_size))
63 | logger.info('Num. of training data: {}'.format(self.num_trains))
64 |
65 | def train_next_batch(self, batch_size):
66 | batch_data = self.sess.run(self.next_batch)
67 | batch_imgs = batch_data[0]["image"]
68 | # batch_labels = batch_data[1]
69 |
70 | if self.flags.batch_size > batch_size:
71 | # reshape 784 vector to 28 x 28 x 1
72 | batch_imgs = np.reshape(batch_imgs[:batch_size], [batch_size, 28, 28])
73 | else:
74 | batch_imgs = np.reshape(batch_imgs, [self.flags.batch_size, 28, 28])
75 |
76 | imgs_32 = [scipy.misc.imresize(batch_imgs[idx], self.image_size[0:2])
77 | for idx in range(batch_imgs.shape[0])] # scipy.misc.imresize convert to uint8 type
78 | imgs_array = np.expand_dims(np.asarray(imgs_32).astype(np.float32), axis=3)
79 |
80 | return imgs_array / 127.5 - 1. # from [0., 255.] to [-1., 1.]
81 |
82 |
83 | class Cifar10(object):
84 | def __init__(self, flags, dataset_name):
85 | self.flags = flags
86 | self.dataset_name = dataset_name
87 | self.image_size = (32, 32, 3)
88 | self.num_trains = 0
89 |
90 | self.cifar10_path = os.path.join('../../Data', self.dataset_name)
91 | self._load_cifar10()
92 |
93 | def _load_cifar10(self):
94 | import cifar10
95 |
96 | cifar10.data_path = self.cifar10_path
97 | logger.info('Load {} dataset...'.format(self.dataset_name))
98 |
99 | # The CIFAR-10 data-set is about 163 MB and will be downloaded automatically if it is not
100 | # located in the given path.
101 | cifar10.maybe_download_and_extract()
102 |
103 | self.train_data, _, _ = cifar10.load_training_data()
104 | self.num_trains = self.train_data.shape[0]
105 |
106 | logger.info('Load {} dataset SUCCESS!'.format(self.dataset_name))
107 | logger.info('Img size: {}'.format(self.image_size))
108 | logger.info('Num. of training data: {}'.format(self.num_trains))
109 |
110 | def train_next_batch(self, batch_size):
111 | batch_imgs = self.train_data[np.random.choice(self.num_trains, batch_size, replace=False)]
112 | return batch_imgs * 2. - 1. # from [0., 1.] to [-1., 1.]
113 |
114 |
115 | class ImageNet64(object):
116 | def __init__(self, flags, dataset_name):
117 | self.flags = flags
118 | self.dataset_name = dataset_name
119 | self.image_size = (64, 64, 3)
120 | self.num_trains = 0
121 |
122 | self.imagenet64_path = os.path.join('../../Data', self.dataset_name, 'train_64x64')
123 | self._load_imagenet64()
124 |
125 | def _load_imagenet64(self):
126 | logger.info('Load {} dataset...'.format(self.dataset_name))
127 | self.train_data = utils.all_files_under(self.imagenet64_path, extension='.png')
128 | self.num_trains = len(self.train_data)
129 |
130 | logger.info('Load {} dataset SUCCESS!'.format(self.dataset_name))
131 | logger.info('Img size: {}'.format(self.image_size))
132 | logger.info('Num. of training data: {}'.format(self.num_trains))
133 |
134 | def train_next_batch(self, batch_size):
135 | batch_paths = np.random.choice(self.train_data, batch_size, replace=False)
136 | batch_imgs = [utils.load_data(batch_path, is_gray_scale=False) for batch_path in batch_paths]
137 | return np.asarray(batch_imgs)
138 |
139 |
140 | # noinspection PyPep8Naming
141 | def Dataset(sess, flags, dataset_name, log_path=None):
142 | if flags.is_train:
143 | _init_logger(flags, log_path) # init logger
144 |
145 | if dataset_name == 'mnist':
146 | return MnistDataset(sess, flags, dataset_name)
147 | elif dataset_name == 'cifar10':
148 | return Cifar10(flags, dataset_name)
149 | elif dataset_name == 'imagenet64':
150 | return ImageNet64(flags, dataset_name)
151 | else:
152 | raise NotImplementedError
153 |
--------------------------------------------------------------------------------
/src/download.py:
--------------------------------------------------------------------------------
1 | ########################################################################
2 | #
3 | # Functions for downloading and extracting data-files from the internet.
4 | #
5 | # Implemented in Python 3.5
6 | #
7 | ########################################################################
8 | #
9 | # This file is part of the TensorFlow Tutorials available at:
10 | #
11 | # https://github.com/Hvass-Labs/TensorFlow-Tutorials
12 | #
13 | # Published under the MIT License. See the file LICENSE for details.
14 | #
15 | # Copyright 2016 by Magnus Erik Hvass Pedersen
16 | #
17 | ########################################################################
18 |
19 | import sys
20 | import os
21 | import urllib.request
22 | import tarfile
23 | import zipfile
24 |
25 | ########################################################################
26 |
27 |
28 | def _print_download_progress(count, block_size, total_size):
29 | """
30 | Function used for printing the download progress.
31 | Used as a call-back function in maybe_download_and_extract().
32 | """
33 |
34 | # Percentage completion.
35 | pct_complete = float(count * block_size) / total_size
36 |
37 | # Status-message. Note the \r which means the line should overwrite itself.
38 | msg = "\r- Download progress: {0:.1%}".format(pct_complete)
39 |
40 | # Print it.
41 | sys.stdout.write(msg)
42 | sys.stdout.flush()
43 |
44 |
45 | ########################################################################
46 |
47 |
48 | def maybe_download_and_extract(url, download_dir):
49 | """
50 | Download and extract the data if it doesn't already exist.
51 | Assumes the url is a tar-ball file.
52 | :param url:
53 | Internet URL for the tar-file to download.
54 | Example: "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
55 | :param download_dir:
56 | Directory where the downloaded file is saved.
57 | Example: "data/CIFAR-10/"
58 | :return:
59 | Nothing.
60 | """
61 |
62 | # Filename for saving the file downloaded from the internet.
63 | # Use the filename from the URL and add it to the download_dir.
64 | filename = url.split('/')[-1]
65 | file_path = os.path.join(download_dir, filename)
66 |
67 | # Check if the file already exists.
68 | # If it exists then we assume it has also been extracted,
69 | # otherwise we need to download and extract it now.
70 | if not os.path.exists(file_path):
71 | # Check if the download directory exists, otherwise create it.
72 | if not os.path.exists(download_dir):
73 | os.makedirs(download_dir)
74 |
75 | # Download the file from the internet.
76 | file_path, _ = urllib.request.urlretrieve(url=url,
77 | filename=file_path,
78 | reporthook=_print_download_progress)
79 |
80 | print()
81 | print("Download finished. Extracting files.")
82 |
83 | if file_path.endswith(".zip"):
84 | # Unpack the zip-file.
85 | zipfile.ZipFile(file=file_path, mode="r").extractall(download_dir)
86 | elif file_path.endswith((".tar.gz", ".tgz")):
87 | # Unpack the tar-ball.
88 | tarfile.open(name=file_path, mode="r:gz").extractall(download_dir)
89 |
90 | print("Done.")
91 | else:
92 | print("Data has apparently already been downloaded and unpacked.")
93 |
94 |
95 | ########################################################################
96 |
--------------------------------------------------------------------------------
/src/download2.py:
--------------------------------------------------------------------------------
1 | """
2 | Modification of https://github.com/stanfordnlp/treelstm/blob/master/scripts/download.py
3 |
4 | Downloads the following:
5 | - Celeb-A dataset
6 | - LSUN dataset
7 | - MNIST dataset
8 | """
9 |
10 | import os
11 | import sys
12 | import json
13 | import zipfile
14 | import argparse
15 | import requests
16 | import subprocess
17 | from tqdm import tqdm
18 | from six.moves import urllib
19 |
20 | parser = argparse.ArgumentParser(description='Download dataset for DCGAN.')
21 | parser.add_argument('datasets', metavar='N', type=str, nargs='+', choices=['celebA', 'lsun', 'mnist'],
22 | help='name of dataset to download [celebA, lsun, mnist]')
23 |
24 |
25 | def download(url, dirpath):
26 | filename = url.split('/')[-1]
27 | filepath = os.path.join(dirpath, filename)
28 | u = urllib.request.urlopen(url)
29 | f = open(filepath, 'wb')
30 | filesize = int(u.headers["Content-Length"])
31 | print("Downloading: %s Bytes: %s" % (filename, filesize))
32 |
33 | downloaded = 0
34 | block_sz = 8192
35 | status_width = 70
36 |
37 | while True:
38 | buf = u.read(block_sz)
39 | if not buf:
40 | print('')
41 | break
42 | else:
43 | print('', end='\r')
44 | downloaded += len(buf)
45 | f.write(buf)
46 | status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") %
47 | ('=' * int(float(downloaded) / filesize * status_width) + '>',
48 | downloaded * 100. / filesize))
49 | print(status, end='')
50 | sys.stdout.flush()
51 | f.close()
52 |
53 | return filepath
54 |
55 |
56 | def download_file_from_google_drive(id_, destination):
57 | url = "https://docs.google.com/uc?export=download"
58 | session = requests.Session()
59 |
60 | response = session.get(url, params={'id': id_}, stream=True)
61 | token = get_confirm_token(response)
62 |
63 | if token:
64 | params = {'id': id_, 'confirm': token}
65 | response = session.get(url, params=params, stream=True)
66 |
67 | save_response_content(response, destination)
68 |
69 |
70 | def get_confirm_token(response):
71 | for key, value in response.cookies.items():
72 | if key.startswith('download_warning'):
73 | return value
74 |
75 | return None
76 |
77 |
78 | def save_response_content(response, destination, chunk_size=32*1024):
79 | total_size = int(response.headers.get('content-length', 0))
80 | with open(destination, "wb") as f:
81 | for chunk in tqdm(response.iter_content(chunk_size), total=total_size, unit='B',
82 | unit_scale=True, desc=destination):
83 | if chunk: # filter out keep-alive new chunks
84 | f.write(chunk)
85 |
86 |
87 | def unzip(filepath):
88 | print("Extracting: " + filepath)
89 | dirpath = os.path.dirname(filepath)
90 | with zipfile.ZipFile(filepath) as zf:
91 | zf.extractall(dirpath)
92 | os.remove(filepath)
93 |
94 |
95 | def download_celeb_a(dirpath):
96 | data_dir = 'celebA'
97 | if os.path.exists(os.path.join(dirpath, data_dir)):
98 | print('Found Celeb-A - skip')
99 | return
100 |
101 | filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM"
102 | save_path = os.path.join(dirpath, filename)
103 |
104 | if os.path.exists(save_path):
105 | print('[*] {} already exists'.format(save_path))
106 | else:
107 | download_file_from_google_drive(drive_id, save_path)
108 |
109 | # zip_dir = ''
110 | with zipfile.ZipFile(save_path) as zf:
111 | zip_dir = zf.namelist()[0]
112 | zf.extractall(dirpath)
113 | os.remove(save_path)
114 | os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir))
115 |
116 |
117 | def _list_categories(tag):
118 | url = 'http://lsun.cs.princeton.edu/htbin/list.cgi?tag=' + tag
119 | f = urllib.request.urlopen(url)
120 |
121 | return json.loads(f.read())
122 |
123 |
124 | def _download_lsun(out_dir, category, set_name, tag):
125 | url = 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag={tag}' \
126 | '&category={category}&set={set_name}'.format(**locals())
127 | print(url)
128 | if set_name == 'test':
129 | out_name = 'test_lmdb.zip'
130 | else:
131 | out_name = '{category}_{set_name}_lmdb.zip'.format(**locals())
132 |
133 | out_path = os.path.join(out_dir, out_name)
134 | cmd = ['curl', url, '-o', out_path]
135 | print('Downloading', category, set_name, 'set')
136 | subprocess.call(cmd)
137 |
138 |
139 | def download_lsun(dirpath):
140 | data_dir = os.path.join(dirpath, 'lsun')
141 | if os.path.exists(data_dir):
142 | print('Found LSUN - skip')
143 | return
144 | else:
145 | os.mkdir(data_dir)
146 |
147 | tag = 'latest'
148 | # categories = _list_categories(tag)
149 | categories = ['bedroom']
150 |
151 | for category in categories:
152 | _download_lsun(data_dir, category, 'train', tag)
153 | _download_lsun(data_dir, category, 'val', tag)
154 | _download_lsun(data_dir, '', 'test', tag)
155 |
156 |
157 | def download_mnist(dirpath):
158 | data_dir = os.path.join(dirpath, 'mnist')
159 | if os.path.exists(data_dir):
160 | print('Found MNIST - skip')
161 | return
162 | else:
163 | os.mkdir(data_dir)
164 |
165 | url_base = 'http://yann.lecun.com/exdb/mnist/'
166 | file_names = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz',
167 | 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
168 |
169 | for file_name in file_names:
170 | url = (url_base+file_name).format(**locals())
171 | print(url)
172 | out_path = os.path.join(data_dir, file_name)
173 | cmd = ['curl', url, '-o', out_path]
174 | print('Downloading ', file_name)
175 | subprocess.call(cmd)
176 | cmd = ['gzip', '-d', out_path]
177 | print('Decompressing ', file_name)
178 | subprocess.call(cmd)
179 |
180 |
181 | def prepare_data_dir(path='./data'):
182 | if not os.path.exists(path):
183 | os.mkdir(path)
184 |
185 |
186 | if __name__ == '__main__':
187 | args = parser.parse_args()
188 | prepare_data_dir()
189 |
190 | if any(name in args.datasets for name in ['CelebA', 'celebA', 'celeba']):
191 | download_celeb_a('./data') # download celebeA dataset
192 | if 'lsun' in args.datasets:
193 | download_lsun('./data') # download LSUN bedroom dataset
194 | if 'mnist' in args.datasets:
195 | download_mnist('./data') # download mnist dataset
196 |
--------------------------------------------------------------------------------
/src/inception_score.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------
2 | # Tensorflow Inception Score Implementation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # From https://github.com/openai/improved-gan/blob/master/inception_score/model.py
5 | # Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py
6 | # Reimplement by Cheng-Bin Jin
7 | # Email: sbkim0407@gmail.com
8 | # ---------------------------------------------------------
9 | from __future__ import absolute_import
10 | from __future__ import division
11 | from __future__ import print_function
12 |
13 | import os.path
14 | import tarfile
15 |
16 | import numpy as np
17 | from six.moves import urllib
18 | import tensorflow as tf
19 | import math
20 | import sys
21 |
22 | MODEL_DIR = './imagenet'
23 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
24 | softmax = None
25 |
26 |
27 | # Call this function with list of images. Each of elements should be a numpy array with values ranging from 0 to 255.
28 | def get_inception_score(images, flags, splits=10):
29 | os.environ['CUDA_VISIBLE_DEVICES'] = flags.gpu_index
30 |
31 | assert(type(images) == list)
32 | assert(type(images[0]) == np.ndarray)
33 | assert(len(images[0].shape) == 3)
34 | assert(np.max(images[0]) > 10)
35 | assert(np.min(images[0]) >= 0.0)
36 |
37 | inps = []
38 | for img in images:
39 | img = img.astype(np.float32)
40 | inps.append(np.expand_dims(img, 0))
41 |
42 | bs = 1 # original is 100
43 | with tf.Session() as sess:
44 | preds = []
45 | n_batches = int(math.ceil(float(len(inps)) / float(bs)))
46 |
47 | for i in range(n_batches):
48 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))]
49 | inp = np.concatenate(inp, 0)
50 | pred = sess.run(softmax, {'ExpandDims:0': inp})
51 | preds.append(pred)
52 |
53 | preds = np.concatenate(preds, 0)
54 | scores = []
55 |
56 | for i in range(splits):
57 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
58 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, axis=0), axis=0)))
59 | kl = np.mean(np.sum(kl, axis=1))
60 | scores.append(np.exp(kl))
61 |
62 | return np.mean(scores), np.std(scores)
63 |
64 |
65 | # This function is called automatically.
66 | def _init_inception():
67 | global softmax
68 | if not os.path.exists(MODEL_DIR):
69 | os.makedirs(MODEL_DIR)
70 |
71 | # DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
72 | filename = DATA_URL.split('/')[-1]
73 | filepath = os.path.join(MODEL_DIR, filename)
74 |
75 | if not os.path.exists(filepath):
76 | def _progress(count, block_size, total_size):
77 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (
78 | filename, float(count * block_size) / float(total_size) * 100.0))
79 | sys.stdout.flush()
80 |
81 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
82 | print()
83 | statinfo = os.stat(filepath)
84 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
85 |
86 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
87 | with tf.gfile.FastGFile(os.path.join(MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
88 | graph_def = tf.GraphDef()
89 | graph_def.ParseFromString(f.read())
90 | _ = tf.import_graph_def(graph_def, name='')
91 |
92 | # Works with an arbitrary minibatch size.
93 | with tf.Session() as sess:
94 | pool3 = sess.graph.get_tensor_by_name('pool_3:0')
95 | ops = pool3.graph.get_operations()
96 |
97 | for op_idx, op in enumerate(ops):
98 | for o in op.outputs:
99 | shape = o.get_shape()
100 | shape = [s.value for s in shape]
101 | new_shape = []
102 |
103 | for j, s in enumerate(shape):
104 | if s == 1 and j == 0:
105 | new_shape.append(None)
106 | else:
107 | new_shape.append(s)
108 |
109 | o.set_shape = tf.TensorShape(new_shape)
110 |
111 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
112 | logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w)
113 | softmax = tf.nn.softmax(logits)
114 |
115 |
116 | if softmax is None:
117 | _init_inception()
118 |
--------------------------------------------------------------------------------
/src/jupyter/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/src/jupyter/gan_toy.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import os\n",
12 | "import random\n",
13 | "import numpy as np\n",
14 | "import sklearn.datasets\n",
15 | "import tensorflow as tf\n",
16 | "import tensorflow_utils as tf_utils\n",
17 | "import utils as utils\n",
18 | "\n",
19 | "import matplotlib\n",
20 | "matplotlib.use('Agg')\n",
21 | "import matplotlib.pyplot as plt"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 2,
27 | "metadata": {},
28 | "outputs": [
29 | {
30 | "name": "stdout",
31 | "output_type": "stream",
32 | "text": [
33 | "Uppercase local vars:\n",
34 | "\tBATCH_SIZE: 256\n",
35 | "\tCRITIC_ITERS: 1\n",
36 | "\tDIM: 512\n",
37 | "\tDIS_DIM: 1\n",
38 | "\tFIXED_GENERATOR: False\n",
39 | "\tFREQ: 1000\n",
40 | "\tGEN_DIM: 2\n",
41 | "\tITERS: 100000\n",
42 | "\tLAMBDA: 0.1\n"
43 | ]
44 | }
45 | ],
46 | "source": [
47 | "DIM = 512 # model dimensionality\n",
48 | "GEN_DIM = 2 # output dimension of the generator\n",
49 | "DIS_DIM = 1 # outptu dimension fo the discriminator\n",
50 | "FIXED_GENERATOR = False # wheter to hold the generator fixed at ral data plus Gaussian noise, as in the plots in the paper\n",
51 | "LAMBDA = .1 # smaller lambda makes things faster for toy tasks, but isn't necessary if you increase CRITIC_ITERS enough\n",
52 | "BATCH_SIZE = 256 # batch size\n",
53 | "ITERS = 100000 # how many generator iterations to train for\n",
54 | "FREQ = 1000 # sample frequency\n",
55 | "\n",
56 | "mode = 'wgan-g' # [gan, wgan, wgan-gp]\n",
57 | "dataset = 'swissroll' # [8gaussians, 25gaussians, swissroll]\n",
58 | "img_folder = os.path.join('img', mode + '_' + dataset + '_' + str(FIXED_GENERATOR))\n",
59 | "\n",
60 | "if mode == 'gan':\n",
61 | " CRITIC_ITERS = 1 # homw many critic iteractions per generator iteration\n",
62 | "else:\n",
63 | " CRITIC_ITERS = 5 # homw many critic iteractions per generator iteration\n",
64 | "\n",
65 | "if not os.path.isdir(img_folder):\n",
66 | " os.makedirs(img_folder)\n",
67 | "\n",
68 | "utils.print_model_setting(locals().copy())"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": 3,
74 | "metadata": {},
75 | "outputs": [
76 | {
77 | "name": "stdout",
78 | "output_type": "stream",
79 | "text": [
80 | "gen/fc-1/add [256, 512]\n",
81 | "gen/fc-2/add [256, 512]\n",
82 | "gen/fc-3/add [256, 512]\n",
83 | "gen/fc-4/add [256, 2]\n",
84 | "is_reuse: False\n",
85 | "disc/fc-1/add [None, 512]\n",
86 | "disc/fc-2/add [None, 512]\n",
87 | "disc/fc-3/add [None, 512]\n",
88 | "disc/fc-4/add [None, 1]\n",
89 | "is_reuse: True\n",
90 | "disc_1/fc-1/add [256, 512]\n",
91 | "disc_1/fc-2/add [256, 512]\n",
92 | "disc_1/fc-3/add [256, 512]\n",
93 | "disc_1/fc-4/add [256, 1]\n"
94 | ]
95 | }
96 | ],
97 | "source": [
98 | "def Generator(n_samples, real_data_, name='gen'):\n",
99 | " if FIXED_GENERATOR:\n",
100 | " return real_data_ + (1. * tf.random_normal(tf.shape(real_data_)))\n",
101 | " else:\n",
102 | " with tf.variable_scope(name):\n",
103 | " noise = tf.random_normal([n_samples, 2])\n",
104 | " output01 = tf_utils.linear(noise, DIM, name='fc-1')\n",
105 | " output01 = tf_utils.relu(output01, name='relu-1')\n",
106 | " \n",
107 | " output02 = tf_utils.linear(output01, DIM, name='fc-2')\n",
108 | " output02 = tf_utils.relu(output02, name='relu-2')\n",
109 | " \n",
110 | " output03 = tf_utils.linear(output02, DIM, name='fc-3')\n",
111 | " output03 = tf_utils.relu(output03, name='relu-3')\n",
112 | " \n",
113 | " output04 = tf_utils.linear(output03, GEN_DIM, name='fc-4')\n",
114 | " \n",
115 | " return output04\n",
116 | " \n",
117 | "\n",
118 | "def Discriminator(inputs, is_reuse=True, name='disc'):\n",
119 | " with tf.variable_scope(name, reuse=is_reuse):\n",
120 | " print('is_reuse: {}'.format(is_reuse))\n",
121 | " output01 = tf_utils.linear(inputs, DIM, name='fc-1')\n",
122 | " output01 = tf_utils.relu(output01, name='relu-1')\n",
123 | "\n",
124 | " output02 = tf_utils.linear(output01, DIM, name='fc-2')\n",
125 | " output02 = tf_utils.relu(output02, name='relu-2')\n",
126 | "\n",
127 | " output03 = tf_utils.linear(output02, DIM, name='fc-3')\n",
128 | " output03 = tf_utils.relu(output03, name='relu-3')\n",
129 | "\n",
130 | " output04 = tf_utils.linear(output03, DIS_DIM, name='fc-4')\n",
131 | " \n",
132 | " return output04\n",
133 | " \n",
134 | "real_data = tf.placeholder(tf.float32, shape=[None, 2])\n",
135 | "fake_data = Generator(BATCH_SIZE, real_data)\n",
136 | "\n",
137 | "disc_real = Discriminator(real_data, is_reuse=False)\n",
138 | "disc_fake = Discriminator(fake_data)\n",
139 | "\n",
140 | "if mode == 'wgan' or mode == 'wgan-gp':\n",
141 | " # WGAN loss\n",
142 | " disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)\n",
143 | " gen_cost = - tf.reduce_mean(disc_fake)\n",
144 | "\n",
145 | " # WGAN gradient penalty\n",
146 | " if mode == 'wgan-gp':\n",
147 | " alpha = tf.random_uniform(shape=[BATCH_SIZE, 1], minval=0., maxval=1.)\n",
148 | " interpolates = alpha*real_data + (1.-alpha) * fake_data\n",
149 | " disc_interpolates = Discriminator(interpolates)\n",
150 | " gradients = tf.gradients(disc_interpolates, [interpolates][0])\n",
151 | " slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))\n",
152 | " gradient_penalty = tf.reduce_mean((slopes - 1)**2)\n",
153 | " \n",
154 | " disc_cost += LAMBDA * gradient_penalty\n",
155 | "elif mode == 'gan':\n",
156 | " gen_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=tf.ones_like(disc_fake)))\n",
157 | " \n",
158 | " disc_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=tf.zeros_like(disc_fake)))\n",
159 | " disc_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real, labels=tf.ones_like(disc_real)))\n",
160 | " disc_cost /= 2.\n",
161 | " \n",
162 | "disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='disc')\n",
163 | "gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='gen')\n",
164 | "\n",
165 | "if mode == 'wgan-gp':\n",
166 | " disc_train_op = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9).minimize(disc_cost, var_list=disc_vars)\n",
167 | " \n",
168 | " if len(gen_vars) > 0:\n",
169 | " gen_train_op = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9).minimize(gen_cost, var_list=gen_vars)\n",
170 | " else:\n",
171 | " gen_train_op = tf.no_op()\n",
172 | " \n",
173 | "elif mode == 'wgan':\n",
174 | " disc_train_op = tf.train.RMSPropOptimizer(learning_rate=5e-5).minimize(disc_cost, var_list=disc_vars)\n",
175 | " \n",
176 | " if len(gen_vars) > 0:\n",
177 | " gen_train_op = tf.train.RMSPropOptimizer(learning_rate=5e-5).minimize(gen_cost, var_list=gen_vars)\n",
178 | " else:\n",
179 | " gen_train_op = tf.no_op()\n",
180 | " \n",
181 | " # build an op to do the weight clipping\n",
182 | " clip_bounds = [-0.01, 0.01]\n",
183 | " clip_ops = [var.assign(tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])) for var in disc_vars]\n",
184 | " clip_disc_weights = tf.group(*clip_ops)\n",
185 | " \n",
186 | "elif mode == 'gan':\n",
187 | " disc_train_op = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5).minimize(disc_cost, var_list=disc_vars)\n",
188 | " \n",
189 | " if len(gen_vars) > 0: \n",
190 | " gen_train_op = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5).minimize(gen_cost, var_list=gen_vars)\n",
191 | " else:\n",
192 | " gen_train_op = tf.no_op()"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": 4,
198 | "metadata": {
199 | "collapsed": true
200 | },
201 | "outputs": [],
202 | "source": [
203 | "def generate_image(sess, true_dist, idx):\n",
204 | " # generates and saves a plot of the true distribution, the generator, and the critic\n",
205 | " N_POINTS = 128\n",
206 | " RANGE = 2\n",
207 | " \n",
208 | " points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')\n",
209 | " points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]\n",
210 | " points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]\n",
211 | " points = points.reshape((-1, 2))\n",
212 | " \n",
213 | " if FIXED_GENERATOR is not True:\n",
214 | " samples = sess.run(fake_data, feed_dict={real_data: points})\n",
215 | " disc_map = sess.run(disc_real, feed_dict={real_data: points})\n",
216 | " \n",
217 | " plt.clf()\n",
218 | " x = y = np.linspace(-RANGE, RANGE, N_POINTS)\n",
219 | " plt.contour(x, y, disc_map.reshape((len(x), len(y))).transpose())\n",
220 | " plt.colorbar() # add color bar\n",
221 | " \n",
222 | " plt.scatter(true_dist[:, 0], true_dist[:, 1], c='orange', marker='+')\n",
223 | " if FIXED_GENERATOR is not True:\n",
224 | " plt.scatter(samples[:, 0], samples[:, 1], c='green', marker='*')\n",
225 | " \n",
226 | " plt.savefig(os.path.join(img_folder, str(idx).zfill(3) + '.jpg'))"
227 | ]
228 | },
229 | {
230 | "cell_type": "code",
231 | "execution_count": 5,
232 | "metadata": {
233 | "collapsed": true
234 | },
235 | "outputs": [],
236 | "source": [
237 | "# Dataset iterator\n",
238 | "def inf_train_gen():\n",
239 | " if dataset == '8gaussians':\n",
240 | " scale = 2.\n",
241 | " centers = [(1.,0.), \n",
242 | " (-1.,0.), \n",
243 | " (0., 1.), \n",
244 | " (0.,-1.),\n",
245 | " (1./np.sqrt(2), 1./np.sqrt(2)),\n",
246 | " (1./np.sqrt(2), -1/np.sqrt(2)), \n",
247 | " (-1./np.sqrt(2), 1./np.sqrt(2)), \n",
248 | " (-1./np.sqrt(2), -1./np.sqrt(2))]\n",
249 | " \n",
250 | " centers = [(scale*x, scale*y) for x, y in centers]\n",
251 | " while True:\n",
252 | " batch_data = []\n",
253 | " for _ in range(BATCH_SIZE):\n",
254 | " point = np.random.randn(2) * .02\n",
255 | " center = random.choice(centers)\n",
256 | " point[0] += center[0]\n",
257 | " point[1] += center[1]\n",
258 | " batch_data.append(point)\n",
259 | " \n",
260 | " batch_data = np.array(batch_data, dtype=np.float32)\n",
261 | " batch_data /= 1.414 # std\n",
262 | " yield batch_data\n",
263 | " \n",
264 | " elif dataset == '25gaussians':\n",
265 | " batch_data = []\n",
266 | " for i_ in range(4000):\n",
267 | " for x in range(-2, 3):\n",
268 | " for y in range(-2, 3):\n",
269 | " point = np.random.randn(2) * 0.05\n",
270 | " point[0] += 2*x\n",
271 | " point[1] += 2*y\n",
272 | " batch_data.append(point)\n",
273 | " \n",
274 | " batch_data = np.asarray(batch_data, dtype=np.float32)\n",
275 | " np.random.shuffle(batch_data)\n",
276 | " batch_data /= 2.828 # std\n",
277 | " \n",
278 | " while True:\n",
279 | " for i_ in range(int(len(batch_data)/BATCH_SIZE)):\n",
280 | " yield batch_data[i_*BATCH_SIZE:(i_+1)*BATCH_SIZE]\n",
281 | " \n",
282 | " elif dataset == 'swissroll':\n",
283 | " while True:\n",
284 | " batch_data = sklearn.datasets.make_swiss_roll(n_samples=BATCH_SIZE, noise=0.25)[0]\n",
285 | " batch_data = batch_data.astype(np.float32)[:, [0, 2]]\n",
286 | " batch_data /= 7.5 # stdev plus a little\n",
287 | " yield batch_data"
288 | ]
289 | },
290 | {
291 | "cell_type": "code",
292 | "execution_count": 6,
293 | "metadata": {},
294 | "outputs": [
295 | {
296 | "name": "stdout",
297 | "output_type": "stream",
298 | "text": [
299 | "iter 0\tdisc cost\t0.7139465808868408\n",
300 | "iter 1000\tdisc cost\t0.6814616771042347\n",
301 | "iter 2000\tdisc cost\t0.6737550102472305\n",
302 | "iter 3000\tdisc cost\t0.6920662943720818\n",
303 | "iter 4000\tdisc cost\t0.6906127047538757\n",
304 | "iter 5000\tdisc cost\t0.6930880607366562\n",
305 | "iter 6000\tdisc cost\t0.6931662155985833\n",
306 | "iter 7000\tdisc cost\t0.6919992100596428\n",
307 | "iter 8000\tdisc cost\t0.6933745203018189\n",
308 | "iter 9000\tdisc cost\t0.6926324440836906\n",
309 | "iter 10000\tdisc cost\t0.6933235999941826\n",
310 | "iter 11000\tdisc cost\t0.6933176911473274\n",
311 | "iter 12000\tdisc cost\t0.6923655412197113\n",
312 | "iter 13000\tdisc cost\t0.6932780802249908\n",
313 | "iter 14000\tdisc cost\t0.6932716737985611\n",
314 | "iter 15000\tdisc cost\t0.6932756861448288\n",
315 | "iter 16000\tdisc cost\t0.6915568546652794\n",
316 | "iter 17000\tdisc cost\t0.6933772000074386\n",
317 | "iter 18000\tdisc cost\t0.6933294381499291\n",
318 | "iter 19000\tdisc cost\t0.6924007230997086\n",
319 | "iter 20000\tdisc cost\t0.6928844445347786\n",
320 | "iter 21000\tdisc cost\t0.6933696148395538\n",
321 | "iter 22000\tdisc cost\t0.6933540980815888\n",
322 | "iter 23000\tdisc cost\t0.6933285065889359\n",
323 | "iter 24000\tdisc cost\t0.6918070427775384\n",
324 | "iter 25000\tdisc cost\t0.6933358488082886\n",
325 | "iter 26000\tdisc cost\t0.6933039631843567\n",
326 | "iter 27000\tdisc cost\t0.6932847768068313\n",
327 | "iter 28000\tdisc cost\t0.6932973915338516\n",
328 | "iter 29000\tdisc cost\t0.6932247024178505\n",
329 | "iter 30000\tdisc cost\t0.6932735745310783\n",
330 | "iter 31000\tdisc cost\t0.6932096065282821\n",
331 | "iter 32000\tdisc cost\t0.6890769819617272\n",
332 | "iter 33000\tdisc cost\t0.6923709976673126\n",
333 | "iter 34000\tdisc cost\t0.6931670038104057\n",
334 | "iter 35000\tdisc cost\t0.6929483571052552\n",
335 | "iter 36000\tdisc cost\t0.6932186397910118\n",
336 | "iter 37000\tdisc cost\t0.6932798113822937\n",
337 | "iter 38000\tdisc cost\t0.6932742780447007\n",
338 | "iter 39000\tdisc cost\t0.6931445201039315\n",
339 | "iter 40000\tdisc cost\t0.6932249497771263\n",
340 | "iter 41000\tdisc cost\t0.6932100749611855\n",
341 | "iter 42000\tdisc cost\t0.6929260883331299\n",
342 | "iter 43000\tdisc cost\t0.6932117487192154\n",
343 | "iter 44000\tdisc cost\t0.6931842603087425\n",
344 | "iter 45000\tdisc cost\t0.6931332590579986\n",
345 | "iter 46000\tdisc cost\t0.6931791034936905\n",
346 | "iter 47000\tdisc cost\t0.693181702375412\n",
347 | "iter 48000\tdisc cost\t0.6931661418676376\n",
348 | "iter 49000\tdisc cost\t0.6932356751561165\n",
349 | "iter 50000\tdisc cost\t0.6931068043708801\n",
350 | "iter 51000\tdisc cost\t0.6931518998742103\n",
351 | "iter 52000\tdisc cost\t0.6929512389302254\n",
352 | "iter 53000\tdisc cost\t0.69314514118433\n",
353 | "iter 54000\tdisc cost\t0.6931488544344903\n",
354 | "iter 55000\tdisc cost\t0.6926609833240509\n",
355 | "iter 56000\tdisc cost\t0.6931225464940071\n",
356 | "iter 57000\tdisc cost\t0.693101442694664\n",
357 | "iter 58000\tdisc cost\t0.6929626023769379\n",
358 | "iter 59000\tdisc cost\t0.6930847532749176\n",
359 | "iter 60000\tdisc cost\t0.6928317602872849\n",
360 | "iter 61000\tdisc cost\t0.6930897635817528\n",
361 | "iter 62000\tdisc cost\t0.6931442449092865\n",
362 | "iter 63000\tdisc cost\t0.6931378693580628\n",
363 | "iter 64000\tdisc cost\t0.6931423845887185\n",
364 | "iter 65000\tdisc cost\t0.6929416499137878\n",
365 | "iter 66000\tdisc cost\t0.6930960271954536\n",
366 | "iter 67000\tdisc cost\t0.6931367915272713\n",
367 | "iter 68000\tdisc cost\t0.6931659261584282\n",
368 | "iter 69000\tdisc cost\t0.6931650475263595\n",
369 | "iter 70000\tdisc cost\t0.6930810611844063\n",
370 | "iter 71000\tdisc cost\t0.6931314595341682\n",
371 | "iter 72000\tdisc cost\t0.6931117633581162\n",
372 | "iter 73000\tdisc cost\t0.6931591765880585\n",
373 | "iter 74000\tdisc cost\t0.6929754077196121\n",
374 | "iter 75000\tdisc cost\t0.693103608250618\n",
375 | "iter 76000\tdisc cost\t0.6931282907724381\n",
376 | "iter 77000\tdisc cost\t0.6930146722197532\n",
377 | "iter 78000\tdisc cost\t0.6930771964788437\n",
378 | "iter 79000\tdisc cost\t0.6931223605871201\n",
379 | "iter 80000\tdisc cost\t0.693151865184307\n",
380 | "iter 81000\tdisc cost\t0.693146475315094\n",
381 | "iter 82000\tdisc cost\t0.6931274979114532\n",
382 | "iter 83000\tdisc cost\t0.6924246903657914\n",
383 | "iter 84000\tdisc cost\t0.6930215736031532\n",
384 | "iter 85000\tdisc cost\t0.6930515897274018\n",
385 | "iter 86000\tdisc cost\t0.6930611334443092\n",
386 | "iter 87000\tdisc cost\t0.6926919516324997\n",
387 | "iter 88000\tdisc cost\t0.6930547666549682\n",
388 | "iter 89000\tdisc cost\t0.6929932239055634\n",
389 | "iter 90000\tdisc cost\t0.6930882499814034\n",
390 | "iter 91000\tdisc cost\t0.6930873312354088\n",
391 | "iter 92000\tdisc cost\t0.6923705233931542\n",
392 | "iter 93000\tdisc cost\t0.6930174446702003\n",
393 | "iter 94000\tdisc cost\t0.6927620051503182\n",
394 | "iter 95000\tdisc cost\t0.69283284419775\n",
395 | "iter 96000\tdisc cost\t0.6930583512187004\n",
396 | "iter 97000\tdisc cost\t0.6930861786007881\n",
397 | "iter 98000\tdisc cost\t0.6927603552937508\n",
398 | "iter 99000\tdisc cost\t0.6929180439114571\n",
399 | "iter 99999\tdisc cost\t0.6930565155662216\n"
400 | ]
401 | }
402 | ],
403 | "source": [
404 | "# Train loop\n",
405 | "with tf.Session() as sess:\n",
406 | " sess.run(tf.global_variables_initializer())\n",
407 | " data_gen = inf_train_gen()\n",
408 | " \n",
409 | " for iter_ in range(ITERS):\n",
410 | " batch_data, disc_cost_ = None, None\n",
411 | " \n",
412 | " # train critic\n",
413 | " for i_ in range(CRITIC_ITERS):\n",
414 | " batch_data = data_gen.__next__()\n",
415 | " disc_cost_, _ = sess.run([disc_cost, disc_train_op], feed_dict={real_data: batch_data})\n",
416 | " \n",
417 | " if mode == 'wgan':\n",
418 | " sess.run(clip_disc_weights)\n",
419 | " \n",
420 | " # train generator\n",
421 | " sess.run(gen_train_op)\n",
422 | " \n",
423 | " # write logs and svae samples\n",
424 | " utils.plot('disc cost', disc_cost_)\n",
425 | " \n",
426 | " if (np.mod(iter_, FREQ) == 0) or (iter_+1 == ITERS):\n",
427 | " utils.flush(img_folder)\n",
428 | " generate_image(sess, batch_data, iter_)\n",
429 | " \n",
430 | " utils.tick() "
431 | ]
432 | }
433 | ],
434 | "metadata": {
435 | "kernelspec": {
436 | "display_name": "Python 3",
437 | "language": "python",
438 | "name": "python3"
439 | },
440 | "language_info": {
441 | "codemirror_mode": {
442 | "name": "ipython",
443 | "version": 3
444 | },
445 | "file_extension": ".py",
446 | "mimetype": "text/x-python",
447 | "name": "python",
448 | "nbconvert_exporter": "python",
449 | "pygments_lexer": "ipython3",
450 | "version": "3.5.5"
451 | }
452 | },
453 | "nbformat": 4,
454 | "nbformat_minor": 2
455 | }
456 |
--------------------------------------------------------------------------------
/src/jupyter/record_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 |
5 |
6 | def all_files_under(path, extension=None, special=None, append_path=True, sort=True):
7 | if append_path:
8 | if extension is None:
9 | filenames = [os.path.join(path, fname) for fname in os.listdir(path) if special not in fname]
10 | else:
11 | filenames = [os.path.join(path, fname) for fname in os.listdir(path)
12 | if (special not in fname) and (fname.endswith(extension))]
13 | else:
14 | if extension is None:
15 | filenames = [os.path.basename(fname) for fname in os.listdir(path) if special not in fname]
16 | else:
17 | filenames = [os.path.basename(fname) for fname in os.listdir(path)
18 | if (special not in fname) and (fname.endswith(extension))]
19 |
20 | if sort:
21 | filenames = sorted(filenames)
22 |
23 | return filenames
24 |
25 |
26 | def main(path_list, name_list):
27 | for idx_list, path in enumerate(path_list):
28 | gan_8gaussian_paths = all_files_under(os.path.join('img', path[0]), extension='jpg', special='disc')
29 | gan_25gaussian_paths = all_files_under(os.path.join('img', path[1]), extension='jpg', special='disc')
30 | gan_swissroll_paths = all_files_under(os.path.join('img', path[2]), extension='jpg', special='disc')
31 |
32 | wgan_8gaussian_paths = all_files_under(os.path.join('img', path[3]), extension='jpg', special='disc')
33 | wgan_25gaussian_paths = all_files_under(os.path.join('img', path[4]), extension='jpg', special='disc')
34 | wgan_swissroll_paths = all_files_under(os.path.join('img', path[5]), extension='jpg', special='disc')
35 |
36 | wgan_gp_8gaussian_paths = all_files_under(os.path.join('img', path[6]), extension='jpg', special='disc')
37 | wgan_gp_25gaussian_paths = all_files_under(os.path.join('img', path[7]), extension='jpg', special='disc')
38 | wgan_gp_swissroll_paths = all_files_under(os.path.join('img', path[8]), extension='jpg', special='disc')
39 |
40 | frame_shape = cv2.imread(wgan_8gaussian_paths[0]).shape
41 | print(frame_shape)
42 |
43 | # Define the codec and create VideoWriter object
44 | fourcc = cv2.VideoWriter_fourcc(*'XVID')
45 | video_writer = cv2.VideoWriter(name_list[idx_list], fourcc, 10.0, (frame_shape[1]*3, frame_shape[0]*3))
46 |
47 | for idx in range(len(wgan_8gaussian_paths)):
48 | img_gan8 = cv2.imread(gan_8gaussian_paths[idx])
49 | img_gan25 = cv2.imread(gan_25gaussian_paths[idx])
50 | img_gans = cv2.imread(gan_swissroll_paths[idx])
51 |
52 | img_wgan8 = cv2.imread(wgan_8gaussian_paths[idx])
53 | img_wgan25 = cv2.imread(wgan_25gaussian_paths[idx])
54 | img_wgans = cv2.imread(wgan_swissroll_paths[idx])
55 |
56 | img_wgangp8 = cv2.imread(wgan_gp_8gaussian_paths[idx])
57 | img_wgangp25 = cv2.imread(wgan_gp_25gaussian_paths[idx])
58 | img_wgangps = cv2.imread(wgan_gp_swissroll_paths[idx])
59 |
60 | frame_1 = np.hstack([img_gan8, img_gan25, img_gans])
61 | frame_2 = np.hstack([img_wgan8, img_wgan25, img_wgans])
62 | frame_3 = np.hstack([img_wgangp8, img_wgangp25, img_wgangps])
63 | frame = np.vstack([frame_1, frame_2, frame_3])
64 |
65 | # write the frame
66 | video_writer.write(frame)
67 |
68 | cv2.imshow('Show', frame)
69 | cv2.waitKey(1)
70 |
71 | # Release everything if job is finished
72 | video_writer.release()
73 | cv2.destroyAllWindows()
74 |
75 |
76 | if __name__ == '__main__':
77 | path01 = ['gan_8gaussians_True', 'gan_25gaussians_True', 'gan_swissroll_True',
78 | 'wgan_8gaussians_True', 'wgan_25gaussians_True', 'wgan_swissroll_True',
79 | 'wgan-gp_8gaussians_True', 'wgan-gp_25gaussians_True', 'wgan-gp_swissroll_True']
80 | path02 = ['gan_8gaussians_False', 'gan_25gaussians_False', 'gan_swissroll_False',
81 | 'wgan_8gaussians_False', 'wgan_25gaussians_False', 'wgan_swissroll_False',
82 | 'wgan-gp_8gaussians_False', 'wgan-gp_25gaussians_False', 'wgan-gp_swissroll_False']
83 | file_names = ['generator_fixed_true.mp4', 'generator_fixed_false.mp4']
84 |
85 | main([path01, path02], file_names)
86 |
--------------------------------------------------------------------------------
/src/jupyter/tensorflow_utils.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------
2 | # Tensorflow Utils Implementation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Cheng-Bin Jin
5 | # Email: sbkim0407@gmail.com
6 | # ---------------------------------------------------------
7 | import tensorflow as tf
8 | import tensorflow.contrib.slim as slim
9 | from tensorflow.python.training import moving_averages
10 |
11 |
12 | def padding2d(x, p_h=1, p_w=1, pad_type='REFLECT', name='pad2d'):
13 | if pad_type == 'REFLECT':
14 | return tf.pad(x, [[0, 0], [p_h, p_h], [p_w, p_w], [0, 0]], 'REFLECT', name=name)
15 |
16 |
17 | def conv2d(x, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, padding='SAME', name='conv2d', is_print=True):
18 | with tf.variable_scope(name):
19 | w = tf.get_variable('w', [k_h, k_w, x.get_shape()[-1], output_dim],
20 | initializer=tf.truncated_normal_initializer(stddev=stddev))
21 | conv = tf.nn.conv2d(x, w, strides=[1, d_h, d_w, 1], padding=padding)
22 |
23 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
24 | # conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
25 | conv = tf.nn.bias_add(conv, biases)
26 |
27 | if is_print:
28 | print_activations(conv)
29 |
30 | return conv
31 |
32 |
33 | def deconv2d(x, k, k_h=3, k_w=3, d_h=2, d_w=2, stddev=0.02, padding_='SAME', output_size=None,
34 | name='deconv2d', with_w=False, is_print=True):
35 | with tf.variable_scope(name):
36 | input_shape = x.get_shape().as_list()
37 |
38 | # calculate output size
39 | h_output, w_output = None, None
40 | if not output_size:
41 | h_output, w_output = input_shape[1] * 2, input_shape[2] * 2
42 | # output_shape = [input_shape[0], h_output, w_output, k] # error when not define batch_size
43 | output_shape = [tf.shape(x)[0], h_output, w_output, k]
44 |
45 | # conv2d transpose
46 | w = tf.get_variable('w', [k_h, k_w, k, input_shape[3]],
47 | initializer=tf.random_normal_initializer(stddev=stddev))
48 | deconv = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=[1, d_h, d_w, 1],
49 | padding=padding_)
50 |
51 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
52 | deconv = tf.nn.bias_add(deconv, biases)
53 |
54 | if is_print:
55 | print_activations(deconv)
56 |
57 | if with_w:
58 | return deconv, w, biases
59 | else:
60 | return deconv
61 |
62 |
63 | def upsampling2d(x, size=(2, 2), name='upsampling2d'):
64 | with tf.name_scope(name):
65 | shape = x.get_shape().as_list()
66 | return tf.image.resize_nearest_neighbor(x, size=(size[0] * shape[1], size[1] * shape[2]))
67 |
68 |
69 | def linear(x, output_size, bias_start=0.0, with_w=False, name='fc', is_print=True):
70 | shape = x.get_shape().as_list()
71 |
72 | with tf.variable_scope(name):
73 | matrix = tf.get_variable(name="matrix", shape=[shape[1], output_size],
74 | dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())
75 | bias = tf.get_variable(name="bias", shape=[output_size],
76 | initializer=tf.constant_initializer(bias_start))
77 | output = tf.matmul(x, matrix) + bias
78 |
79 | if is_print:
80 | print_activations(output)
81 |
82 | if with_w:
83 | return output, matrix, bias
84 | else:
85 | return output
86 |
87 |
88 | def norm(x, name, _type, _ops, is_train=True):
89 | if _type == 'batch':
90 | return batch_norm(x, name=name, _ops=_ops, is_train=is_train)
91 | elif _type == 'instance':
92 | return instance_norm(x, name=name)
93 | else:
94 | raise NotImplementedError
95 |
96 |
97 | def batch_norm(x, name, _ops, is_train=True):
98 | """Batch normalization."""
99 | with tf.variable_scope(name):
100 | params_shape = [x.get_shape()[-1]]
101 |
102 | beta = tf.get_variable('beta', params_shape, tf.float32,
103 | initializer=tf.constant_initializer(0.0, tf.float32))
104 | gamma = tf.get_variable('gamma', params_shape, tf.float32,
105 | initializer=tf.constant_initializer(1.0, tf.float32))
106 |
107 | if is_train is True:
108 | mean, variance = tf.nn.moments(x, [0, 1, 2], name='moments')
109 |
110 | moving_mean = tf.get_variable('moving_mean', params_shape, tf.float32,
111 | initializer=tf.constant_initializer(0.0, tf.float32),
112 | trainable=False)
113 | moving_variance = tf.get_variable('moving_variance', params_shape, tf.float32,
114 | initializer=tf.constant_initializer(1.0, tf.float32),
115 | trainable=False)
116 |
117 | _ops.append(moving_averages.assign_moving_average(moving_mean, mean, 0.9))
118 | _ops.append(moving_averages.assign_moving_average(moving_variance, variance, 0.9))
119 | else:
120 | mean = tf.get_variable('moving_mean', params_shape, tf.float32,
121 | initializer=tf.constant_initializer(0.0, tf.float32), trainable=False)
122 | variance = tf.get_variable('moving_variance', params_shape, tf.float32, trainable=False)
123 |
124 | # epsilon used to be 1e-5. Maybe 0.001 solves NaN problem in deeper net.
125 | y = tf.nn.batch_normalization(x, mean, variance, beta, gamma, 1e-5)
126 | y.set_shape(x.get_shape())
127 |
128 | return y
129 |
130 |
131 | def instance_norm(x, name='instance_norm', mean=1.0, stddev=0.02, epsilon=1e-5):
132 | with tf.variable_scope(name):
133 | depth = x.get_shape()[3]
134 | scale = tf.get_variable(
135 | 'scale', [depth], tf.float32,
136 | initializer=tf.random_normal_initializer(mean=mean, stddev=stddev, dtype=tf.float32))
137 | offset = tf.get_variable('offset', [depth], initializer=tf.constant_initializer(0.0))
138 |
139 | # calcualte mean and variance as instance
140 | mean, variance = tf.nn.moments(x, axes=[1, 2], keep_dims=True)
141 |
142 | # normalization
143 | inv = tf.rsqrt(variance + epsilon)
144 | normalized = (x - mean) * inv
145 |
146 | return scale * normalized + offset
147 |
148 |
149 | def n_res_blocks(x, _ops=None, norm_='instance', is_train=True, num_blocks=6, is_print=False):
150 | output = None
151 | for idx in range(1, num_blocks+1):
152 | output = res_block(x, x.get_shape()[3], _ops=_ops, norm_=norm_, is_train=is_train,
153 | name='res{}'.format(idx))
154 | x = output
155 |
156 | if is_print:
157 | print_activations(output)
158 |
159 | return output
160 |
161 |
162 | # norm(x, name, _type, _ops, is_train=True)
163 | def res_block(x, k, _ops=None, norm_='instance', is_train=True, pad_type=None, name=None):
164 | with tf.variable_scope(name):
165 | conv1, conv2 = None, None
166 |
167 | # 3x3 Conv-Batch-Relu S1
168 | with tf.variable_scope('layer1'):
169 | if pad_type is None:
170 | conv1 = conv2d(x, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='SAME', name='conv')
171 | elif pad_type == 'REFLECT':
172 | padded1 = padding2d(x, p_h=1, p_w=1, pad_type='REFLECT', name='padding')
173 | conv1 = conv2d(padded1, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', name='conv')
174 | normalized1 = norm(conv1, name='norm', _type=norm_, _ops=_ops, is_train=is_train)
175 | relu1 = tf.nn.relu(normalized1)
176 |
177 | # 3x3 Conv-Batch S1
178 | with tf.variable_scope('layer2'):
179 | if pad_type is None:
180 | conv2 = conv2d(relu1, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='SAME', name='conv')
181 | elif pad_type == 'REFLECT':
182 | padded2 = padding2d(relu1, p_h=1, p_w=1, pad_type='REFLECT', name='padding')
183 | conv2 = conv2d(padded2, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', name='conv')
184 | normalized2 = norm(conv2, name='norm', _type=norm_, _ops=_ops, is_train=is_train)
185 |
186 | # sum layer1 and layer2
187 | output = x + normalized2
188 | return output
189 |
190 |
191 | def identity(x, name='identity', is_print=False):
192 | output = tf.identity(x, name=name)
193 | if is_print:
194 | print_activations(output)
195 |
196 | return output
197 |
198 |
199 | def max_pool_2x2(x, name='max_pool'):
200 | with tf.name_scope(name):
201 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
202 |
203 |
204 | def sigmoid(x, name='sigmoid', is_print=False):
205 | output = tf.nn.sigmoid(x, name=name)
206 | if is_print:
207 | print_activations(output)
208 |
209 | return output
210 |
211 |
212 | def tanh(x, name='tanh', is_print=False):
213 | output = tf.nn.tanh(x, name=name)
214 | if is_print:
215 | print_activations(output)
216 |
217 | return output
218 |
219 |
220 | def relu(x, name='relu', is_print=False):
221 | output = tf.nn.relu(x, name=name)
222 | if is_print:
223 | print_activations(output)
224 |
225 | return output
226 |
227 |
228 | def lrelu(x, leak=0.2, name='lrelu', is_print=False):
229 | output = tf.maximum(x, leak*x, name=name)
230 | if is_print:
231 | print_activations(output)
232 |
233 | return output
234 |
235 |
236 | def xavier_init(in_dim):
237 | # print('in_dim: ', in_dim)
238 | xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
239 | return xavier_stddev
240 |
241 |
242 | def print_activations(t):
243 | print(t.op.name, ' ', t.get_shape().as_list())
244 |
245 |
246 | def show_all_variables():
247 | model_vars = tf.trainable_variables()
248 | slim.model_analyzer.analyze_vars(model_vars, print_info=True)
249 |
250 |
251 | def batch_convert2int(images):
252 | # images: 4D float tensor (batch_size, image_size, image_size, depth)
253 | return tf.map_fn(convert2int, images, dtype=tf.uint8)
254 |
255 |
256 | def convert2int(image):
257 | # transform from float tensor ([-1.,1.]) to int image ([0,255])
258 | return tf.image.convert_image_dtype((image + 1.0) / 2.0, tf.uint8)
259 |
--------------------------------------------------------------------------------
/src/jupyter/utils.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------
2 | # Python Utils Implementation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Cheng-Bin Jin
5 | # Email: sbkim0407@gmail.com
6 | # ---------------------------------------------------------
7 | import os
8 | import collections
9 | import matplotlib
10 | matplotlib.use('Agg')
11 | import matplotlib.pyplot as plt
12 |
13 |
14 | def print_model_setting(locals_):
15 | print("Uppercase local vars:")
16 |
17 | all_vars = [(k, v) for (k, v) in locals_.items() if (
18 | k.isupper() and k != 'T' and k != 'SETTINGS' and k != 'ALL_SETTINGS')]
19 | all_vars = sorted(all_vars, key=lambda x: x[0])
20 |
21 | for var_name, var_value in all_vars:
22 | print("\t{}: {}".format(var_name, var_value))
23 |
24 |
25 | _since_beginning = collections.defaultdict(lambda: {})
26 | _since_last_flush = collections.defaultdict(lambda: {})
27 | _iter = [0]
28 |
29 |
30 | def tick():
31 | _iter[0] += 1
32 |
33 |
34 | def plot(name, value):
35 | _since_last_flush[name][_iter[0]] = value
36 |
37 |
38 | def flush(save_folder):
39 | prints = []
40 |
41 | for name, vals in _since_last_flush.items():
42 | sum_ = 0
43 | keys = vals.keys()
44 | values = vals.values()
45 | num_keys = len(list(keys))
46 | for val in values:
47 | sum_ += val
48 |
49 | prints.append("{}\t{}".format(name, sum_/num_keys))
50 | _since_beginning[name].update(vals)
51 |
52 | x_vals = _since_beginning[name].keys()
53 | y_vals = [_since_beginning[name][x] for x in x_vals]
54 |
55 | plt.clf()
56 | plt.plot(x_vals, y_vals)
57 | plt.xlabel('iteration')
58 | plt.ylabel(name)
59 | plt.savefig(os.path.join(save_folder, name.replace(' ', '_')+'.jpg'))
60 |
61 | print("iter {}\t{}".format(_iter[0], "\t".join(prints)))
62 | _since_last_flush.clear()
63 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------
2 | # Tensorflow WGAN-GP Implementation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Cheng-Bin Jin
5 | # Email: sbkim0407@gmail.com
6 | # ---------------------------------------------------------
7 | import os
8 | import tensorflow as tf
9 | from solver import Solver
10 |
11 | FLAGS = tf.flags.FLAGS
12 |
13 | tf.flags.DEFINE_string('gpu_index', '0', 'gpu index if you have multiple gpus, default: 0')
14 | tf.flags.DEFINE_integer('batch_size', 64, 'batch size, default: 64')
15 | tf.flags.DEFINE_string('dataset', 'mnist', 'dataset name from [mnist, cifar10, imagenet64], default: mnist')
16 |
17 | tf.flags.DEFINE_bool('is_train', True, 'training or inference mode, default: True')
18 | tf.flags.DEFINE_float('learning_rate', 1e-4, 'initial learning rate for Adam, default: 0.0002')
19 | tf.flags.DEFINE_integer('num_critic', 5, 'the number of iterations of the critic per generator iteration, default: 5')
20 | tf.flags.DEFINE_integer('z_dim', 128, 'dimension of z vector, default: 128')
21 | tf.flags.DEFINE_float('lambda_', 10., 'gradient penalty lambda hyperparameter, default: 10.')
22 | tf.flags.DEFINE_float('beta1', 0.5, 'beta1 momentum term of Adam, default: 0.5')
23 | tf.flags.DEFINE_float('beta2', 0.9, 'beta2 momentum term of Adam, default: 0.9')
24 |
25 | tf.flags.DEFINE_integer('iters', 200000, 'number of iterations, default: 200000')
26 | tf.flags.DEFINE_integer('print_freq', 100, 'print frequency for loss, default: 100')
27 | tf.flags.DEFINE_integer('save_freq', 10000, 'save frequency for model, default: 10000')
28 | tf.flags.DEFINE_integer('sample_freq', 500, 'sample frequency for saving image, default: 500')
29 | tf.flags.DEFINE_integer('inception_freq', 1000, 'calculation frequence of inception score, default: 1000')
30 | tf.flags.DEFINE_integer('sample_batch', 64, 'number of sampling images for check generator quality, default: 64')
31 | tf.flags.DEFINE_string('load_model', None, 'folder of saved model taht you wish to continue training '
32 | '(e.g. 20181017-1430), default: None')
33 |
34 |
35 | def main(_):
36 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_index
37 |
38 | solver = Solver(FLAGS)
39 | if FLAGS.is_train:
40 | solver.train()
41 | if not FLAGS.is_train:
42 | solver.test()
43 |
44 |
45 | if __name__ == '__main__':
46 | tf.app.run()
47 |
--------------------------------------------------------------------------------
/src/plot.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------
2 | # Python Plot Function
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # From https://github.com/igul222/improved_wgan_training/blob/master/tflib/plot.py
5 | # Code from igul222
6 | # ---------------------------------------------------------
7 | import os
8 | import matplotlib as mpl
9 | mpl.use('TkAgg') # or whatever other backend that you want to solve Segmentation fault (core dumped)
10 | import matplotlib.pyplot as plt
11 | import collections
12 |
13 |
14 | _since_beginning = collections.defaultdict(lambda: {})
15 | _since_last_flush = collections.defaultdict(lambda: {})
16 | _iter = [0]
17 |
18 |
19 | def tick():
20 | _iter[0] += 1
21 |
22 |
23 | def plot(name, value):
24 | _since_last_flush[name][_iter[0]] = value
25 |
26 |
27 | def flush(save_folder):
28 | prints = []
29 | for name, vals in _since_last_flush.items():
30 | sum_ = 0
31 | keys = vals.keys()
32 | values = vals.values()
33 | num_keys = len(list(keys))
34 | for val in values:
35 | sum_ += val
36 |
37 | prints.append("{}\t{}".format(name, sum_/num_keys))
38 | _since_beginning[name].update(vals)
39 |
40 | x_vals = _since_beginning[name].keys()
41 | y_vals = [_since_beginning[name][x] for x in x_vals]
42 |
43 | plt.clf()
44 | plt.plot(x_vals, y_vals)
45 | plt.xlabel('iteration')
46 | plt.ylabel(name)
47 | plt.savefig(os.path.join(save_folder, name.replace(' ', '_')+'.jpg'))
48 |
49 | # print("iter {}\t{}".format(_iter[0], "\t".join(prints)))
50 | _since_last_flush.clear()
51 |
--------------------------------------------------------------------------------
/src/solver.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------
2 | # Tensorflow WGAN-GP Implementation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Cheng-Bin Jin
5 | # Email: sbkim0407@gmail.com
6 | # ---------------------------------------------------------
7 | import os
8 | import logging
9 | import numpy as np
10 | import tensorflow as tf
11 | from datetime import datetime
12 |
13 | # noinspection PyPep8Naming
14 | import plot as plot
15 | from dataset_ import Dataset
16 | from wgan_gp import WGAN_GP
17 | from inception_score import get_inception_score
18 |
19 | logger = logging.getLogger(__name__) # logger
20 | logger.setLevel(logging.INFO)
21 |
22 |
23 | class Solver(object):
24 | def __init__(self, flags):
25 | run_config = tf.ConfigProto()
26 | run_config.gpu_options.allow_growth = True
27 | self.sess = tf.Session(config=run_config)
28 |
29 | self.flags = flags
30 | self.iter_time = 0
31 | self.num_examples_IS = 1000
32 | self._make_folders()
33 | self._init_logger()
34 |
35 | self.dataset = Dataset(self.sess, self.flags, self.flags.dataset, log_path=self.log_out_dir)
36 | self.model = WGAN_GP(self.sess, self.flags, self.dataset, log_path=self.log_out_dir)
37 |
38 | self.saver = tf.train.Saver()
39 | self.sess.run([tf.global_variables_initializer()])
40 |
41 | # tf_utils.show_all_variables()
42 |
43 | def _make_folders(self):
44 | if self.flags.is_train: # train stage
45 | if self.flags.load_model is None:
46 | cur_time = datetime.now().strftime("%Y%m%d-%H%M")
47 | self.model_out_dir = "{}/model/{}".format(self.flags.dataset, cur_time)
48 | if not os.path.isdir(self.model_out_dir):
49 | os.makedirs(self.model_out_dir)
50 | else:
51 | cur_time = self.flags.load_model
52 | self.model_out_dir = "{}/model/{}".format(self.flags.dataset, cur_time)
53 |
54 | self.sample_out_dir = "{}/sample/{}".format(self.flags.dataset, cur_time)
55 | if not os.path.isdir(self.sample_out_dir):
56 | os.makedirs(self.sample_out_dir)
57 |
58 | self.log_out_dir = "{}/logs/{}".format(self.flags.dataset, cur_time)
59 | self.train_writer = tf.summary.FileWriter("{}/logs/{}".format(self.flags.dataset, cur_time),
60 | graph_def=self.sess.graph_def)
61 |
62 | elif not self.flags.is_train: # test stage
63 | self.model_out_dir = "{}/model/{}".format(self.flags.dataset, self.flags.load_model)
64 | self.test_out_dir = "{}/test/{}".format(self.flags.dataset, self.flags.load_model)
65 | self.log_out_dir = "{}/logs/{}".format(self.flags.dataset, self.flags.load_model)
66 |
67 | if not os.path.isdir(self.test_out_dir):
68 | os.makedirs(self.test_out_dir)
69 |
70 | def _init_logger(self):
71 | formatter = logging.Formatter('%(asctime)s:%(name)s:%(message)s')
72 | # file handler
73 | file_handler = logging.FileHandler(os.path.join(self.log_out_dir, 'solver.log'))
74 | file_handler.setFormatter(formatter)
75 | file_handler.setLevel(logging.INFO)
76 | # stream handler
77 | stream_handler = logging.StreamHandler()
78 | stream_handler.setFormatter(formatter)
79 | # add handlers
80 | logger.addHandler(file_handler)
81 | logger.addHandler(stream_handler)
82 |
83 | if self.flags.is_train:
84 | logger.info('gpu_index: {}'.format(self.flags.gpu_index))
85 | logger.info('batch_size: {}'.format(self.flags.batch_size))
86 | logger.info('dataset: {}'.format(self.flags.dataset))
87 |
88 | logger.info('is_train: {}'.format(self.flags.is_train))
89 | logger.info('learning_rate: {}'.format(self.flags.learning_rate))
90 | logger.info('num_critic: {}'.format(self.flags.num_critic))
91 | logger.info('z_dim: {}'.format(self.flags.z_dim))
92 | logger.info('lambda_: {}'.format(self.flags.lambda_))
93 | logger.info('beta1: {}'.format(self.flags.beta1))
94 | logger.info('beta2: {}'.format(self.flags.beta2))
95 |
96 | logger.info('iters: {}'.format(self.flags.iters))
97 | logger.info('print_freq: {}'.format(self.flags.print_freq))
98 | logger.info('save_freq: {}'.format(self.flags.save_freq))
99 | logger.info('sample_freq: {}'.format(self.flags.sample_freq))
100 | logger.info('inception_freq: {}'.format(self.flags.inception_freq))
101 | logger.info('sample_batch: {}'.format(self.flags.sample_batch))
102 | logger.info('load_model: {}'.format(self.flags.load_model))
103 |
104 | def train(self):
105 | # load initialized checkpoint that provided
106 | if self.flags.load_model is not None:
107 | if self.load_model():
108 | logger.info(' [*] Load SUCCESS!\n')
109 | else:
110 | logger.info(' [!] Load Failed...\n')
111 |
112 | # for iter_time in range(self.flags.iters):
113 | while self.iter_time < self.flags.iters:
114 | # sampling images and save them
115 | self.sample(self.iter_time)
116 |
117 | # train_step
118 | loss, summary = self.model.train_step()
119 | self.model.print_info(loss, self.iter_time)
120 | self.train_writer.add_summary(summary, self.iter_time)
121 | self.train_writer.flush()
122 |
123 | if self.flags.dataset == 'cifar10':
124 | self.get_inception_score(self.iter_time) # calculate inception score
125 |
126 | # save model
127 | self.save_model(self.iter_time)
128 | self.iter_time += 1
129 |
130 | self.save_model(self.flags.iters)
131 |
132 | def test(self):
133 | if self.load_model():
134 | logger.info(' [*] Load SUCCESS!')
135 | else:
136 | logger.info(' [!] Load Failed...')
137 |
138 | num_iters = 20
139 | for iter_time in range(num_iters):
140 | print('iter_time: {}'.format(iter_time))
141 |
142 | imgs = self.model.test_step()
143 | self.model.plots(imgs, iter_time, self.test_out_dir)
144 |
145 | def get_inception_score(self, iter_time):
146 | if np.mod(iter_time, self.flags.inception_freq) == 0:
147 | sample_size = 100
148 | all_samples = []
149 | for _ in range(int(self.num_examples_IS/sample_size)):
150 | imgs = self.model.sample_imgs(sample_size=sample_size)
151 | all_samples.append(imgs[0])
152 |
153 | all_samples = np.concatenate(all_samples, axis=0)
154 | all_samples = ((all_samples + 1.) * 255. / 2.).astype(np.uint8)
155 |
156 | mean_IS, std_IS = get_inception_score(list(all_samples), self.flags)
157 | # print('Inception score iter: {}, IS: {}'.format(self.iter_time, mean_IS))
158 |
159 | plot.plot('inception score', mean_IS)
160 | plot.flush(self.log_out_dir) # write logs
161 | plot.tick()
162 |
163 | def sample(self, iter_time):
164 | if np.mod(iter_time, self.flags.sample_freq) == 0:
165 | imgs = self.model.sample_imgs(sample_size=self.flags.sample_batch)
166 | self.model.plots(imgs, iter_time, self.sample_out_dir)
167 |
168 | def save_model(self, iter_time):
169 | if np.mod(iter_time + 1, self.flags.save_freq) == 0:
170 | model_name = 'model'
171 | self.saver.save(self.sess, os.path.join(self.model_out_dir, model_name), global_step=iter_time)
172 | logger.info('[*] Model saved! Iter: {}'.format(iter_time))
173 |
174 | def load_model(self):
175 | logger.info(' [*] Reading checkpoint...')
176 |
177 | checkpoint = tf.train.get_checkpoint_state(self.model_out_dir)
178 | if checkpoint and checkpoint.model_checkpoint_path:
179 | ckpt_name = os.path.basename(checkpoint.model_checkpoint_path)
180 | self.saver.restore(self.sess, os.path.join(self.model_out_dir, ckpt_name))
181 |
182 | meta_graph_path = checkpoint.model_checkpoint_path + '.meta'
183 | self.iter_time = int(meta_graph_path.split('-')[-1].split('.')[0])
184 |
185 | logger.info('[*] Load iter_time: {}'.format(self.iter_time))
186 | return True
187 | else:
188 | return False
189 |
--------------------------------------------------------------------------------
/src/tensorflow_utils.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------
2 | # Tensorflow Utils Implementation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Cheng-Bin Jin
5 | # Email: sbkim0407@gmail.com
6 | # ---------------------------------------------------------
7 | import os
8 | import logging
9 | import functools
10 | import tensorflow as tf
11 | import tensorflow.contrib.slim as slim
12 | from tensorflow.python.training import moving_averages
13 |
14 | logger = logging.getLogger(__name__) # logger
15 | logger.setLevel(logging.INFO)
16 |
17 |
18 | def _init_logger(log_path):
19 | formatter = logging.Formatter('%(asctime)s:%(name)s:%(message)s')
20 | # file handler
21 | file_handler = logging.FileHandler(os.path.join(log_path, 'model.log'))
22 | file_handler.setFormatter(formatter)
23 | file_handler.setLevel(logging.INFO)
24 | # stream handler
25 | stream_handler = logging.StreamHandler()
26 | stream_handler.setFormatter(formatter)
27 | # add handlers
28 | logger.addHandler(file_handler)
29 | logger.addHandler(stream_handler)
30 |
31 |
32 | def padding2d(x, p_h=1, p_w=1, pad_type='REFLECT', name='pad2d'):
33 | if pad_type == 'REFLECT':
34 | return tf.pad(x, [[0, 0], [p_h, p_h], [p_w, p_w], [0, 0]], 'REFLECT', name=name)
35 |
36 |
37 | def conv2d(x, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, padding='SAME', name='conv2d', is_print=True):
38 | with tf.variable_scope(name):
39 | w = tf.get_variable('w', [k_h, k_w, x.get_shape()[-1], output_dim],
40 | initializer=tf.truncated_normal_initializer(stddev=stddev))
41 | conv = tf.nn.conv2d(x, w, strides=[1, d_h, d_w, 1], padding=padding)
42 |
43 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
44 | # conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
45 | conv = tf.nn.bias_add(conv, biases)
46 |
47 | if is_print:
48 | print_activations(conv)
49 |
50 | return conv
51 |
52 |
53 | def deconv2d(x, k, k_h=3, k_w=3, d_h=2, d_w=2, stddev=0.02, padding_='SAME', output_size=None,
54 | name='deconv2d', with_w=False, is_print=True):
55 | with tf.variable_scope(name):
56 | input_shape = x.get_shape().as_list()
57 |
58 | # calculate output size
59 | h_output, w_output = None, None
60 | if not output_size:
61 | h_output, w_output = input_shape[1] * 2, input_shape[2] * 2
62 | # output_shape = [input_shape[0], h_output, w_output, k] # error when not define batch_size
63 | output_shape = [tf.shape(x)[0], h_output, w_output, k]
64 |
65 | # conv2d transpose
66 | w = tf.get_variable('w', [k_h, k_w, k, input_shape[3]],
67 | initializer=tf.random_normal_initializer(stddev=stddev))
68 | deconv = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=[1, d_h, d_w, 1],
69 | padding=padding_)
70 |
71 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
72 | deconv = tf.nn.bias_add(deconv, biases)
73 |
74 | if is_print:
75 | print_activations(deconv)
76 |
77 | if with_w:
78 | return deconv, w, biases
79 | else:
80 | return deconv
81 |
82 |
83 | def upsampling2d(x, size=(2, 2), name='upsampling2d'):
84 | with tf.name_scope(name):
85 | shape = x.get_shape().as_list()
86 | return tf.image.resize_nearest_neighbor(x, size=(size[0] * shape[1], size[1] * shape[2]))
87 |
88 |
89 | def linear(x, output_size, bias_start=0.0, with_w=False, name='fc'):
90 | shape = x.get_shape().as_list()
91 |
92 | with tf.variable_scope(name):
93 | matrix = tf.get_variable(name="matrix", shape=[shape[1], output_size],
94 | dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())
95 | bias = tf.get_variable(name="bias", shape=[output_size],
96 | initializer=tf.constant_initializer(bias_start))
97 | if with_w:
98 | return tf.matmul(x, matrix) + bias, matrix, bias
99 | else:
100 | return tf.matmul(x, matrix) + bias
101 |
102 |
103 | def norm(x, name, _type, _ops, is_train=True):
104 | if _type == 'batch':
105 | return batch_norm(x, name=name, _ops=_ops, is_train=is_train)
106 | elif _type == 'instance':
107 | return instance_norm(x, name=name)
108 | elif _type == 'layer':
109 | return layer_norm(x, name=name)
110 | else:
111 | raise NotImplementedError
112 |
113 |
114 | def batch_norm(x, name, _ops, is_train=True):
115 | """Batch normalization."""
116 | with tf.variable_scope(name):
117 | params_shape = [x.get_shape()[-1]]
118 |
119 | beta = tf.get_variable('beta', params_shape, tf.float32,
120 | initializer=tf.constant_initializer(0.0, tf.float32))
121 | gamma = tf.get_variable('gamma', params_shape, tf.float32,
122 | initializer=tf.constant_initializer(1.0, tf.float32))
123 |
124 | if is_train is True:
125 | mean, variance = tf.nn.moments(x, [0, 1, 2], name='moments')
126 |
127 | moving_mean = tf.get_variable('moving_mean', params_shape, tf.float32,
128 | initializer=tf.constant_initializer(0.0, tf.float32),
129 | trainable=False)
130 | moving_variance = tf.get_variable('moving_variance', params_shape, tf.float32,
131 | initializer=tf.constant_initializer(1.0, tf.float32),
132 | trainable=False)
133 |
134 | _ops.append(moving_averages.assign_moving_average(moving_mean, mean, 0.9))
135 | _ops.append(moving_averages.assign_moving_average(moving_variance, variance, 0.9))
136 | else:
137 | mean = tf.get_variable('moving_mean', params_shape, tf.float32,
138 | initializer=tf.constant_initializer(0.0, tf.float32), trainable=False)
139 | variance = tf.get_variable('moving_variance', params_shape, tf.float32, trainable=False)
140 |
141 | # epsilon used to be 1e-5. Maybe 0.001 solves NaN problem in deeper net.
142 | y = tf.nn.batch_normalization(x, mean, variance, beta, gamma, 1e-5)
143 | y.set_shape(x.get_shape())
144 |
145 | return y
146 |
147 |
148 | def instance_norm(x, name='instance_norm', mean=1.0, stddev=0.02, epsilon=1e-5):
149 | with tf.variable_scope(name):
150 | depth = x.get_shape()[3]
151 | scale = tf.get_variable(
152 | 'scale', [depth], tf.float32,
153 | initializer=tf.random_normal_initializer(mean=mean, stddev=stddev, dtype=tf.float32))
154 | offset = tf.get_variable('offset', [depth], initializer=tf.constant_initializer(0.0))
155 |
156 | # calcualte mean and variance as instance
157 | mean, variance = tf.nn.moments(x, axes=[1, 2], keep_dims=True)
158 |
159 | # normalization
160 | inv = tf.rsqrt(variance + epsilon)
161 | normalized = (x - mean) * inv
162 |
163 | return scale * normalized + offset
164 |
165 |
166 | # TODO: I'm not sure is it a good implementation of layer normalization...
167 | def layer_norm(x, name='layer_norm'):
168 | with tf.variable_scope(name):
169 | norm_axes = [1, 2, 3]
170 | mean, var = tf.nn.moments(x, axes=norm_axes, keep_dims=True)
171 |
172 | # Assume the 'neurons' axis is the third of norm_axes. This is the case for fully-connected
173 | # and BHWC conv layers.
174 | n_neurons = x.get_shape().as_list()[norm_axes[2]]
175 | offset = tf.get_variable('offset', n_neurons, tf.float32, initializer=tf.constant_initializer(0.0, tf.float32))
176 | scale = tf.get_variable('scale', n_neurons, tf.float32, initializer=tf.constant_initializer(1.0, tf.float32))
177 |
178 | # Add broadcasting dims to offset and scale (e.g. BCHW conv data)
179 | offset = tf.reshape(offset, [1 for _ in range(len(norm_axes)-1)] + [-1])
180 | scale = tf.reshape(scale, [1 for _ in range(len(norm_axes)-1)] + [-1])
181 |
182 | result = tf.nn.batch_normalization(x, mean, var, offset, scale, 1e-5)
183 |
184 | return result
185 |
186 |
187 | def n_res_blocks(x, _ops=None, norm_='instance', is_train=True, num_blocks=6, is_print=False):
188 | output = None
189 | for idx in range(1, num_blocks+1):
190 | output = res_block(x, x.get_shape()[3], _ops=_ops, norm_=norm_, is_train=is_train,
191 | name='res{}'.format(idx))
192 | x = output
193 |
194 | if is_print:
195 | print_activations(output)
196 |
197 | return output
198 |
199 |
200 | # norm(x, name, _type, _ops, is_train=True)
201 | def res_block(x, k, _ops=None, norm_='instance', is_train=True, pad_type=None, name=None):
202 | with tf.variable_scope(name):
203 | conv1, conv2 = None, None
204 |
205 | # 3x3 Conv-Batch-Relu S1
206 | with tf.variable_scope('layer1'):
207 | if pad_type is None:
208 | conv1 = conv2d(x, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='SAME', name='conv')
209 | elif pad_type == 'REFLECT':
210 | padded1 = padding2d(x, p_h=1, p_w=1, pad_type='REFLECT', name='padding')
211 | conv1 = conv2d(padded1, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', name='conv')
212 | normalized1 = norm(conv1, name='norm', _type=norm_, _ops=_ops, is_train=is_train)
213 | relu1 = tf.nn.relu(normalized1)
214 |
215 | # 3x3 Conv-Batch S1
216 | with tf.variable_scope('layer2'):
217 | if pad_type is None:
218 | conv2 = conv2d(relu1, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='SAME', name='conv')
219 | elif pad_type == 'REFLECT':
220 | padded2 = padding2d(relu1, p_h=1, p_w=1, pad_type='REFLECT', name='padding')
221 | conv2 = conv2d(padded2, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', name='conv')
222 | normalized2 = norm(conv2, name='norm', _type=norm_, _ops=_ops, is_train=is_train)
223 |
224 | # sum layer1 and layer2
225 | output = x + normalized2
226 | return output
227 |
228 |
229 | def identity(x, name='identity', is_print=False):
230 | output = tf.identity(x, name=name)
231 | if is_print:
232 | print_activations(output)
233 |
234 | return output
235 |
236 |
237 | def avgPoolConv(x, output_dim, filter_size=3, stride=1, name='avgPoolConv', is_print=True):
238 | with tf.variable_scope(name):
239 | output = avg_pool_2x2(x)
240 | output = conv2d(output, output_dim=output_dim, k_h=filter_size, k_w=filter_size, d_h=stride, d_w=stride)
241 | if is_print:
242 | print_activations(output)
243 |
244 | return output
245 |
246 |
247 | def convAvgPool(x, output_dim, filter_size=3, stride=1, name='convAvgPool', is_print=True):
248 | with tf.variable_scope(name):
249 | output = conv2d(x, output_dim=output_dim, k_h=filter_size, k_w=filter_size, d_h=stride, d_w=stride)
250 | output = avg_pool_2x2(output)
251 | if is_print:
252 | print_activations(output)
253 |
254 | return output
255 |
256 |
257 | def max_pool_2x2(x, name='max_pool'):
258 | with tf.name_scope(name):
259 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
260 |
261 |
262 | def avg_pool_2x2(x, name='avg_pool'):
263 | with tf.name_scope(name):
264 | return tf.nn.avg_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
265 |
266 |
267 | def sigmoid(x, name='sigmoid', is_print=False):
268 | output = tf.nn.sigmoid(x, name=name)
269 | if is_print:
270 | print_activations(output)
271 |
272 | return output
273 |
274 |
275 | def tanh(x, name='tanh', is_print=False):
276 | output = tf.nn.tanh(x, name=name)
277 | if is_print:
278 | print_activations(output)
279 |
280 | return output
281 |
282 |
283 | def relu(x, name='relu', is_print=False):
284 | output = tf.nn.relu(x, name=name)
285 | if is_print:
286 | print_activations(output)
287 |
288 | return output
289 |
290 |
291 | def lrelu(x, leak=0.2, name='lrelu', is_print=False):
292 | output = tf.maximum(x, leak*x, name=name)
293 | if is_print:
294 | print_activations(output)
295 |
296 | return output
297 |
298 |
299 | def xavier_init(in_dim):
300 | # print('in_dim: ', in_dim)
301 | xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
302 | return xavier_stddev
303 |
304 |
305 | def print_activations(t):
306 | # print(t.op.name, ' ', t.get_shape().as_list())
307 | logger.info(t.op.name + '{}'.format(t.get_shape().as_list()))
308 |
309 |
310 | def show_all_variables():
311 | model_vars = tf.trainable_variables()
312 | slim.model_analyzer.analyze_vars(model_vars, print_info=True)
313 |
314 |
315 | def batch_convert2int(images):
316 | # images: 4D float tensor (batch_size, image_size, image_size, depth)
317 | return tf.map_fn(convert2int, images, dtype=tf.uint8)
318 |
319 |
320 | def convert2int(image):
321 | # transform from float tensor ([-1.,1.]) to int image ([0,255])
322 | return tf.image.convert_image_dtype((image + 1.0) / 2.0, tf.uint8)
323 |
324 |
325 | def res_block_v2(x, k, filter_size, _ops=None, norm_='instance', is_train=True, resample=None, name=None):
326 | with tf.variable_scope(name):
327 | if resample == 'down':
328 | conv_shortcut = functools.partial(avgPoolConv, output_dim=k, filter_size=1)
329 | conv_1 = functools.partial(conv2d, output_dim=k, k_h=filter_size, k_w=filter_size, d_h=1, d_w=1)
330 | conv_2 = functools.partial(convAvgPool, output_dim=k)
331 | elif resample == 'up':
332 | conv_shortcut = functools.partial(deconv2d, k=k)
333 | conv_1 = functools.partial(deconv2d, k=k, k_h=filter_size, k_w=filter_size)
334 | conv_2 = functools.partial(conv2d, output_dim=k, k_h=filter_size, k_w=filter_size, d_h=1, d_w=1)
335 | elif resample is None:
336 | conv_shortcut = functools.partial(conv2d, output_dim=k, k_h=filter_size, k_w=filter_size, d_h=1, d_w=1)
337 | conv_1 = functools.partial(conv2d, output_dim=k, k_h=filter_size, k_w=filter_size, d_h=1, d_w=1)
338 | conv_2 = functools.partial(conv2d, output_dim=k, k_h=filter_size, k_w=filter_size, d_h=1, d_w=1)
339 | else:
340 | raise Exception('invalid resample value')
341 |
342 | if (k == x.get_shape().as_list()[3]) and (resample is None):
343 | shortcut = x # Identity skip-connection
344 | else:
345 | shortcut = conv_shortcut(x, name='shortcut')
346 |
347 | output = x
348 | output = norm(output, _type=norm_, _ops=_ops, is_train=is_train, name='norm1')
349 | output = relu(output, name='relu1')
350 | output = conv_1(output, name='conv1')
351 | output = norm(output, _type=norm_, _ops=_ops, is_train=is_train, name='norm2')
352 | output = relu(output, name='relu2')
353 | output = conv_2(output, name='conv2')
354 |
355 | return shortcut + output
356 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------
2 | # Python Implementation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Cheng-Bin Jin
5 | # Email: sbkim0407@gmail.com
6 | # ---------------------------------------------------------
7 | import os
8 | import sys
9 | import random
10 | import numpy as np
11 | import matplotlib as mpl
12 | import scipy.misc
13 | mpl.use('TkAgg') # or whatever other backend that you want to solve Segmentation fault (core dumped)
14 | import matplotlib.pyplot as plt
15 | import matplotlib.gridspec as gridspec
16 | from PIL import Image
17 |
18 |
19 | class ImagePool(object):
20 | def __init__(self, pool_size=50):
21 | self.pool_size = pool_size
22 | self.imgs = []
23 |
24 | def query(self, img):
25 | if self.pool_size == 0:
26 | return img
27 |
28 | if len(self.imgs) < self.pool_size:
29 | self.imgs.append(img)
30 | return img
31 | else:
32 | if random.random() > 0.5:
33 | # use old image
34 | random_id = random.randrange(0, self.pool_size)
35 | tmp_img = self.imgs[random_id].copy()
36 | self.imgs[random_id] = img.copy()
37 | return tmp_img
38 | else:
39 | return img
40 |
41 |
42 | def all_files_under(path, extension=None, append_path=True, sort=True):
43 | if append_path:
44 | if extension is None:
45 | filenames = [os.path.join(path, fname) for fname in os.listdir(path)]
46 | else:
47 | filenames = [os.path.join(path, fname)
48 | for fname in os.listdir(path) if fname.endswith(extension)]
49 | else:
50 | if extension is None:
51 | filenames = [os.path.basename(fname) for fname in os.listdir(path)]
52 | else:
53 | filenames = [os.path.basename(fname)
54 | for fname in os.listdir(path) if fname.endswith(extension)]
55 |
56 | if sort:
57 | filenames = sorted(filenames)
58 |
59 | return filenames
60 |
61 |
62 | def imagefiles2arrs(filenames):
63 | img_shape = image_shape(filenames[0])
64 | images_arr = None
65 |
66 | if len(img_shape) == 3: # color image
67 | images_arr = np.zeros((len(filenames), img_shape[0], img_shape[1], img_shape[2]), dtype=np.float32)
68 | elif len(img_shape) == 2: # gray scale image
69 | images_arr = np.zeros((len(filenames), img_shape[0], img_shape[1]), dtype=np.float32)
70 |
71 | for file_index in range(len(filenames)):
72 | img = Image.open(filenames[file_index])
73 | images_arr[file_index] = np.asarray(img).astype(np.float32)
74 |
75 | return images_arr
76 |
77 |
78 | def image_shape(filename):
79 | img = Image.open(filename, mode="r")
80 | img_arr = np.asarray(img)
81 | img_shape = img_arr.shape
82 | return img_shape
83 |
84 |
85 | def print_metrics(itr, kargs):
86 | print("*** Iteration {} ====> ".format(itr))
87 | for name, value in kargs.items():
88 | print("{} : {}, ".format(name, value))
89 | print("")
90 | sys.stdout.flush()
91 |
92 |
93 | def transform(img):
94 | return img / 127.5 - 1.0
95 |
96 |
97 | def inverse_transform(img):
98 | return (img + 1.) / 2.
99 |
100 |
101 | def preprocess_pair(img_a, img_b, load_size=286, fine_size=256, flip=True, is_test=False):
102 | if is_test:
103 | img_a = scipy.misc.imresize(img_a, [fine_size, fine_size])
104 | img_b = scipy.misc.imresize(img_b, [fine_size, fine_size])
105 | else:
106 | img_a = scipy.misc.imresize(img_a, [load_size, load_size])
107 | img_b = scipy.misc.imresize(img_b, [load_size, load_size])
108 |
109 | h1 = int(np.ceil(np.random.uniform(1e-2, load_size - fine_size)))
110 | w1 = int(np.ceil(np.random.uniform(1e-2, load_size - fine_size)))
111 | img_a = img_a[h1:h1 + fine_size, w1:w1 + fine_size]
112 | img_b = img_b[h1:h1 + fine_size, w1:w1 + fine_size]
113 |
114 | if flip and np.random.random() > 0.5:
115 | img_a = np.fliplr(img_a)
116 | img_b = np.fliplr(img_b)
117 |
118 | return img_a, img_b
119 |
120 |
121 | def imread(path, is_gray_scale=False, img_size=None):
122 | if is_gray_scale:
123 | img = scipy.misc.imread(path, flatten=True).astype(np.float)
124 | else:
125 | img = scipy.misc.imread(path, mode='RGB').astype(np.float)
126 |
127 | if not (img.ndim == 3 and img.shape[2] == 3):
128 | img = np.dstack((img, img, img))
129 |
130 | if img_size is not None:
131 | img = scipy.misc.imresize(img, img_size)
132 |
133 | return img
134 |
135 |
136 | def load_image(image_path, which_direction=0, is_gray_scale=True, img_size=(256, 256, 1)):
137 | input_img = imread(image_path, is_gray_scale=is_gray_scale, img_size=img_size)
138 | w_pair = int(input_img.shape[1])
139 | w_single = int(w_pair / 2)
140 |
141 | if which_direction == 0: # A to B
142 | img_a = input_img[:, 0:w_single]
143 | img_b = input_img[:, w_single:w_pair]
144 | else: # B to A
145 | img_a = input_img[:, w_single:w_pair]
146 | img_b = input_img[:, 0:w_single]
147 |
148 | return img_a, img_b
149 |
150 |
151 | def load_data(image_path, is_gray_scale=False):
152 | img = imread(path=image_path, is_gray_scale=is_gray_scale)
153 | img_trans = transform(img) # from [0, 255] to [-1., 1.]
154 |
155 | if is_gray_scale and (img_trans.ndim == 2):
156 | img_trans = np.expand_dims(img_trans, axis=2)
157 |
158 | return img_trans
159 |
160 |
161 | def plots(imgs, iter_time, save_file, grid_cols, grid_rows, sample_batch, name=None):
162 | # parameters for plot size
163 | scale, margin = 0.02, 0.02
164 |
165 | # save more bigger image
166 | img_h, img_w, img_c = imgs.shape[1:]
167 | fig = plt.figure(figsize=(img_w * grid_cols * scale, img_h * grid_rows * scale)) # (column, row)
168 | gs = gridspec.GridSpec(grid_rows, grid_cols) # (row, column)
169 | gs.update(wspace=margin, hspace=margin)
170 |
171 | for img_idx in range(sample_batch):
172 | ax = plt.subplot(gs[img_idx])
173 | plt.axis('off')
174 | ax.set_xticklabels([])
175 | ax.set_yticklabels([])
176 | ax.set_aspect('equal')
177 |
178 | if imgs[img_idx].shape[2] == 1: # gray scale
179 | plt.imshow((imgs[img_idx]).reshape(img_h, img_w), cmap='Greys_r')
180 | else:
181 | plt.imshow((imgs[img_idx]).reshape(img_h, img_w, img_c), cmap='Greys_r')
182 |
183 | plt.savefig(save_file + '/{}_{}.png'.format(str(iter_time), name), bbox_inches='tight')
184 | plt.close(fig)
185 |
--------------------------------------------------------------------------------
/src/wgan_gp.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------
2 | # TensorFlow WGAN-GP Implementation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Cheng-Bin Jin
5 | # Email: sbkim0407@gmail.com
6 | # ---------------------------------------------------------
7 | import logging
8 | import collections
9 | import numpy as np
10 | import tensorflow as tf
11 | from tensorflow.contrib.layers import flatten
12 | import matplotlib as mpl
13 | mpl.use('TkAgg') # or whatever other backend that you want to solve Segmentation fault (core dumped)
14 | import matplotlib.pyplot as plt
15 | import matplotlib.gridspec as gridspec
16 |
17 | # noinspection PyPep8Naming
18 | import tensorflow_utils as tf_utils
19 | import utils as utils
20 |
21 | logger = logging.getLogger(__name__) # logger
22 | logger.setLevel(logging.INFO)
23 |
24 |
25 | # noinspection PyPep8Naming
26 | class WGAN_GP(object):
27 | def __init__(self, sess, flags, dataset, log_path=None):
28 | self.sess = sess
29 | self.flags = flags
30 | self.dataset = dataset
31 | self.image_size = dataset.image_size
32 | self.log_path = log_path
33 |
34 | if self.flags.dataset == 'mnist':
35 | self.gen_c = [4*4*256, 128, 64]
36 | self.dis_c = [64, 128, 256]
37 | elif self.flags.dataset == 'cifar10':
38 | self.gen_c = [4*4*4*128, 256, 128]
39 | self.dis_c = [128, 256, 512]
40 | elif self.flags.dataset == 'imagenet64':
41 | self.gen_c = [4*4*8*64, 512, 256, 128, 64]
42 | self.dis_c = [64, 128, 256, 512, 512]
43 | else:
44 | raise NotImplementedError
45 |
46 | self.gen_train_ops, self.dis_train_ops = [], []
47 |
48 | self._init_logger() # init logger
49 | self._build_net() # init graph
50 | self._tensorboard() # init tensorboard
51 | logger.info("Initialized WGAN-GP SUCCESS!")
52 |
53 | def _init_logger(self):
54 | if self.flags.is_train:
55 | tf_utils._init_logger(self.log_path)
56 |
57 | def _build_net(self):
58 | self.Y = tf.placeholder(tf.float32, shape=[None, *self.image_size], name='real_data')
59 | self.z = tf.placeholder(tf.float32, shape=[None, self.flags.z_dim], name='latent_vector')
60 |
61 | if self.flags.dataset == 'imagenet64':
62 | self.generator = self.resnetGenerator
63 | self.discriminator = self.resnetDiscriminator
64 | else:
65 | self.generator = self.basicGenerator
66 | self.discriminator = self.basicDiscriminator
67 |
68 | self.g_samples = self.generator(self.z)
69 | _, d_logit_real = self.discriminator(self.Y)
70 | _, d_logit_fake = self.discriminator(self.g_samples, is_reuse=True)
71 |
72 | # discriminator loss
73 | self.wgan_d_loss = tf.reduce_mean(d_logit_fake) - tf.reduce_mean(d_logit_real)
74 | # generator loss
75 | self.g_loss = -tf.reduce_mean(d_logit_fake)
76 |
77 | d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='d_')
78 | g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='g_')
79 |
80 | # gradient penalty
81 | self.gp_loss = self.gradient_penalty()
82 | self.d_loss = self.wgan_d_loss + self.flags.lambda_ * self.gp_loss
83 |
84 | # Optimizers for generator and discriminator
85 | self.gen_optim = tf.train.AdamOptimizer(
86 | learning_rate=self.flags.learning_rate, beta1=0.5, beta2=0.9).minimize(self.g_loss, var_list=g_vars)
87 | self.dis_optim = tf.train.AdamOptimizer(
88 | learning_rate=self.flags.learning_rate, beta1=0.5, beta2=0.9).minimize(self.d_loss, var_list=d_vars)
89 |
90 | def gradient_penalty(self):
91 | alpha = tf.random_uniform(shape=[self.flags.batch_size, 1, 1, 1], minval=0., maxval=1.)
92 | differences = self.g_samples - self.Y
93 | interpolates = self.Y + (alpha * differences)
94 | gradients = tf.gradients(self.discriminator(interpolates, is_reuse=True), [interpolates])[0]
95 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2, 3]))
96 | gradient_penalty = tf.reduce_mean((slopes-1.)**2)
97 |
98 | return gradient_penalty
99 |
100 | def _tensorboard(self):
101 | tf.summary.scalar('loss/negative_wgan_d_loss', -self.wgan_d_loss)
102 | tf.summary.scalar('loss/gp_loss', self.gp_loss)
103 | tf.summary.scalar('loss/negative_d_loss', -self.d_loss) # negative critic loss
104 | tf.summary.scalar('loss/g_loss', self.g_loss)
105 |
106 | self.summary_op = tf.summary.merge_all()
107 |
108 | def basicGenerator(self, data, name='g_'):
109 | with tf.variable_scope(name):
110 | data_flatten = flatten(data)
111 | tf_utils.print_activations(data_flatten)
112 |
113 | # from (N, 128) to (N, 4, 4, 256)
114 | h0_linear = tf_utils.linear(data_flatten, self.gen_c[0], name='h0_linear')
115 | if self.flags.dataset == 'cifar10':
116 | h0_linear = tf.reshape(h0_linear, [tf.shape(h0_linear)[0], 4, 4, int(self.gen_c[0] / (4 * 4))])
117 | h0_linear = tf_utils.norm(h0_linear, _type='batch', _ops=self.gen_train_ops, name='h0_norm')
118 | h0_relu = tf.nn.relu(h0_linear, name='h0_relu')
119 | h0_reshape = tf.reshape(h0_relu, [tf.shape(h0_relu)[0], 4, 4, int(self.gen_c[0]/(4*4))])
120 |
121 | # from (N, 4, 4, 256) to (N, 8, 8, 128)
122 | h1_deconv = tf_utils.deconv2d(h0_reshape, self.gen_c[1], k_h=5, k_w=5, name='h1_deconv2d')
123 | if self.flags.dataset == 'cifar10':
124 | h1_deconv = tf_utils.norm(h1_deconv, _type='batch', _ops=self.gen_train_ops, name='h1_norm')
125 | h1_relu = tf.nn.relu(h1_deconv, name='h1_relu')
126 |
127 | # from (N, 8, 8, 128) to (N, 16, 16, 64)
128 | h2_deconv = tf_utils.deconv2d(h1_relu, self.gen_c[2], k_h=5, k_w=5, name='h2_deconv2d')
129 | if self.flags.dataset == 'cifar10':
130 | h2_deconv = tf_utils.norm(h2_deconv, _type='batch', _ops=self.gen_train_ops, name='h2_norm')
131 | h2_relu = tf.nn.relu(h2_deconv, name='h2_relu')
132 |
133 | # from (N, 16, 16, 64) to (N, 32, 32, 1)
134 | output = tf_utils.deconv2d(h2_relu, self.image_size[2], k_h=5, k_w=5, name='h3_deconv2d')
135 |
136 | return tf_utils.tanh(output)
137 |
138 | def basicDiscriminator(self, data, name='d_', is_reuse=False):
139 | with tf.variable_scope(name) as scope:
140 | if is_reuse is True:
141 | scope.reuse_variables()
142 | tf_utils.print_activations(data)
143 |
144 | # from (N, 32, 32, 1) to (N, 16, 16, 64)
145 | h0_conv = tf_utils.conv2d(data, self.dis_c[0], k_h=5, k_w=5, name='h0_conv2d')
146 | h0_lrelu = tf_utils.lrelu(h0_conv, name='h0_lrelu')
147 |
148 | # from (N, 16, 16, 64) to (N, 8, 8, 128)
149 | h1_conv = tf_utils.conv2d(h0_lrelu, self.dis_c[1], k_h=5, k_w=5, name='h1_conv2d')
150 | h1_lrelu = tf_utils.lrelu(h1_conv, name='h1_lrelu')
151 |
152 | # from (N, 8, 8, 128) to (N, 4, 4, 256)
153 | h2_conv = tf_utils.conv2d(h1_lrelu, self.dis_c[2], k_h=5, k_w=5, name='h2_conv2d')
154 | h2_lrelu = tf_utils.lrelu(h2_conv, name='h2_lrelu')
155 |
156 | # from (N, 4, 4, 256) to (N, 4096) and to (N, 1)
157 | h2_flatten = flatten(h2_lrelu)
158 | h3_linear = tf_utils.linear(h2_flatten, 1, name='h3_linear')
159 |
160 | return tf.nn.sigmoid(h3_linear), h3_linear
161 |
162 | def resnetGenerator(self, data, name='g_'):
163 | with tf.variable_scope(name):
164 | data_flatten = flatten(data)
165 | tf_utils.print_activations(data_flatten)
166 |
167 | # from (N, 128) to (N, 4, 4, 512)
168 | h0_linear = tf_utils.linear(data_flatten, self.gen_c[0], name='h0_linear')
169 | h0_reshape = tf.reshape(h0_linear, [tf.shape(h0_linear)[0], 4, 4, int(self.gen_c[0]/(4*4))])
170 |
171 | # (N, 8, 8, 512)
172 | resblock_1 = tf_utils.res_block_v2(h0_reshape, self.gen_c[1], filter_size=3, _ops=self.gen_train_ops,
173 | norm_='batch', resample='up', name='res_block_1')
174 | # (N, 16, 16, 256)
175 | resblock_2 = tf_utils.res_block_v2(resblock_1, self.gen_c[2], filter_size=3, _ops=self.gen_train_ops,
176 | norm_='batch', resample='up', name='res_block_2')
177 | # (N, 32, 32, 128)
178 | resblock_3 = tf_utils.res_block_v2(resblock_2, self.gen_c[3], filter_size=3, _ops=self.gen_train_ops,
179 | norm_='batch', resample='up', name='res_block_3')
180 | # (N, 64, 64, 64)
181 | resblock_4 = tf_utils.res_block_v2(resblock_3, self.gen_c[4], filter_size=3, _ops=self.gen_train_ops,
182 | norm_='batch', resample='up', name='res_block_4')
183 |
184 | norm_5 = tf_utils.norm(resblock_4, _type='batch', _ops=self.gen_train_ops, name='norm_5')
185 | relu_5 = tf_utils.relu(norm_5, name='relu_5')
186 | # (N, 64, 64, 3)
187 | output = tf_utils.conv2d(relu_5, output_dim=self.image_size[2], k_w=3, k_h=3, d_h=1, d_w=1, name='output')
188 |
189 | return tf_utils.tanh(output)
190 |
191 | def resnetDiscriminator(self, data, name='d_', is_reuse=False):
192 | with tf.variable_scope(name) as scope:
193 | if is_reuse is True:
194 | scope.reuse_variables()
195 | tf_utils.print_activations(data)
196 |
197 | # (N, 64, 64, 64)
198 | conv_0 = tf_utils.conv2d(data, output_dim=self.dis_c[0], k_h=3, k_w=3, d_h=1, d_w=1, name='conv_0')
199 | # (N, 32, 32, 128)
200 | resblock_1 = tf_utils.res_block_v2(conv_0, self.dis_c[1], filter_size=3, _ops=self.dis_train_ops,
201 | norm_='layer', resample='down', name='res_block_1')
202 | # (N, 16, 16, 256)
203 | resblock_2 = tf_utils.res_block_v2(resblock_1, self.dis_c[2], filter_size=3, _ops=self.dis_train_ops,
204 | norm_='layer', resample='down', name='res_block_2')
205 | # (N, 8, 8, 512)
206 | resblock_3 = tf_utils.res_block_v2(resblock_2, self.dis_c[3], filter_size=3, _ops=self.dis_train_ops,
207 | norm_='layer', resample='down', name='res_block_3')
208 | # (N, 4, 4, 512)
209 | resblock_4 = tf_utils.res_block_v2(resblock_3, self.dis_c[4], filter_size=3, _ops=self.dis_train_ops,
210 | norm_='layer', resample='down', name='res_block_4')
211 | # (N, 4*4*512)
212 | flatten_5 = flatten(resblock_4)
213 | output = tf_utils.linear(flatten_5, 1, name='output')
214 |
215 | return tf.nn.sigmoid(output), output
216 |
217 | def train_step(self):
218 | wgan_d_loss, gp_loss, d_loss = None, None, None
219 |
220 | # train discriminator
221 | for idx in range(self.flags.num_critic):
222 | batch_imgs = self.dataset.train_next_batch(batch_size=self.flags.batch_size)
223 | dis_feed = {self.z: self.sample_z(num=self.flags.batch_size), self.Y: batch_imgs}
224 | dis_run = [self.dis_optim, self.wgan_d_loss, self.gp_loss, self.d_loss]
225 | _, wgan_d_loss, gp_loss, d_loss = self.sess.run(dis_run, feed_dict=dis_feed)
226 |
227 | # train generator
228 | batch_imgs = self.dataset.train_next_batch(batch_size=self.flags.batch_size)
229 | gen_feed = {self.z: self.sample_z(num=self.flags.batch_size), self.Y: batch_imgs}
230 | _, g_loss, summary = self.sess.run([self.gen_optim, self.g_loss, self.summary_op], feed_dict=gen_feed)
231 |
232 | # negative critic loss
233 | return [-wgan_d_loss, gp_loss, -d_loss, g_loss], summary
234 |
235 | def test_step(self):
236 | return self.sample_imgs()
237 |
238 | def sample_imgs(self, sample_size=64):
239 | g_feed = {self.z: self.sample_z(num=sample_size)}
240 | y_fakes = self.sess.run(self.g_samples, feed_dict=g_feed)
241 |
242 | return [y_fakes]
243 |
244 | def sample_z(self, num=64):
245 | return np.random.uniform(-1., 1., size=[num, self.flags.z_dim])
246 |
247 | def print_info(self, loss, iter_time):
248 | if np.mod(iter_time, self.flags.print_freq) == 0:
249 | ord_output = collections.OrderedDict([('cur_iter', iter_time), ('tar_iters', self.flags.iters),
250 | ('batch_size', self.flags.batch_size),
251 | ('wgan_d_loss', loss[0]), ('gp_loss', loss[1]),
252 | ('d_loss', loss[2]), ('g_loss', loss[3]),
253 | ('dataset', self.flags.dataset),
254 | ('gpu_index', self.flags.gpu_index)])
255 |
256 | utils.print_metrics(iter_time, ord_output)
257 |
258 | def plots(self, imgs_, iter_time, save_file):
259 | # reshape image from vector to (N, H, W, C)
260 | imgs_fake = np.reshape(imgs_[0], (self.flags.sample_batch, *self.image_size))
261 |
262 | imgs = []
263 | for img in imgs_fake:
264 | imgs.append(img)
265 |
266 | # parameters for plot size
267 | scale, margin = 0.04, 0.01
268 | n_cols, n_rows = int(np.sqrt(len(imgs))), int(np.sqrt(len(imgs)))
269 | cell_size_h, cell_size_w = imgs[0].shape[0] * scale, imgs[0].shape[1] * scale
270 |
271 | fig = plt.figure(figsize=(cell_size_w * n_cols, cell_size_h * n_rows)) # (column, row)
272 | gs = gridspec.GridSpec(n_rows, n_cols) # (row, column)
273 | gs.update(wspace=margin, hspace=margin)
274 |
275 | imgs = [utils.inverse_transform(imgs[idx]) for idx in range(len(imgs))]
276 |
277 | # save more bigger image
278 | for col_index in range(n_cols):
279 | for row_index in range(n_rows):
280 | ax = plt.subplot(gs[row_index * n_cols + col_index])
281 | plt.axis('off')
282 | ax.set_xticklabels([])
283 | ax.set_yticklabels([])
284 | ax.set_aspect('equal')
285 | if self.image_size[2] == 3:
286 | plt.imshow((imgs[row_index * n_cols + col_index]).reshape(
287 | self.image_size[0], self.image_size[1], self.image_size[2]), cmap='Greys_r')
288 | elif self.image_size[2] == 1:
289 | plt.imshow((imgs[row_index * n_cols + col_index]).reshape(
290 | self.image_size[0], self.image_size[1]), cmap='Greys_r')
291 | else:
292 | raise NotImplementedError
293 |
294 | plt.savefig(save_file + '/sample_{}.png'.format(str(iter_time)), bbox_inches='tight')
295 | plt.close(fig)
296 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | skipsdist=True
3 | envlist = py35, py36, py37
4 |
5 | [testenv]
6 | deps = -rrequirements.d/base.txt
7 |
--------------------------------------------------------------------------------