├── .gitignore
├── LICENSE
├── README.md
├── img
├── cgan.png
├── con_gan.png
├── gan.png
├── gan_new.png
├── places365.jpg
├── places365.png
└── unet.png
├── requirements.txt
├── setup.cfg
├── src
├── __init__.py
├── dataset.py
├── main.py
├── models.py
├── networks.py
├── ops.py
├── options.py
└── utils.py
├── test-turing.py
├── test.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
103 | # custom
104 | _TODO
105 | checkpoints
106 | plots
107 | vcs.xml
108 | .idea
109 | .vscode
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Image Colorization with Generative Adversarial Networks
2 | In this work, we generalize the colorization procedure using a conditional Deep Convolutional Generative Adversarial Network (DCGAN) as as suggested by [Pix2Pix](https://arxiv.org/abs/1611.07004). The network is trained on the datasets [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) and [Places365](http://places2.csail.mit.edu). Some of the results from Places365 dataset are [shown here.](#places365-results)
3 |
4 | ## Prerequisites
5 | - Linux
6 | - Tensorflow 1.7
7 | - NVIDIA GPU (12G or 24G memory) + CUDA cuDNN
8 |
9 | ## Getting Started
10 | ### Installation
11 | - Clone this repo:
12 | ```bash
13 | git clone https://github.com/ImagingLab/Colorizing-with-GANs.git
14 | cd Colorizing-with-GANs
15 | ```
16 | - Install Tensorflow and dependencies from https://www.tensorflow.org/install/
17 | - Install python requirements:
18 | ```bash
19 | pip install -r requirements.txt
20 | ```
21 |
22 | ### Dataset
23 | - We use [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) and [Places365](http://places2.csail.mit.edu) datasets. To train a model on the full dataset, download datasets from official websites.
24 | After downloading, put then under the `datasets` folder.
25 |
26 | ### Training
27 | - To train the model, run `main.py` script
28 | ```bash
29 | python main.py
30 | ```
31 | - To train the model on places365 dataset with tuned hyperparameters:
32 | ```
33 | python train.py \
34 | --seed 100 \
35 | --dataset places365 \
36 | --dataset-path ./dataset/places365 \
37 | --checkpoints-path ./checkpoints \
38 | --batch-size 16 \
39 | --epochs 10 \
40 | --lr 3e-4 \
41 | --label-smoothing 1
42 |
43 | ```
44 |
45 | - To train the model of cifar10 dataset with tuned hyperparameters:
46 | ```
47 | python train.py \
48 | --seed 100 \
49 | --dataset cifar10 \
50 | --dataset-path ./dataset/cifar10 \
51 | --checkpoints-path ./checkpoints \
52 | --batch-size 128 \
53 | --epochs 200 \
54 | --lr 3e-4 \
55 | --lr-decay-steps 1e4 \
56 | --augment True
57 |
58 | ```
59 |
60 | ### Test
61 | - Download the pre-trained weights [from here.](https://drive.google.com/open?id=1jTsAUAKrMiHO2gn7s-fFZ_zUSzgKoPyp) and copy them in the `checkpoints` folder.
62 | - To test the model on a custom image(s), run `test.py` script:
63 | ```bash
64 | python test.py \
65 | --checkpoints-path ./checkpoints \ # checkpoints path
66 | --test-input ./checkpoints/test \ # test image(s) path
67 | --test-output ./checkpoints/output \ # output image(s) path
68 | ```
69 |
70 | ### Visual Turing Test
71 | - Download the pre-trained weights [from here.](https://drive.google.com/open?id=1jTsAUAKrMiHO2gn7s-fFZ_zUSzgKoPyp) and copy them in the `checkpoints` folder.
72 | - To evaluate the model qualitatively using visual Turing test, run `test-turing.py`:
73 | ```bash
74 | python test-turing.py
75 | ```
76 |
77 | - To apply time-based visual Turing test run (2 seconds decision time):
78 | ```bash
79 | python test-turing.py --test-delay 2
80 | ```
81 |
82 |
83 | ## Networks Architecture
84 | The architecture of generator is inspired by [U-Net](https://arxiv.org/abs/1505.04597): The architecture of the model is symmetric, with `n` encoding units and `n` decoding units. The contracting path consists of 4x4 convolution layers with stride 2 for downsampling, each followed by batch normalization and Leaky-ReLU activation function with the slope of 0.2. The number of channels are doubled after each step. Each unit in the expansive path consists of a 4x4 transposed convolutional layer with stride 2 for upsampling, concatenation with the activation map of the mirroring layer in the contracting path, followed by batch normalization and ReLU activation function. The last layer of the network is a 1x1 convolution which is equivalent to cross-channel parametric pooling layer. We use `tanh` function for the last layer.
85 |
86 |
87 |
88 |
89 | For discriminator, we use patch-gan architecture with contractive path similar to the baselines: a series of 4x4 convolutional layers with stride 2 with the number of channels being doubled after each downsampling. All convolution layers are followed by batch normalization, leaky ReLU activation with slope 0.2. After the last layer, a sigmoid function is applied to return probability values of `70x70` patches of the input being real or fake. We take the average of the probabilities as the network output!
90 |
91 | ## Places365 Results
92 | Colorization results with Places365. (a) Grayscale. (b) Original Image. (c) Colorized with GAN.
93 |
94 |
95 |
96 |
97 | ## Citation
98 | If you use this code for your research, please cite our paper Image Colorization with Generative Adversarial Networks:
99 |
100 | ```
101 | @inproceedings{nazeri2018image,
102 | title={Image Colorization Using Generative Adversarial Networks},
103 | author={Nazeri, Kamyar and Ng, Eric and Ebrahimi, Mehran},
104 | booktitle={International Conference on Articulated Motion and Deformable Objects},
105 | pages={85--94},
106 | year={2018},
107 | organization={Springer}
108 | }
109 | ```
110 |
--------------------------------------------------------------------------------
/img/cgan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ImagingLab/Colorizing-with-GANs/13bb4c49144788d20729cab120f205440917d8c3/img/cgan.png
--------------------------------------------------------------------------------
/img/con_gan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ImagingLab/Colorizing-with-GANs/13bb4c49144788d20729cab120f205440917d8c3/img/con_gan.png
--------------------------------------------------------------------------------
/img/gan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ImagingLab/Colorizing-with-GANs/13bb4c49144788d20729cab120f205440917d8c3/img/gan.png
--------------------------------------------------------------------------------
/img/gan_new.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ImagingLab/Colorizing-with-GANs/13bb4c49144788d20729cab120f205440917d8c3/img/gan_new.png
--------------------------------------------------------------------------------
/img/places365.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ImagingLab/Colorizing-with-GANs/13bb4c49144788d20729cab120f205440917d8c3/img/places365.jpg
--------------------------------------------------------------------------------
/img/places365.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ImagingLab/Colorizing-with-GANs/13bb4c49144788d20729cab120f205440917d8c3/img/places365.png
--------------------------------------------------------------------------------
/img/unet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ImagingLab/Colorizing-with-GANs/13bb4c49144788d20729cab120f205440917d8c3/img/unet.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy ~= 1.19
2 | scipy ~= 1.0.1
3 | future ~= 0.16.0
4 | matplotlib ~= 2.2.2
5 | pillow >= 6.2.0
6 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [pycodestyle]
2 | ignore = E303
3 | max-line-length = 200
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | from .options import *
2 | from .models import *
3 | from .utils import *
4 | from .dataset import *
5 | from .main import *
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import numpy as np
4 | import tensorflow as tf
5 | from scipy.misc import imread
6 | from abc import abstractmethod
7 | from .utils import unpickle
8 |
9 | CIFAR10_DATASET = 'cifar10'
10 | PLACES365_DATASET = 'places365'
11 |
12 |
13 | class BaseDataset():
14 | def __init__(self, name, path, training=True, augment=True):
15 | self.name = name
16 | self.augment = augment and training
17 | self.training = training
18 | self.path = path
19 | self._data = []
20 |
21 | def __len__(self):
22 | return len(self.data)
23 |
24 | def __iter__(self):
25 | total = len(self)
26 | start = 0
27 |
28 | while start < total:
29 | item = self[start]
30 | start += 1
31 | yield item
32 |
33 | raise StopIteration
34 |
35 | def __getitem__(self, index):
36 | val = self.data[index]
37 | try:
38 | img = imread(val) if isinstance(val, str) else val
39 |
40 | # grayscale images
41 | if np.sum(img[:,:,0] - img[:,:,1]) == 0 and np.sum(img[:,:,0] - img[:,:,2]) == 0:
42 | return None
43 |
44 | if self.augment and np.random.binomial(1, 0.5) == 1:
45 | img = img[:, ::-1, :]
46 |
47 | except:
48 | img = None
49 |
50 | return img
51 |
52 | def generator(self, batch_size, recusrive=False):
53 | start = 0
54 | total = len(self)
55 |
56 | while True:
57 | while start < total:
58 | end = np.min([start + batch_size, total])
59 | items = []
60 |
61 | for ix in range(start, end):
62 | item = self[ix]
63 | if item is not None:
64 | items.append(item)
65 |
66 | start = end
67 | yield items
68 |
69 | if recusrive:
70 | start = 0
71 |
72 | else:
73 | raise StopIteration
74 |
75 | @property
76 | def data(self):
77 | if len(self._data) == 0:
78 | self._data = self.load()
79 | np.random.shuffle(self._data)
80 |
81 | return self._data
82 |
83 | @abstractmethod
84 | def load(self):
85 | return []
86 |
87 |
88 | class Cifar10Dataset(BaseDataset):
89 | def __init__(self, path, training=True, augment=True):
90 | super(Cifar10Dataset, self).__init__(CIFAR10_DATASET, path, training, augment)
91 |
92 | def load(self):
93 | data = []
94 | if self.training:
95 | for i in range(1, 6):
96 | filename = '{}/data_batch_{}'.format(self.path, i)
97 | batch_data = unpickle(filename)
98 | if len(data) > 0:
99 | data = np.vstack((data, batch_data[b'data']))
100 | else:
101 | data = batch_data[b'data']
102 |
103 | else:
104 | filename = '{}/test_batch'.format(self.path)
105 | batch_data = unpickle(filename)
106 | data = batch_data[b'data']
107 |
108 | w = 32
109 | h = 32
110 | s = w * h
111 | data = np.array(data)
112 | data = np.dstack((data[:, :s], data[:, s:2 * s], data[:, 2 * s:]))
113 | data = data.reshape((-1, w, h, 3))
114 | return data
115 |
116 |
117 | class Places365Dataset(BaseDataset):
118 | def __init__(self, path, training=True, augment=True):
119 | super(Places365Dataset, self).__init__(PLACES365_DATASET, path, training, augment)
120 |
121 | def load(self):
122 | if self.training:
123 | flist = os.path.join(self.path, 'train.flist')
124 | if os.path.exists(flist):
125 | data = np.genfromtxt(flist, dtype=np.str, encoding='utf-8')
126 | else:
127 | data = glob.glob(self.path + '/data_256/**/*.jpg', recursive=True)
128 | np.savetxt(flist, data, fmt='%s')
129 |
130 | else:
131 | flist = os.path.join(self.path, 'test.flist')
132 | if os.path.exists(flist):
133 | data = np.genfromtxt(flist, dtype=np.str, encoding='utf-8')
134 | else:
135 | data = np.array(glob.glob(self.path + '/val_256/*.jpg'))
136 | np.savetxt(flist, data, fmt='%s')
137 |
138 | return data
139 |
140 |
141 | class TestDataset(BaseDataset):
142 | def __init__(self, path):
143 | super(TestDataset, self).__init__('TEST', path, training=False, augment=False)
144 |
145 | def __getitem__(self, index):
146 | path = self.data[index]
147 | img = imread(path)
148 | return path, img
149 |
150 | def load(self):
151 |
152 | if os.path.isfile(self.path):
153 | data = [self.path]
154 |
155 | elif os.path.isdir(self.path):
156 | data = list(glob.glob(self.path + '/*.jpg')) + list(glob.glob(self.path + '/*.png'))
157 |
158 | return data
159 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | import tensorflow as tf
5 | from .options import ModelOptions
6 | from .models import Cifar10Model, Places365Model
7 | from .dataset import CIFAR10_DATASET, PLACES365_DATASET
8 |
9 |
10 | def main(options):
11 |
12 | # reset tensorflow graph
13 | tf.reset_default_graph()
14 |
15 |
16 | # initialize random seed
17 | tf.set_random_seed(options.seed)
18 | np.random.seed(options.seed)
19 | random.seed(options.seed)
20 |
21 |
22 | # create a session environment
23 | with tf.Session() as sess:
24 |
25 | if options.dataset == CIFAR10_DATASET:
26 | model = Cifar10Model(sess, options)
27 |
28 | elif options.dataset == PLACES365_DATASET:
29 | model = Places365Model(sess, options)
30 |
31 | if not os.path.exists(options.checkpoints_path):
32 | os.makedirs(options.checkpoints_path)
33 |
34 | if options.log:
35 | open(model.train_log_file, 'w').close()
36 | open(model.test_log_file, 'w').close()
37 |
38 | # build the model and initialize
39 | model.build()
40 | sess.run(tf.global_variables_initializer())
41 |
42 |
43 | # load model only after global variables initialization
44 | model.load()
45 |
46 |
47 | if options.mode == 0:
48 | args = vars(options)
49 | print('\n------------ Options -------------')
50 | with open(os.path.join(options.checkpoints_path, 'options.dat'), 'w') as f:
51 | for k, v in sorted(args.items()):
52 | print('%s: %s' % (str(k), str(v)))
53 | f.write('%s: %s\n' % (str(k), str(v)))
54 | print('-------------- End ----------------\n')
55 |
56 | model.train()
57 |
58 | elif options.mode == 1:
59 | model.test()
60 |
61 | else:
62 | model.turing_test()
63 |
64 |
65 | if __name__ == "__main__":
66 | main(ModelOptions().parse())
67 |
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import time
5 | import numpy as np
6 | import tensorflow as tf
7 |
8 | from abc import abstractmethod
9 | from .networks import Generator, Discriminator
10 | from .ops import pixelwise_accuracy, preprocess, postprocess
11 | from .ops import COLORSPACE_RGB, COLORSPACE_LAB
12 | from .dataset import Places365Dataset, Cifar10Dataset, TestDataset
13 | from .utils import stitch_images, turing_test, imshow, imsave, create_dir, visualize, Progbar
14 |
15 |
16 | class BaseModel:
17 | def __init__(self, sess, options):
18 | self.sess = sess
19 | self.options = options
20 | self.name = options.name
21 | self.samples_dir = os.path.join(options.checkpoints_path, 'samples')
22 | self.test_log_file = os.path.join(options.checkpoints_path, 'log_test.dat')
23 | self.train_log_file = os.path.join(options.checkpoints_path, 'log_train.dat')
24 | self.global_step = tf.Variable(0, name='global_step', trainable=False)
25 | self.dataset_train = self.create_dataset(True)
26 | self.dataset_val = self.create_dataset(False)
27 | self.sample_generator = self.dataset_val.generator(options.sample_size, True)
28 | self.iteration = 0
29 | self.epoch = 0
30 | self.is_built = False
31 |
32 | def train(self):
33 | total = len(self.dataset_train)
34 |
35 | for epoch in range(self.options.epochs):
36 | lr_rate = self.sess.run(self.learning_rate)
37 |
38 | print('Training epoch: %d' % (epoch + 1) + " - learning rate: " + str(lr_rate))
39 |
40 | self.epoch = epoch + 1
41 | self.iteration = 0
42 |
43 | generator = self.dataset_train.generator(self.options.batch_size)
44 | progbar = Progbar(total, width=25, stateful_metrics=['epoch', 'iter', 'step'])
45 |
46 | for input_rgb in generator:
47 | feed_dic = {self.input_rgb: input_rgb}
48 |
49 | self.iteration = self.iteration + 1
50 | self.sess.run([self.dis_train], feed_dict=feed_dic)
51 | self.sess.run([self.gen_train, self.accuracy], feed_dict=feed_dic)
52 | self.sess.run([self.gen_train, self.accuracy], feed_dict=feed_dic)
53 |
54 | lossD, lossD_fake, lossD_real, lossG, lossG_l1, lossG_gan, acc, step = self.eval_outputs(feed_dic=feed_dic)
55 |
56 | progbar.add(len(input_rgb), values=[
57 | ("epoch", epoch + 1),
58 | ("iter", self.iteration),
59 | ("step", step),
60 | ("D loss", lossD),
61 | ("D fake", lossD_fake),
62 | ("D real", lossD_real),
63 | ("G loss", lossG),
64 | ("G L1", lossG_l1),
65 | ("G gan", lossG_gan),
66 | ("accuracy", acc)
67 | ])
68 |
69 | # log model at checkpoints
70 | if self.options.log and step % self.options.log_interval == 0:
71 | with open(self.train_log_file, 'a') as f:
72 | f.write('%d %d %f %f %f %f %f %f %f\n' % (self.epoch, step, lossD, lossD_fake, lossD_real, lossG, lossG_l1, lossG_gan, acc))
73 |
74 | if self.options.visualize:
75 | visualize(self.train_log_file, self.test_log_file, self.options.visualize_window, self.name)
76 |
77 | # sample model at checkpoints
78 | if self.options.sample and step % self.options.sample_interval == 0:
79 | self.sample(show=False)
80 |
81 | # validate model at checkpoints
82 | if self.options.validate and self.options.validate_interval > 0 and step % self.options.validate_interval == 0:
83 | self.validate()
84 |
85 | # save model at checkpoints
86 | if self.options.save and step % self.options.save_interval == 0:
87 | self.save()
88 |
89 | if self.options.validate:
90 | self.validate()
91 |
92 | def validate(self):
93 | print('\n\nValidating epoch: %d' % self.epoch)
94 | total = len(self.dataset_val)
95 | val_generator = self.dataset_val.generator(self.options.batch_size)
96 | progbar = Progbar(total, width=25)
97 |
98 | for input_rgb in val_generator:
99 | feed_dic = {self.input_rgb: input_rgb}
100 |
101 | self.sess.run([self.dis_loss, self.gen_loss, self.accuracy], feed_dict=feed_dic)
102 |
103 | lossD, lossD_fake, lossD_real, lossG, lossG_l1, lossG_gan, acc, step = self.eval_outputs(feed_dic=feed_dic)
104 |
105 | progbar.add(len(input_rgb), values=[
106 | ("D loss", lossD),
107 | ("D fake", lossD_fake),
108 | ("D real", lossD_real),
109 | ("G loss", lossG),
110 | ("G L1", lossG_l1),
111 | ("G gan", lossG_gan),
112 | ("accuracy", acc)
113 | ])
114 |
115 | print('\n')
116 |
117 | def test(self):
118 | print('\nTesting...')
119 | dataset = TestDataset(self.options.test_input or (self.options.checkpoints_path + '/test'))
120 | outputs_path = create_dir(self.options.test_output or (self.options.checkpoints_path + '/output'))
121 |
122 | for index in range(len(dataset)):
123 | img_gray_path, img_gray = dataset[index]
124 | name = os.path.basename(img_gray_path)
125 | path = os.path.join(outputs_path, name)
126 |
127 | feed_dic = {self.input_gray: img_gray[None, :, :, None]}
128 | outputs = self.sess.run(self.sampler, feed_dict=feed_dic)
129 | outputs = postprocess(tf.convert_to_tensor(outputs), colorspace_in=self.options.color_space, colorspace_out=COLORSPACE_RGB).eval() * 255
130 | print(path)
131 | imsave(outputs[0], path)
132 |
133 | def sample(self, show=True):
134 | input_rgb = next(self.sample_generator)
135 | feed_dic = {self.input_rgb: input_rgb}
136 |
137 | step, rate = self.sess.run([self.global_step, self.learning_rate])
138 | fake_image, input_gray = self.sess.run([self.sampler, self.input_gray], feed_dict=feed_dic)
139 | fake_image = postprocess(tf.convert_to_tensor(fake_image), colorspace_in=self.options.color_space, colorspace_out=COLORSPACE_RGB)
140 | img = stitch_images(input_gray, input_rgb, fake_image.eval())
141 |
142 | create_dir(self.samples_dir)
143 | sample = self.options.dataset + "_" + str(step).zfill(5) + ".png"
144 |
145 | if show:
146 | imshow(np.array(img), self.name)
147 | else:
148 | print('\nsaving sample ' + sample + ' - learning rate: ' + str(rate))
149 | img.save(os.path.join(self.samples_dir, sample))
150 |
151 | def turing_test(self):
152 | batch_size = self.options.batch_size
153 | gen = self.dataset_val.generator(batch_size, True)
154 | count = 0
155 | score = 0
156 | size = self.options.turing_test_size
157 |
158 | while count < size:
159 | input_rgb = next(gen)
160 | feed_dic = {self.input_rgb: input_rgb}
161 | fake_image = self.sess.run(self.sampler, feed_dict=feed_dic)
162 | fake_image = postprocess(tf.convert_to_tensor(fake_image), colorspace_in=self.options.color_space, colorspace_out=COLORSPACE_RGB)
163 |
164 | for i in range(np.min([batch_size, size - count])):
165 | res = turing_test(input_rgb[i], fake_image.eval()[i], self.options.turing_test_delay)
166 | count += 1
167 | score += res
168 | print('success: %d - fail: %d - rate: %f' % (score, count - score, (count - score) / count))
169 |
170 | def build(self):
171 | if self.is_built:
172 | return
173 |
174 | self.is_built = True
175 |
176 | gen_factory = self.create_generator()
177 | dis_factory = self.create_discriminator()
178 | smoothing = 0.9 if self.options.label_smoothing else 1
179 | seed = self.options.seed
180 | kernel = 4
181 |
182 | # model input placeholder: RGB imaege
183 | self.input_rgb = tf.placeholder(tf.float32, shape=(None, None, None, 3), name='input_rgb')
184 |
185 | # model input after preprocessing: LAB image
186 | self.input_color = preprocess(self.input_rgb, colorspace_in=COLORSPACE_RGB, colorspace_out=self.options.color_space)
187 |
188 | # test mode: model input is a graycale placeholder
189 | if self.options.mode == 1:
190 | self.input_gray = tf.placeholder(tf.float32, shape=(None, None, None, 1), name='input_gray')
191 |
192 | # train/turing-test we extract grayscale image from color image
193 | else:
194 | self.input_gray = tf.image.rgb_to_grayscale(self.input_rgb)
195 |
196 | gen = gen_factory.create(self.input_gray, kernel, seed)
197 | dis_real = dis_factory.create(tf.concat([self.input_gray, self.input_color], 3), kernel, seed)
198 | dis_fake = dis_factory.create(tf.concat([self.input_gray, gen], 3), kernel, seed, reuse_variables=True)
199 |
200 | gen_ce = tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_fake, labels=tf.ones_like(dis_fake))
201 | dis_real_ce = tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_real, labels=tf.ones_like(dis_real) * smoothing)
202 | dis_fake_ce = tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_fake, labels=tf.zeros_like(dis_fake))
203 |
204 | self.dis_loss_real = tf.reduce_mean(dis_real_ce)
205 | self.dis_loss_fake = tf.reduce_mean(dis_fake_ce)
206 | self.dis_loss = tf.reduce_mean(dis_real_ce + dis_fake_ce)
207 |
208 | self.gen_loss_gan = tf.reduce_mean(gen_ce)
209 | self.gen_loss_l1 = tf.reduce_mean(tf.abs(self.input_color - gen)) * self.options.l1_weight
210 | self.gen_loss = self.gen_loss_gan + self.gen_loss_l1
211 |
212 | self.sampler = tf.identity(gen_factory.create(self.input_gray, kernel, seed, reuse_variables=True), name='output')
213 | self.accuracy = pixelwise_accuracy(self.input_color, gen, self.options.color_space, self.options.acc_thresh)
214 | self.learning_rate = tf.constant(self.options.lr)
215 |
216 | # learning rate decay
217 | if self.options.lr_decay and self.options.lr_decay_rate > 0:
218 | self.learning_rate = tf.maximum(1e-6, tf.train.exponential_decay(
219 | learning_rate=self.options.lr,
220 | global_step=self.global_step,
221 | decay_steps=self.options.lr_decay_steps,
222 | decay_rate=self.options.lr_decay_rate))
223 |
224 | # generator optimizaer
225 | self.gen_train = tf.train.AdamOptimizer(
226 | learning_rate=self.learning_rate,
227 | beta1=self.options.beta1
228 | ).minimize(self.gen_loss, var_list=gen_factory.var_list)
229 |
230 | # discriminator optimizaer
231 | self.dis_train = tf.train.AdamOptimizer(
232 | learning_rate=self.learning_rate / 10,
233 | beta1=self.options.beta1
234 | ).minimize(self.dis_loss, var_list=dis_factory.var_list, global_step=self.global_step)
235 |
236 | self.saver = tf.train.Saver()
237 |
238 | def load(self):
239 | ckpt = tf.train.get_checkpoint_state(self.options.checkpoints_path)
240 | if ckpt is not None:
241 | print('loading model...\n')
242 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
243 | self.saver.restore(self.sess, os.path.join(self.options.checkpoints_path, ckpt_name))
244 | return True
245 |
246 | return False
247 |
248 | def save(self):
249 | print('saving model...\n')
250 | self.saver.save(self.sess, os.path.join(self.options.checkpoints_path, 'CGAN_' + self.options.dataset), write_meta_graph=False)
251 |
252 | def eval_outputs(self, feed_dic):
253 | '''
254 | evaluates the loss and accuracy
255 | returns (D loss, D_fake loss, D_real loss, G loss, G_L1 loss, G_gan loss, accuracy, step)
256 | '''
257 | lossD_fake = self.dis_loss_fake.eval(feed_dict=feed_dic)
258 | lossD_real = self.dis_loss_real.eval(feed_dict=feed_dic)
259 | lossD = self.dis_loss.eval(feed_dict=feed_dic)
260 |
261 | lossG_l1 = self.gen_loss_l1.eval(feed_dict=feed_dic)
262 | lossG_gan = self.gen_loss_gan.eval(feed_dict=feed_dic)
263 | lossG = lossG_l1 + lossG_gan
264 |
265 | acc = self.accuracy.eval(feed_dict=feed_dic)
266 | step = self.sess.run(self.global_step)
267 |
268 | return lossD, lossD_fake, lossD_real, lossG, lossG_l1, lossG_gan, acc, step
269 |
270 | @abstractmethod
271 | def create_generator(self):
272 | raise NotImplementedError
273 |
274 | @abstractmethod
275 | def create_discriminator(self):
276 | raise NotImplementedError
277 |
278 | @abstractmethod
279 | def create_dataset(self, training):
280 | raise NotImplementedError
281 |
282 |
283 | class Cifar10Model(BaseModel):
284 | def __init__(self, sess, options):
285 | super(Cifar10Model, self).__init__(sess, options)
286 |
287 | def create_generator(self):
288 | kernels_gen_encoder = [
289 | (64, 1, 0), # [batch, 32, 32, ch] => [batch, 32, 32, 64]
290 | (128, 2, 0), # [batch, 32, 32, 64] => [batch, 16, 16, 128]
291 | (256, 2, 0), # [batch, 16, 16, 128] => [batch, 8, 8, 256]
292 | (512, 2, 0), # [batch, 8, 8, 256] => [batch, 4, 4, 512]
293 | (512, 2, 0), # [batch, 4, 4, 512] => [batch, 2, 2, 512]
294 | ]
295 |
296 | kernels_gen_decoder = [
297 | (512, 2, 0.5), # [batch, 2, 2, 512] => [batch, 4, 4, 512]
298 | (256, 2, 0.5), # [batch, 4, 4, 512] => [batch, 8, 8, 256]
299 | (128, 2, 0), # [batch, 8, 8, 256] => [batch, 16, 16, 128]
300 | (64, 2, 0), # [batch, 16, 16, 128] => [batch, 32, 32, 64]
301 | ]
302 |
303 | return Generator('gen', kernels_gen_encoder, kernels_gen_decoder, training=self.options.training)
304 |
305 | def create_discriminator(self):
306 | kernels_dis = [
307 | (64, 2, 0), # [batch, 32, 32, ch] => [batch, 16, 16, 64]
308 | (128, 2, 0), # [batch, 16, 16, 64] => [batch, 8, 8, 128]
309 | (256, 2, 0), # [batch, 8, 8, 128] => [batch, 4, 4, 256]
310 | (512, 1, 0), # [batch, 4, 4, 256] => [batch, 4, 4, 512]
311 | ]
312 |
313 | return Discriminator('dis', kernels_dis, training=self.options.training)
314 |
315 | def create_dataset(self, training=True):
316 | return Cifar10Dataset(
317 | path=self.options.dataset_path,
318 | training=training,
319 | augment=self.options.augment)
320 |
321 |
322 | class Places365Model(BaseModel):
323 | def __init__(self, sess, options):
324 | super(Places365Model, self).__init__(sess, options)
325 |
326 | def create_generator(self):
327 | kernels_gen_encoder = [
328 | (64, 1, 0), # [batch, 256, 256, ch] => [batch, 256, 256, 64]
329 | (64, 2, 0), # [batch, 256, 256, 64] => [batch, 128, 128, 64]
330 | (128, 2, 0), # [batch, 128, 128, 64] => [batch, 64, 64, 128]
331 | (256, 2, 0), # [batch, 64, 64, 128] => [batch, 32, 32, 256]
332 | (512, 2, 0), # [batch, 32, 32, 256] => [batch, 16, 16, 512]
333 | (512, 2, 0), # [batch, 16, 16, 512] => [batch, 8, 8, 512]
334 | (512, 2, 0), # [batch, 8, 8, 512] => [batch, 4, 4, 512]
335 | (512, 2, 0) # [batch, 4, 4, 512] => [batch, 2, 2, 512]
336 | ]
337 |
338 | kernels_gen_decoder = [
339 | (512, 2, 0), # [batch, 2, 2, 512] => [batch, 4, 4, 512]
340 | (512, 2, 0), # [batch, 4, 4, 512] => [batch, 8, 8, 512]
341 | (512, 2, 0), # [batch, 8, 8, 512] => [batch, 16, 16, 512]
342 | (256, 2, 0), # [batch, 16, 16, 512] => [batch, 32, 32, 256]
343 | (128, 2, 0), # [batch, 32, 32, 256] => [batch, 64, 64, 128]
344 | (64, 2, 0), # [batch, 64, 64, 128] => [batch, 128, 128, 64]
345 | (64, 2, 0) # [batch, 128, 128, 64] => [batch, 256, 256, 64]
346 | ]
347 |
348 | return Generator('gen', kernels_gen_encoder, kernels_gen_decoder, training=self.options.training)
349 |
350 | def create_discriminator(self):
351 | kernels_dis = [
352 | (64, 2, 0), # [batch, 256, 256, ch] => [batch, 128, 128, 64]
353 | (128, 2, 0), # [batch, 128, 128, 64] => [batch, 64, 64, 128]
354 | (256, 2, 0), # [batch, 64, 64, 128] => [batch, 32, 32, 256]
355 | (512, 1, 0), # [batch, 32, 32, 256] => [batch, 32, 32, 512]
356 | ]
357 |
358 | return Discriminator('dis', kernels_dis, training=self.options.training)
359 |
360 | def create_dataset(self, training=True):
361 | return Places365Dataset(
362 | path=self.options.dataset_path,
363 | training=training,
364 | augment=self.options.augment)
365 |
--------------------------------------------------------------------------------
/src/networks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from .ops import conv2d, conv2d_transpose, pixelwise_accuracy
4 |
5 |
6 | class Discriminator(object):
7 | def __init__(self, name, kernels, training=True):
8 | self.name = name
9 | self.kernels = kernels
10 | self.training = training
11 | self.var_list = []
12 |
13 | def create(self, inputs, kernel_size=None, seed=None, reuse_variables=None):
14 | output = inputs
15 | with tf.variable_scope(self.name, reuse=reuse_variables):
16 | for index, kernel in enumerate(self.kernels):
17 |
18 | # not use batch-norm in the first layer
19 | bnorm = False if index == 0 else True
20 | name = 'conv' + str(index)
21 | output = conv2d(
22 | inputs=output,
23 | name=name,
24 | kernel_size=kernel_size,
25 | filters=kernel[0],
26 | strides=kernel[1],
27 | bnorm=bnorm,
28 | activation=tf.nn.leaky_relu,
29 | seed=seed
30 | )
31 |
32 | if kernel[2] > 0:
33 | keep_prob = 1.0 - kernel[2] if self.training else 1.0
34 | output = tf.nn.dropout(output, keep_prob=keep_prob, name='dropout_' + name, seed=seed)
35 |
36 | output = conv2d(
37 | inputs=output,
38 | name='conv_last',
39 | filters=1,
40 | kernel_size=4, # last layer kernel size = 4
41 | strides=1, # last layer stride = 1
42 | bnorm=False, # do not use batch-norm for the last layer
43 | seed=seed
44 | )
45 |
46 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
47 |
48 | return output
49 |
50 |
51 | class Generator(object):
52 | def __init__(self, name, encoder_kernels, decoder_kernels, output_channels=3, training=True):
53 | self.name = name
54 | self.encoder_kernels = encoder_kernels
55 | self.decoder_kernels = decoder_kernels
56 | self.output_channels = output_channels
57 | self.training = training
58 | self.var_list = []
59 |
60 | def create(self, inputs, kernel_size=None, seed=None, reuse_variables=None):
61 | output = inputs
62 |
63 | with tf.variable_scope(self.name, reuse=reuse_variables):
64 |
65 | layers = []
66 |
67 | # encoder branch
68 | for index, kernel in enumerate(self.encoder_kernels):
69 |
70 | name = 'conv' + str(index)
71 | output = conv2d(
72 | inputs=output,
73 | name=name,
74 | kernel_size=kernel_size,
75 | filters=kernel[0],
76 | strides=kernel[1],
77 | activation=tf.nn.leaky_relu,
78 | seed=seed
79 | )
80 |
81 | # save contracting path layers to be used for skip connections
82 | layers.append(output)
83 |
84 | if kernel[2] > 0:
85 | keep_prob = 1.0 - kernel[2] if self.training else 1.0
86 | output = tf.nn.dropout(output, keep_prob=keep_prob, name='dropout_' + name, seed=seed)
87 |
88 | # decoder branch
89 | for index, kernel in enumerate(self.decoder_kernels):
90 |
91 | name = 'deconv' + str(index)
92 | output = conv2d_transpose(
93 | inputs=output,
94 | name=name,
95 | kernel_size=kernel_size,
96 | filters=kernel[0],
97 | strides=kernel[1],
98 | activation=tf.nn.relu,
99 | seed=seed
100 | )
101 |
102 | if kernel[2] > 0:
103 | keep_prob = 1.0 - kernel[2] if self.training else 1.0
104 | output = tf.nn.dropout(output, keep_prob=keep_prob, name='dropout_' + name, seed=seed)
105 |
106 | # concat the layer from the contracting path with the output of the current layer
107 | # concat only the channels (axis=3)
108 | output = tf.concat([layers[len(layers) - index - 2], output], axis=3)
109 |
110 | output = conv2d(
111 | inputs=output,
112 | name='conv_last',
113 | filters=self.output_channels, # number of output chanels
114 | kernel_size=1, # last layer kernel size = 1
115 | strides=1, # last layer stride = 1
116 | bnorm=False, # do not use batch-norm for the last layer
117 | activation=tf.nn.tanh, # tanh activation function for the output
118 | seed=seed
119 | )
120 |
121 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
122 |
123 | return output
124 |
--------------------------------------------------------------------------------
/src/ops.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | COLORSPACE_RGB = 'RGB'
5 | COLORSPACE_LAB = 'LAB'
6 | tf.nn.softmax_cross_entropy_with_logits_v2
7 |
8 | def conv2d(inputs, filters, name, kernel_size=4, strides=2, bnorm=True, activation=None, seed=None):
9 | """
10 | Creates a conv2D block
11 | """
12 | initializer=tf.variance_scaling_initializer(seed=seed)
13 | res = tf.layers.conv2d(
14 | name=name,
15 | inputs=inputs,
16 | filters=filters,
17 | kernel_size=kernel_size,
18 | strides=strides,
19 | padding="same",
20 | kernel_initializer=initializer)
21 |
22 | if bnorm:
23 | res = tf.layers.batch_normalization(inputs=res, name='bn_' + name, training=True)
24 |
25 | # activation after batch-norm
26 | if activation is not None:
27 | res = activation(res)
28 |
29 | return res
30 |
31 |
32 | def conv2d_transpose(inputs, filters, name, kernel_size=4, strides=2, bnorm=True, activation=None, seed=None):
33 | """
34 | Creates a conv2D-transpose block
35 | """
36 | initializer=tf.variance_scaling_initializer(seed=seed)
37 | res = tf.layers.conv2d_transpose(
38 | name=name,
39 | inputs=inputs,
40 | filters=filters,
41 | kernel_size=kernel_size,
42 | strides=strides,
43 | padding="same",
44 | kernel_initializer=initializer)
45 |
46 | if bnorm:
47 | res = tf.layers.batch_normalization(inputs=res, name='bn_' + name, training=True)
48 |
49 | # activation after batch-norm
50 | if activation is not None:
51 | res = activation(res)
52 |
53 | return res
54 |
55 |
56 | def pixelwise_accuracy(img_real, img_fake, colorspace, thresh):
57 | """
58 | Measures the accuracy of the colorization process by comparing pixels
59 | """
60 | img_real = postprocess(img_real, colorspace, COLORSPACE_LAB)
61 | img_fake = postprocess(img_fake, colorspace, COLORSPACE_LAB)
62 |
63 | diffL = tf.abs(tf.round(img_real[..., 0]) - tf.round(img_fake[..., 0]))
64 | diffA = tf.abs(tf.round(img_real[..., 1]) - tf.round(img_fake[..., 1]))
65 | diffB = tf.abs(tf.round(img_real[..., 2]) - tf.round(img_fake[..., 2]))
66 |
67 | # within %thresh of the original
68 | predL = tf.cast(tf.less_equal(diffL, 1 * thresh), tf.float64) # L: [0, 100]
69 | predA = tf.cast(tf.less_equal(diffA, 2.2 * thresh), tf.float64) # A: [-110, 110]
70 | predB = tf.cast(tf.less_equal(diffB, 2.2 * thresh), tf.float64) # B: [-110, 110]
71 |
72 | # all three channels are within the threshold
73 | pred = predL * predA * predB
74 |
75 | return tf.reduce_mean(pred)
76 |
77 |
78 | def preprocess(img, colorspace_in, colorspace_out):
79 | if colorspace_out.upper() == COLORSPACE_RGB:
80 | if colorspace_in == COLORSPACE_LAB:
81 | img = lab_to_rgb(img)
82 |
83 | # [0, 1] => [-1, 1]
84 | img = (img / 255.0) * 2 - 1
85 |
86 | elif colorspace_out.upper() == COLORSPACE_LAB:
87 | if colorspace_in == COLORSPACE_RGB:
88 | img = rgb_to_lab(img / 255.0)
89 |
90 | L_chan, a_chan, b_chan = tf.unstack(img, axis=3)
91 |
92 | # L: [0, 100] => [-1, 1]
93 | # A, B: [-110, 110] => [-1, 1]
94 | img = tf.stack([L_chan / 50 - 1, a_chan / 110, b_chan / 110], axis=3)
95 |
96 | return img
97 |
98 |
99 | def postprocess(img, colorspace_in, colorspace_out):
100 | if colorspace_in.upper() == COLORSPACE_RGB:
101 | # [-1, 1] => [0, 1]
102 | img = (img + 1) / 2
103 |
104 | if colorspace_out == COLORSPACE_LAB:
105 | img = rgb_to_lab(img)
106 |
107 | elif colorspace_in.upper() == COLORSPACE_LAB:
108 | L_chan, a_chan, b_chan = tf.unstack(img, axis=3)
109 |
110 | # L: [-1, 1] => [0, 100]
111 | # A, B: [-1, 1] => [-110, 110]
112 | img = tf.stack([(L_chan + 1) / 2 * 100, a_chan * 110, b_chan * 110], axis=3)
113 |
114 | if colorspace_out == COLORSPACE_RGB:
115 | img = lab_to_rgb(img)
116 |
117 | return img
118 |
119 |
120 | def rgb_to_lab(srgb):
121 | # based on https://github.com/torch/image/blob/9f65c30167b2048ecbe8b7befdc6b2d6d12baee9/generic/image.c
122 | with tf.name_scope("rgb_to_lab"):
123 | srgb_pixels = tf.reshape(srgb, [-1, 3])
124 |
125 | with tf.name_scope("srgb_to_xyz"):
126 | linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
127 | exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32)
128 | rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
129 | rgb_to_xyz = tf.constant([
130 | # X Y Z
131 | [0.412453, 0.212671, 0.019334], # R
132 | [0.357580, 0.715160, 0.119193], # G
133 | [0.180423, 0.072169, 0.950227], # B
134 | ])
135 | xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz)
136 |
137 | # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
138 | with tf.name_scope("xyz_to_cielab"):
139 |
140 | # normalize for D65 white point
141 | xyz_normalized_pixels = tf.multiply(xyz_pixels, [1 / 0.950456, 1.0, 1 / 1.088754])
142 |
143 | epsilon = 6 / 29
144 | linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32)
145 | exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32)
146 | fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4 / 29) * linear_mask + (xyz_normalized_pixels ** (1 / 3)) * exponential_mask
147 |
148 | # convert to lab
149 | fxfyfz_to_lab = tf.constant([
150 | # l a b
151 | [0.0, 500.0, 0.0], # fx
152 | [116.0, -500.0, 200.0], # fy
153 | [0.0, 0.0, -200.0], # fz
154 | ])
155 | lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0])
156 |
157 | return tf.reshape(lab_pixels, tf.shape(srgb))
158 |
159 |
160 | def lab_to_rgb(lab):
161 | with tf.name_scope("lab_to_rgb"):
162 | lab_pixels = tf.reshape(lab, [-1, 3])
163 |
164 | # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
165 | with tf.name_scope("cielab_to_xyz"):
166 | # convert to fxfyfz
167 | lab_to_fxfyfz = tf.constant([
168 | # fx fy fz
169 | [1 / 116.0, 1 / 116.0, 1 / 116.0], # l
170 | [1 / 500.0, 0.0, 0.0], # a
171 | [0.0, 0.0, -1 / 200.0], # b
172 | ])
173 | fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz)
174 |
175 | # convert to xyz
176 | epsilon = 6 / 29
177 | linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32)
178 | exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32)
179 | xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4 / 29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask
180 |
181 | # denormalize for D65 white point
182 | xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754])
183 |
184 | with tf.name_scope("xyz_to_srgb"):
185 | xyz_to_rgb = tf.constant([
186 | # r g b
187 | [3.2404542, -0.9692660, 0.0556434], # x
188 | [-1.5371385, 1.8760108, -0.2040259], # y
189 | [-0.4985314, 0.0415560, 1.0572252], # z
190 | ])
191 | rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb)
192 | # avoid a slightly negative number messing up the conversion
193 | rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0)
194 | linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32)
195 | exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32)
196 | srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1 / 2.4) * 1.055) - 0.055) * exponential_mask
197 |
198 | return tf.reshape(srgb_pixels, tf.shape(lab))
199 |
--------------------------------------------------------------------------------
/src/options.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import random
4 | import argparse
5 |
6 |
7 | def str2bool(v):
8 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
9 | return True
10 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
11 | return False
12 | else:
13 | raise argparse.ArgumentTypeError('Boolean value expected.')
14 |
15 |
16 | class ModelOptions:
17 | def __init__(self):
18 | parser = argparse.ArgumentParser(description='Colorization with GANs')
19 | parser.add_argument('--seed', type=int, default=0, metavar='S', help='random seed (default: 0)')
20 | parser.add_argument('--name', type=str, default='CGAN', help='arbitrary model name (default: CGAN)')
21 | parser.add_argument('--mode', default=0, help='run mode [0: train, 1: test, 2: turing-test] (default: 0)')
22 | parser.add_argument('--dataset', type=str, default='places365', help='the name of dataset [places365, cifar10] (default: places365)')
23 | parser.add_argument('--dataset-path', type=str, default='./dataset', help='dataset path (default: ./dataset)')
24 | parser.add_argument('--checkpoints-path', type=str, default='./checkpoints', help='models are saved here (default: ./checkpoints)')
25 | parser.add_argument('--batch-size', type=int, default=16, metavar='N', help='input batch size for training (default: 16)')
26 | parser.add_argument('--color-space', type=str, default='lab', help='model color space [lab, rgb] (default: lab)')
27 | parser.add_argument('--epochs', type=int, default=30, metavar='N', help='number of epochs to train (default: 30)')
28 | parser.add_argument('--lr', type=float, default=3e-4, metavar='LR', help='learning rate (default: 3e-4)')
29 | parser.add_argument('--lr-decay', type=str2bool, default=True, help='True for learning-rate decay (default: True)')
30 | parser.add_argument('--lr-decay-rate', type=float, default=0.1, help='learning rate exponentially decay rate (default: 0.1)')
31 | parser.add_argument('--lr-decay-steps', type=float, default=5e5, help='learning rate exponentially decay steps (default: 1e5)')
32 | parser.add_argument('--beta1', type=float, default=0, help='momentum term of adam optimizer (default: 0)')
33 | parser.add_argument("--l1-weight", type=float, default=100.0, help="weight on L1 term for generator gradient (default: 100.0)")
34 | parser.add_argument('--augment', type=str2bool, default=True, help='True for augmentation (default: True)')
35 | parser.add_argument('--label-smoothing', type=str2bool, default=False, help='True for one-sided label smoothing (default: False)')
36 | parser.add_argument('--acc-thresh', type=float, default=2.0, help="accuracy threshold (default: 2.0)")
37 | parser.add_argument('--gpu-ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
38 |
39 | parser.add_argument('--save', type=str2bool, default=True, help='True for saving (default: True)')
40 | parser.add_argument('--save-interval', type=int, default=1000, help='how many batches to wait before saving model (default: 1000)')
41 | parser.add_argument('--sample', type=str2bool, default=True, help='True for sampling (default: True)')
42 | parser.add_argument('--sample-size', type=int, default=8, help='number of images to sample (default: 8)')
43 | parser.add_argument('--sample-interval', type=int, default=1000, help='how many batches to wait before sampling (default: 1000)')
44 | parser.add_argument('--validate', type=str2bool, default=True, help='True for validation (default: True)')
45 | parser.add_argument('--validate-interval', type=int, default=0, help='how many batches to wait before validating (default: 0)')
46 | parser.add_argument('--log', type=str2bool, default=False, help='True for logging (default: True)')
47 | parser.add_argument('--log-interval', type=int, default=10, help='how many iterations to wait before logging training status (default: 10)')
48 | parser.add_argument('--visualize', type=str2bool, default=False, help='True for accuracy visualization (default: False)')
49 | parser.add_argument('--visualize-window', type=int, default=100, help='the exponentially moving average window width (default: 100)')
50 |
51 | parser.add_argument('--test-input', type=str, default='', help='path to the grayscale images directory or a grayscale file')
52 | parser.add_argument('--test-output', type=str, default='', help='model test output directory')
53 | parser.add_argument('--turing-test-size', type=int, default=100, metavar='N', help='number of Turing tests (default: 100)')
54 | parser.add_argument('--turing-test-delay', type=int, default=0, metavar='N', help='number of seconds to wait when doing Turing test, 0 for unlimited (default: 0)')
55 |
56 | self._parser = parser
57 |
58 | def parse(self):
59 | opt = self._parser.parse_args()
60 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids
61 |
62 | opt.color_space = opt.color_space.upper()
63 | opt.training = opt.mode == 1
64 |
65 | if opt.seed == 0:
66 | opt.seed = random.randint(0, 2**31 - 1)
67 |
68 | if opt.dataset_path == './dataset':
69 | opt.dataset_path += ('/' + opt.dataset)
70 |
71 | if opt.checkpoints_path == './checkpoints':
72 | opt.checkpoints_path += ('/' + opt.dataset)
73 |
74 | return opt
75 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import random
5 | import pickle
6 | import numpy as np
7 | from PIL import Image
8 | import matplotlib.pyplot as plt
9 |
10 |
11 | def stitch_images(grayscale, original, pred):
12 | gap = 5
13 | width, height = original[0][:, :, 0].shape
14 | img_per_row = 2 if width > 200 else 4
15 | img = Image.new('RGB', (width * img_per_row * 3 + gap * (img_per_row - 1), height * int(len(original) / img_per_row)))
16 |
17 | grayscale = np.array(grayscale).squeeze()
18 | original = np.array(original)
19 | pred = np.array(pred)
20 |
21 | for ix in range(len(original)):
22 | xoffset = int(ix % img_per_row) * width * 3 + int(ix % img_per_row) * gap
23 | yoffset = int(ix / img_per_row) * height
24 | im1 = Image.fromarray(grayscale[ix])
25 | im2 = Image.fromarray(original[ix])
26 | im3 = Image.fromarray((pred[ix] * 255).astype(np.uint8))
27 | img.paste(im1, (xoffset, yoffset))
28 | img.paste(im2, (xoffset + width, yoffset))
29 | img.paste(im3, (xoffset + width + width, yoffset))
30 |
31 | return img
32 |
33 |
34 | def create_dir(dir):
35 | if not os.path.exists(dir):
36 | os.makedirs(dir)
37 |
38 | return dir
39 |
40 |
41 | def unpickle(file):
42 | with open(file, 'rb') as fo:
43 | dict = pickle.load(fo, encoding='bytes')
44 | return dict
45 |
46 |
47 | def moving_average(data, window_width):
48 | cumsum_vec = np.cumsum(np.insert(data, 0, 0))
49 | ma_vec = (cumsum_vec[window_width:] - cumsum_vec[:-window_width]) / window_width
50 | return ma_vec
51 |
52 |
53 | def imshow(img, title=''):
54 | fig = plt.gcf()
55 | fig.canvas.set_window_title(title)
56 | plt.axis('off')
57 | plt.imshow(img, interpolation='none')
58 | plt.show()
59 |
60 |
61 | def imsave(img, path):
62 | im = Image.fromarray(np.array(img).astype(np.uint8).squeeze())
63 | im.save(path)
64 |
65 |
66 | def turing_test(real_img, fake_img, delay=0):
67 | height, width, _ = real_img.shape
68 | imgs = np.array([real_img, (fake_img * 255).astype(np.uint8)])
69 | real_index = np.random.binomial(1, 0.5)
70 | fake_index = (real_index + 1) % 2
71 |
72 | img = Image.new('RGB', (2 + width * 2, height))
73 | img.paste(Image.fromarray(imgs[real_index]), (0, 0))
74 | img.paste(Image.fromarray(imgs[fake_index]), (2 + width, 0))
75 |
76 | img.success = 0
77 |
78 | def onclick(event):
79 | if event.xdata is not None:
80 | if event.x < width and real_index == 0:
81 | img.success = 1
82 |
83 | elif event.x > width and real_index == 1:
84 | img.success = 1
85 |
86 | plt.gcf().canvas.stop_event_loop()
87 |
88 | plt.ion()
89 | plt.gcf().canvas.mpl_connect('button_press_event', onclick)
90 | plt.title('click on the real image')
91 | plt.axis('off')
92 | plt.imshow(img, interpolation='none')
93 | plt.show()
94 | plt.draw()
95 | plt.gcf().canvas.start_event_loop(delay)
96 |
97 | return img.success
98 |
99 |
100 | def visualize(train_log_file, test_log_file, window_width, title=''):
101 | train_data = np.loadtxt(train_log_file)
102 | test_data = np.loadtxt(test_log_file)
103 |
104 | if len(train_data.shape) < 2:
105 | return
106 |
107 | if len(train_data) < window_width:
108 | window_width = len(train_data) - 1
109 |
110 | fig = plt.gcf()
111 | fig.canvas.set_window_title(title)
112 |
113 | plt.ion()
114 | plt.subplot('121')
115 | plt.cla()
116 | if len(train_data) > 1:
117 | plt.plot(moving_average(train_data[:, 8], window_width))
118 | plt.title('train')
119 |
120 | plt.subplot('122')
121 | plt.cla()
122 | if len(test_data) > 1:
123 | plt.plot(test_data[:, 8])
124 | plt.title('test')
125 |
126 | plt.show()
127 | plt.draw()
128 | plt.pause(.01)
129 |
130 |
131 |
132 | class Progbar(object):
133 | """Displays a progress bar.
134 |
135 | Arguments:
136 | target: Total number of steps expected, None if unknown.
137 | width: Progress bar width on screen.
138 | verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
139 | stateful_metrics: Iterable of string names of metrics that
140 | should *not* be averaged over time. Metrics in this list
141 | will be displayed as-is. All others will be averaged
142 | by the progbar before display.
143 | interval: Minimum visual progress update interval (in seconds).
144 | """
145 |
146 | def __init__(self, target, width=25, verbose=1, interval=0.05,
147 | stateful_metrics=None):
148 | self.target = target
149 | self.width = width
150 | self.verbose = verbose
151 | self.interval = interval
152 | if stateful_metrics:
153 | self.stateful_metrics = set(stateful_metrics)
154 | else:
155 | self.stateful_metrics = set()
156 |
157 | self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
158 | sys.stdout.isatty()) or
159 | 'ipykernel' in sys.modules or
160 | 'posix' in sys.modules)
161 | self._total_width = 0
162 | self._seen_so_far = 0
163 | # We use a dict + list to avoid garbage collection
164 | # issues found in OrderedDict
165 | self._values = {}
166 | self._values_order = []
167 | self._start = time.time()
168 | self._last_update = 0
169 |
170 | def update(self, current, values=None):
171 | """Updates the progress bar.
172 |
173 | Arguments:
174 | current: Index of current step.
175 | values: List of tuples:
176 | `(name, value_for_last_step)`.
177 | If `name` is in `stateful_metrics`,
178 | `value_for_last_step` will be displayed as-is.
179 | Else, an average of the metric over time will be displayed.
180 | """
181 | values = values or []
182 | for k, v in values:
183 | if k not in self._values_order:
184 | self._values_order.append(k)
185 | if k not in self.stateful_metrics:
186 | if k not in self._values:
187 | self._values[k] = [v * (current - self._seen_so_far),
188 | current - self._seen_so_far]
189 | else:
190 | self._values[k][0] += v * (current - self._seen_so_far)
191 | self._values[k][1] += (current - self._seen_so_far)
192 | else:
193 | self._values[k] = v
194 | self._seen_so_far = current
195 |
196 | now = time.time()
197 | info = ' - %.0fs' % (now - self._start)
198 | if self.verbose == 1:
199 | if (now - self._last_update < self.interval and
200 | self.target is not None and current < self.target):
201 | return
202 |
203 | prev_total_width = self._total_width
204 | if self._dynamic_display:
205 | sys.stdout.write('\b' * prev_total_width)
206 | sys.stdout.write('\r')
207 | else:
208 | sys.stdout.write('\n')
209 |
210 | if self.target is not None:
211 | numdigits = int(np.floor(np.log10(self.target))) + 1
212 | barstr = '%%%dd/%d [' % (numdigits, self.target)
213 | bar = barstr % current
214 | prog = float(current) / self.target
215 | prog_width = int(self.width * prog)
216 | if prog_width > 0:
217 | bar += ('=' * (prog_width - 1))
218 | if current < self.target:
219 | bar += '>'
220 | else:
221 | bar += '='
222 | bar += ('.' * (self.width - prog_width))
223 | bar += ']'
224 | else:
225 | bar = '%7d/Unknown' % current
226 |
227 | self._total_width = len(bar)
228 | sys.stdout.write(bar)
229 |
230 | if current:
231 | time_per_unit = (now - self._start) / current
232 | else:
233 | time_per_unit = 0
234 | if self.target is not None and current < self.target:
235 | eta = time_per_unit * (self.target - current)
236 | if eta > 3600:
237 | eta_format = '%d:%02d:%02d' % (eta // 3600,
238 | (eta % 3600) // 60,
239 | eta % 60)
240 | elif eta > 60:
241 | eta_format = '%d:%02d' % (eta // 60, eta % 60)
242 | else:
243 | eta_format = '%ds' % eta
244 |
245 | info = ' - ETA: %s' % eta_format
246 | else:
247 | if time_per_unit >= 1:
248 | info += ' %.0fs/step' % time_per_unit
249 | elif time_per_unit >= 1e-3:
250 | info += ' %.0fms/step' % (time_per_unit * 1e3)
251 | else:
252 | info += ' %.0fus/step' % (time_per_unit * 1e6)
253 |
254 | for k in self._values_order:
255 | info += ' - %s:' % k
256 | if isinstance(self._values[k], list):
257 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
258 | if abs(avg) > 1e-3:
259 | info += ' %.4f' % avg
260 | else:
261 | info += ' %.4e' % avg
262 | else:
263 | info += ' %s' % self._values[k]
264 |
265 | self._total_width += len(info)
266 | if prev_total_width > self._total_width:
267 | info += (' ' * (prev_total_width - self._total_width))
268 |
269 | if self.target is not None and current >= self.target:
270 | info += '\n'
271 |
272 | sys.stdout.write(info)
273 | sys.stdout.flush()
274 |
275 | elif self.verbose == 2:
276 | if self.target is None or current >= self.target:
277 | for k in self._values_order:
278 | info += ' - %s:' % k
279 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
280 | if avg > 1e-3:
281 | info += ' %.4f' % avg
282 | else:
283 | info += ' %.4e' % avg
284 | info += '\n'
285 |
286 | sys.stdout.write(info)
287 | sys.stdout.flush()
288 |
289 | self._last_update = now
290 |
291 | def add(self, n, values=None):
292 | self.update(self._seen_so_far + n, values)
293 |
--------------------------------------------------------------------------------
/test-turing.py:
--------------------------------------------------------------------------------
1 | from src import ModelOptions, main
2 |
3 | options = ModelOptions().parse()
4 | options.mode = 2
5 | main(options)
6 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from src import ModelOptions, main
2 |
3 | options = ModelOptions().parse()
4 | options.mode = 1
5 | main(options)
6 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from src import ModelOptions, main
2 |
3 | options = ModelOptions().parse()
4 | options.mode = 0
5 | main(options)
6 |
--------------------------------------------------------------------------------