├── .gitignore ├── LICENSE ├── README.md ├── img ├── cgan.png ├── con_gan.png ├── gan.png ├── gan_new.png ├── places365.jpg ├── places365.png └── unet.png ├── requirements.txt ├── setup.cfg ├── src ├── __init__.py ├── dataset.py ├── main.py ├── models.py ├── networks.py ├── ops.py ├── options.py └── utils.py ├── test-turing.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # custom 104 | _TODO 105 | checkpoints 106 | plots 107 | vcs.xml 108 | .idea 109 | .vscode -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image Colorization with Generative Adversarial Networks 2 | In this work, we generalize the colorization procedure using a conditional Deep Convolutional Generative Adversarial Network (DCGAN) as as suggested by [Pix2Pix](https://arxiv.org/abs/1611.07004). The network is trained on the datasets [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) and [Places365](http://places2.csail.mit.edu). Some of the results from Places365 dataset are [shown here.](#places365-results) 3 | 4 | ## Prerequisites 5 | - Linux 6 | - Tensorflow 1.7 7 | - NVIDIA GPU (12G or 24G memory) + CUDA cuDNN 8 | 9 | ## Getting Started 10 | ### Installation 11 | - Clone this repo: 12 | ```bash 13 | git clone https://github.com/ImagingLab/Colorizing-with-GANs.git 14 | cd Colorizing-with-GANs 15 | ``` 16 | - Install Tensorflow and dependencies from https://www.tensorflow.org/install/ 17 | - Install python requirements: 18 | ```bash 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | ### Dataset 23 | - We use [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) and [Places365](http://places2.csail.mit.edu) datasets. To train a model on the full dataset, download datasets from official websites. 24 | After downloading, put then under the `datasets` folder. 25 | 26 | ### Training 27 | - To train the model, run `main.py` script 28 | ```bash 29 | python main.py 30 | ``` 31 | - To train the model on places365 dataset with tuned hyperparameters: 32 | ``` 33 | python train.py \ 34 | --seed 100 \ 35 | --dataset places365 \ 36 | --dataset-path ./dataset/places365 \ 37 | --checkpoints-path ./checkpoints \ 38 | --batch-size 16 \ 39 | --epochs 10 \ 40 | --lr 3e-4 \ 41 | --label-smoothing 1 42 | 43 | ``` 44 | 45 | - To train the model of cifar10 dataset with tuned hyperparameters: 46 | ``` 47 | python train.py \ 48 | --seed 100 \ 49 | --dataset cifar10 \ 50 | --dataset-path ./dataset/cifar10 \ 51 | --checkpoints-path ./checkpoints \ 52 | --batch-size 128 \ 53 | --epochs 200 \ 54 | --lr 3e-4 \ 55 | --lr-decay-steps 1e4 \ 56 | --augment True 57 | 58 | ``` 59 | 60 | ### Test 61 | - Download the pre-trained weights [from here.](https://drive.google.com/open?id=1jTsAUAKrMiHO2gn7s-fFZ_zUSzgKoPyp) and copy them in the `checkpoints` folder. 62 | - To test the model on a custom image(s), run `test.py` script: 63 | ```bash 64 | python test.py \ 65 | --checkpoints-path ./checkpoints \ # checkpoints path 66 | --test-input ./checkpoints/test \ # test image(s) path 67 | --test-output ./checkpoints/output \ # output image(s) path 68 | ``` 69 | 70 | ### Visual Turing Test 71 | - Download the pre-trained weights [from here.](https://drive.google.com/open?id=1jTsAUAKrMiHO2gn7s-fFZ_zUSzgKoPyp) and copy them in the `checkpoints` folder. 72 | - To evaluate the model qualitatively using visual Turing test, run `test-turing.py`: 73 | ```bash 74 | python test-turing.py 75 | ``` 76 | 77 | - To apply time-based visual Turing test run (2 seconds decision time): 78 | ```bash 79 | python test-turing.py --test-delay 2 80 | ``` 81 | 82 | 83 | ## Networks Architecture 84 | The architecture of generator is inspired by [U-Net](https://arxiv.org/abs/1505.04597): The architecture of the model is symmetric, with `n` encoding units and `n` decoding units. The contracting path consists of 4x4 convolution layers with stride 2 for downsampling, each followed by batch normalization and Leaky-ReLU activation function with the slope of 0.2. The number of channels are doubled after each step. Each unit in the expansive path consists of a 4x4 transposed convolutional layer with stride 2 for upsampling, concatenation with the activation map of the mirroring layer in the contracting path, followed by batch normalization and ReLU activation function. The last layer of the network is a 1x1 convolution which is equivalent to cross-channel parametric pooling layer. We use `tanh` function for the last layer. 85 |

86 | 87 |

88 | 89 | For discriminator, we use patch-gan architecture with contractive path similar to the baselines: a series of 4x4 convolutional layers with stride 2 with the number of channels being doubled after each downsampling. All convolution layers are followed by batch normalization, leaky ReLU activation with slope 0.2. After the last layer, a sigmoid function is applied to return probability values of `70x70` patches of the input being real or fake. We take the average of the probabilities as the network output! 90 | 91 | ## Places365 Results 92 | Colorization results with Places365. (a) Grayscale. (b) Original Image. (c) Colorized with GAN. 93 |

94 | 95 |

96 | 97 | ## Citation 98 | If you use this code for your research, please cite our paper Image Colorization with Generative Adversarial Networks: 99 | 100 | ``` 101 | @inproceedings{nazeri2018image, 102 | title={Image Colorization Using Generative Adversarial Networks}, 103 | author={Nazeri, Kamyar and Ng, Eric and Ebrahimi, Mehran}, 104 | booktitle={International Conference on Articulated Motion and Deformable Objects}, 105 | pages={85--94}, 106 | year={2018}, 107 | organization={Springer} 108 | } 109 | ``` 110 | -------------------------------------------------------------------------------- /img/cgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImagingLab/Colorizing-with-GANs/13bb4c49144788d20729cab120f205440917d8c3/img/cgan.png -------------------------------------------------------------------------------- /img/con_gan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImagingLab/Colorizing-with-GANs/13bb4c49144788d20729cab120f205440917d8c3/img/con_gan.png -------------------------------------------------------------------------------- /img/gan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImagingLab/Colorizing-with-GANs/13bb4c49144788d20729cab120f205440917d8c3/img/gan.png -------------------------------------------------------------------------------- /img/gan_new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImagingLab/Colorizing-with-GANs/13bb4c49144788d20729cab120f205440917d8c3/img/gan_new.png -------------------------------------------------------------------------------- /img/places365.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImagingLab/Colorizing-with-GANs/13bb4c49144788d20729cab120f205440917d8c3/img/places365.jpg -------------------------------------------------------------------------------- /img/places365.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImagingLab/Colorizing-with-GANs/13bb4c49144788d20729cab120f205440917d8c3/img/places365.png -------------------------------------------------------------------------------- /img/unet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImagingLab/Colorizing-with-GANs/13bb4c49144788d20729cab120f205440917d8c3/img/unet.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy ~= 1.19 2 | scipy ~= 1.0.1 3 | future ~= 0.16.0 4 | matplotlib ~= 2.2.2 5 | pillow >= 6.2.0 6 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pycodestyle] 2 | ignore = E303 3 | max-line-length = 200 -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .options import * 2 | from .models import * 3 | from .utils import * 4 | from .dataset import * 5 | from .main import * -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import tensorflow as tf 5 | from scipy.misc import imread 6 | from abc import abstractmethod 7 | from .utils import unpickle 8 | 9 | CIFAR10_DATASET = 'cifar10' 10 | PLACES365_DATASET = 'places365' 11 | 12 | 13 | class BaseDataset(): 14 | def __init__(self, name, path, training=True, augment=True): 15 | self.name = name 16 | self.augment = augment and training 17 | self.training = training 18 | self.path = path 19 | self._data = [] 20 | 21 | def __len__(self): 22 | return len(self.data) 23 | 24 | def __iter__(self): 25 | total = len(self) 26 | start = 0 27 | 28 | while start < total: 29 | item = self[start] 30 | start += 1 31 | yield item 32 | 33 | raise StopIteration 34 | 35 | def __getitem__(self, index): 36 | val = self.data[index] 37 | try: 38 | img = imread(val) if isinstance(val, str) else val 39 | 40 | # grayscale images 41 | if np.sum(img[:,:,0] - img[:,:,1]) == 0 and np.sum(img[:,:,0] - img[:,:,2]) == 0: 42 | return None 43 | 44 | if self.augment and np.random.binomial(1, 0.5) == 1: 45 | img = img[:, ::-1, :] 46 | 47 | except: 48 | img = None 49 | 50 | return img 51 | 52 | def generator(self, batch_size, recusrive=False): 53 | start = 0 54 | total = len(self) 55 | 56 | while True: 57 | while start < total: 58 | end = np.min([start + batch_size, total]) 59 | items = [] 60 | 61 | for ix in range(start, end): 62 | item = self[ix] 63 | if item is not None: 64 | items.append(item) 65 | 66 | start = end 67 | yield items 68 | 69 | if recusrive: 70 | start = 0 71 | 72 | else: 73 | raise StopIteration 74 | 75 | @property 76 | def data(self): 77 | if len(self._data) == 0: 78 | self._data = self.load() 79 | np.random.shuffle(self._data) 80 | 81 | return self._data 82 | 83 | @abstractmethod 84 | def load(self): 85 | return [] 86 | 87 | 88 | class Cifar10Dataset(BaseDataset): 89 | def __init__(self, path, training=True, augment=True): 90 | super(Cifar10Dataset, self).__init__(CIFAR10_DATASET, path, training, augment) 91 | 92 | def load(self): 93 | data = [] 94 | if self.training: 95 | for i in range(1, 6): 96 | filename = '{}/data_batch_{}'.format(self.path, i) 97 | batch_data = unpickle(filename) 98 | if len(data) > 0: 99 | data = np.vstack((data, batch_data[b'data'])) 100 | else: 101 | data = batch_data[b'data'] 102 | 103 | else: 104 | filename = '{}/test_batch'.format(self.path) 105 | batch_data = unpickle(filename) 106 | data = batch_data[b'data'] 107 | 108 | w = 32 109 | h = 32 110 | s = w * h 111 | data = np.array(data) 112 | data = np.dstack((data[:, :s], data[:, s:2 * s], data[:, 2 * s:])) 113 | data = data.reshape((-1, w, h, 3)) 114 | return data 115 | 116 | 117 | class Places365Dataset(BaseDataset): 118 | def __init__(self, path, training=True, augment=True): 119 | super(Places365Dataset, self).__init__(PLACES365_DATASET, path, training, augment) 120 | 121 | def load(self): 122 | if self.training: 123 | flist = os.path.join(self.path, 'train.flist') 124 | if os.path.exists(flist): 125 | data = np.genfromtxt(flist, dtype=np.str, encoding='utf-8') 126 | else: 127 | data = glob.glob(self.path + '/data_256/**/*.jpg', recursive=True) 128 | np.savetxt(flist, data, fmt='%s') 129 | 130 | else: 131 | flist = os.path.join(self.path, 'test.flist') 132 | if os.path.exists(flist): 133 | data = np.genfromtxt(flist, dtype=np.str, encoding='utf-8') 134 | else: 135 | data = np.array(glob.glob(self.path + '/val_256/*.jpg')) 136 | np.savetxt(flist, data, fmt='%s') 137 | 138 | return data 139 | 140 | 141 | class TestDataset(BaseDataset): 142 | def __init__(self, path): 143 | super(TestDataset, self).__init__('TEST', path, training=False, augment=False) 144 | 145 | def __getitem__(self, index): 146 | path = self.data[index] 147 | img = imread(path) 148 | return path, img 149 | 150 | def load(self): 151 | 152 | if os.path.isfile(self.path): 153 | data = [self.path] 154 | 155 | elif os.path.isdir(self.path): 156 | data = list(glob.glob(self.path + '/*.jpg')) + list(glob.glob(self.path + '/*.png')) 157 | 158 | return data 159 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import tensorflow as tf 5 | from .options import ModelOptions 6 | from .models import Cifar10Model, Places365Model 7 | from .dataset import CIFAR10_DATASET, PLACES365_DATASET 8 | 9 | 10 | def main(options): 11 | 12 | # reset tensorflow graph 13 | tf.reset_default_graph() 14 | 15 | 16 | # initialize random seed 17 | tf.set_random_seed(options.seed) 18 | np.random.seed(options.seed) 19 | random.seed(options.seed) 20 | 21 | 22 | # create a session environment 23 | with tf.Session() as sess: 24 | 25 | if options.dataset == CIFAR10_DATASET: 26 | model = Cifar10Model(sess, options) 27 | 28 | elif options.dataset == PLACES365_DATASET: 29 | model = Places365Model(sess, options) 30 | 31 | if not os.path.exists(options.checkpoints_path): 32 | os.makedirs(options.checkpoints_path) 33 | 34 | if options.log: 35 | open(model.train_log_file, 'w').close() 36 | open(model.test_log_file, 'w').close() 37 | 38 | # build the model and initialize 39 | model.build() 40 | sess.run(tf.global_variables_initializer()) 41 | 42 | 43 | # load model only after global variables initialization 44 | model.load() 45 | 46 | 47 | if options.mode == 0: 48 | args = vars(options) 49 | print('\n------------ Options -------------') 50 | with open(os.path.join(options.checkpoints_path, 'options.dat'), 'w') as f: 51 | for k, v in sorted(args.items()): 52 | print('%s: %s' % (str(k), str(v))) 53 | f.write('%s: %s\n' % (str(k), str(v))) 54 | print('-------------- End ----------------\n') 55 | 56 | model.train() 57 | 58 | elif options.mode == 1: 59 | model.test() 60 | 61 | else: 62 | model.turing_test() 63 | 64 | 65 | if __name__ == "__main__": 66 | main(ModelOptions().parse()) 67 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import time 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from abc import abstractmethod 9 | from .networks import Generator, Discriminator 10 | from .ops import pixelwise_accuracy, preprocess, postprocess 11 | from .ops import COLORSPACE_RGB, COLORSPACE_LAB 12 | from .dataset import Places365Dataset, Cifar10Dataset, TestDataset 13 | from .utils import stitch_images, turing_test, imshow, imsave, create_dir, visualize, Progbar 14 | 15 | 16 | class BaseModel: 17 | def __init__(self, sess, options): 18 | self.sess = sess 19 | self.options = options 20 | self.name = options.name 21 | self.samples_dir = os.path.join(options.checkpoints_path, 'samples') 22 | self.test_log_file = os.path.join(options.checkpoints_path, 'log_test.dat') 23 | self.train_log_file = os.path.join(options.checkpoints_path, 'log_train.dat') 24 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 25 | self.dataset_train = self.create_dataset(True) 26 | self.dataset_val = self.create_dataset(False) 27 | self.sample_generator = self.dataset_val.generator(options.sample_size, True) 28 | self.iteration = 0 29 | self.epoch = 0 30 | self.is_built = False 31 | 32 | def train(self): 33 | total = len(self.dataset_train) 34 | 35 | for epoch in range(self.options.epochs): 36 | lr_rate = self.sess.run(self.learning_rate) 37 | 38 | print('Training epoch: %d' % (epoch + 1) + " - learning rate: " + str(lr_rate)) 39 | 40 | self.epoch = epoch + 1 41 | self.iteration = 0 42 | 43 | generator = self.dataset_train.generator(self.options.batch_size) 44 | progbar = Progbar(total, width=25, stateful_metrics=['epoch', 'iter', 'step']) 45 | 46 | for input_rgb in generator: 47 | feed_dic = {self.input_rgb: input_rgb} 48 | 49 | self.iteration = self.iteration + 1 50 | self.sess.run([self.dis_train], feed_dict=feed_dic) 51 | self.sess.run([self.gen_train, self.accuracy], feed_dict=feed_dic) 52 | self.sess.run([self.gen_train, self.accuracy], feed_dict=feed_dic) 53 | 54 | lossD, lossD_fake, lossD_real, lossG, lossG_l1, lossG_gan, acc, step = self.eval_outputs(feed_dic=feed_dic) 55 | 56 | progbar.add(len(input_rgb), values=[ 57 | ("epoch", epoch + 1), 58 | ("iter", self.iteration), 59 | ("step", step), 60 | ("D loss", lossD), 61 | ("D fake", lossD_fake), 62 | ("D real", lossD_real), 63 | ("G loss", lossG), 64 | ("G L1", lossG_l1), 65 | ("G gan", lossG_gan), 66 | ("accuracy", acc) 67 | ]) 68 | 69 | # log model at checkpoints 70 | if self.options.log and step % self.options.log_interval == 0: 71 | with open(self.train_log_file, 'a') as f: 72 | f.write('%d %d %f %f %f %f %f %f %f\n' % (self.epoch, step, lossD, lossD_fake, lossD_real, lossG, lossG_l1, lossG_gan, acc)) 73 | 74 | if self.options.visualize: 75 | visualize(self.train_log_file, self.test_log_file, self.options.visualize_window, self.name) 76 | 77 | # sample model at checkpoints 78 | if self.options.sample and step % self.options.sample_interval == 0: 79 | self.sample(show=False) 80 | 81 | # validate model at checkpoints 82 | if self.options.validate and self.options.validate_interval > 0 and step % self.options.validate_interval == 0: 83 | self.validate() 84 | 85 | # save model at checkpoints 86 | if self.options.save and step % self.options.save_interval == 0: 87 | self.save() 88 | 89 | if self.options.validate: 90 | self.validate() 91 | 92 | def validate(self): 93 | print('\n\nValidating epoch: %d' % self.epoch) 94 | total = len(self.dataset_val) 95 | val_generator = self.dataset_val.generator(self.options.batch_size) 96 | progbar = Progbar(total, width=25) 97 | 98 | for input_rgb in val_generator: 99 | feed_dic = {self.input_rgb: input_rgb} 100 | 101 | self.sess.run([self.dis_loss, self.gen_loss, self.accuracy], feed_dict=feed_dic) 102 | 103 | lossD, lossD_fake, lossD_real, lossG, lossG_l1, lossG_gan, acc, step = self.eval_outputs(feed_dic=feed_dic) 104 | 105 | progbar.add(len(input_rgb), values=[ 106 | ("D loss", lossD), 107 | ("D fake", lossD_fake), 108 | ("D real", lossD_real), 109 | ("G loss", lossG), 110 | ("G L1", lossG_l1), 111 | ("G gan", lossG_gan), 112 | ("accuracy", acc) 113 | ]) 114 | 115 | print('\n') 116 | 117 | def test(self): 118 | print('\nTesting...') 119 | dataset = TestDataset(self.options.test_input or (self.options.checkpoints_path + '/test')) 120 | outputs_path = create_dir(self.options.test_output or (self.options.checkpoints_path + '/output')) 121 | 122 | for index in range(len(dataset)): 123 | img_gray_path, img_gray = dataset[index] 124 | name = os.path.basename(img_gray_path) 125 | path = os.path.join(outputs_path, name) 126 | 127 | feed_dic = {self.input_gray: img_gray[None, :, :, None]} 128 | outputs = self.sess.run(self.sampler, feed_dict=feed_dic) 129 | outputs = postprocess(tf.convert_to_tensor(outputs), colorspace_in=self.options.color_space, colorspace_out=COLORSPACE_RGB).eval() * 255 130 | print(path) 131 | imsave(outputs[0], path) 132 | 133 | def sample(self, show=True): 134 | input_rgb = next(self.sample_generator) 135 | feed_dic = {self.input_rgb: input_rgb} 136 | 137 | step, rate = self.sess.run([self.global_step, self.learning_rate]) 138 | fake_image, input_gray = self.sess.run([self.sampler, self.input_gray], feed_dict=feed_dic) 139 | fake_image = postprocess(tf.convert_to_tensor(fake_image), colorspace_in=self.options.color_space, colorspace_out=COLORSPACE_RGB) 140 | img = stitch_images(input_gray, input_rgb, fake_image.eval()) 141 | 142 | create_dir(self.samples_dir) 143 | sample = self.options.dataset + "_" + str(step).zfill(5) + ".png" 144 | 145 | if show: 146 | imshow(np.array(img), self.name) 147 | else: 148 | print('\nsaving sample ' + sample + ' - learning rate: ' + str(rate)) 149 | img.save(os.path.join(self.samples_dir, sample)) 150 | 151 | def turing_test(self): 152 | batch_size = self.options.batch_size 153 | gen = self.dataset_val.generator(batch_size, True) 154 | count = 0 155 | score = 0 156 | size = self.options.turing_test_size 157 | 158 | while count < size: 159 | input_rgb = next(gen) 160 | feed_dic = {self.input_rgb: input_rgb} 161 | fake_image = self.sess.run(self.sampler, feed_dict=feed_dic) 162 | fake_image = postprocess(tf.convert_to_tensor(fake_image), colorspace_in=self.options.color_space, colorspace_out=COLORSPACE_RGB) 163 | 164 | for i in range(np.min([batch_size, size - count])): 165 | res = turing_test(input_rgb[i], fake_image.eval()[i], self.options.turing_test_delay) 166 | count += 1 167 | score += res 168 | print('success: %d - fail: %d - rate: %f' % (score, count - score, (count - score) / count)) 169 | 170 | def build(self): 171 | if self.is_built: 172 | return 173 | 174 | self.is_built = True 175 | 176 | gen_factory = self.create_generator() 177 | dis_factory = self.create_discriminator() 178 | smoothing = 0.9 if self.options.label_smoothing else 1 179 | seed = self.options.seed 180 | kernel = 4 181 | 182 | # model input placeholder: RGB imaege 183 | self.input_rgb = tf.placeholder(tf.float32, shape=(None, None, None, 3), name='input_rgb') 184 | 185 | # model input after preprocessing: LAB image 186 | self.input_color = preprocess(self.input_rgb, colorspace_in=COLORSPACE_RGB, colorspace_out=self.options.color_space) 187 | 188 | # test mode: model input is a graycale placeholder 189 | if self.options.mode == 1: 190 | self.input_gray = tf.placeholder(tf.float32, shape=(None, None, None, 1), name='input_gray') 191 | 192 | # train/turing-test we extract grayscale image from color image 193 | else: 194 | self.input_gray = tf.image.rgb_to_grayscale(self.input_rgb) 195 | 196 | gen = gen_factory.create(self.input_gray, kernel, seed) 197 | dis_real = dis_factory.create(tf.concat([self.input_gray, self.input_color], 3), kernel, seed) 198 | dis_fake = dis_factory.create(tf.concat([self.input_gray, gen], 3), kernel, seed, reuse_variables=True) 199 | 200 | gen_ce = tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_fake, labels=tf.ones_like(dis_fake)) 201 | dis_real_ce = tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_real, labels=tf.ones_like(dis_real) * smoothing) 202 | dis_fake_ce = tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_fake, labels=tf.zeros_like(dis_fake)) 203 | 204 | self.dis_loss_real = tf.reduce_mean(dis_real_ce) 205 | self.dis_loss_fake = tf.reduce_mean(dis_fake_ce) 206 | self.dis_loss = tf.reduce_mean(dis_real_ce + dis_fake_ce) 207 | 208 | self.gen_loss_gan = tf.reduce_mean(gen_ce) 209 | self.gen_loss_l1 = tf.reduce_mean(tf.abs(self.input_color - gen)) * self.options.l1_weight 210 | self.gen_loss = self.gen_loss_gan + self.gen_loss_l1 211 | 212 | self.sampler = tf.identity(gen_factory.create(self.input_gray, kernel, seed, reuse_variables=True), name='output') 213 | self.accuracy = pixelwise_accuracy(self.input_color, gen, self.options.color_space, self.options.acc_thresh) 214 | self.learning_rate = tf.constant(self.options.lr) 215 | 216 | # learning rate decay 217 | if self.options.lr_decay and self.options.lr_decay_rate > 0: 218 | self.learning_rate = tf.maximum(1e-6, tf.train.exponential_decay( 219 | learning_rate=self.options.lr, 220 | global_step=self.global_step, 221 | decay_steps=self.options.lr_decay_steps, 222 | decay_rate=self.options.lr_decay_rate)) 223 | 224 | # generator optimizaer 225 | self.gen_train = tf.train.AdamOptimizer( 226 | learning_rate=self.learning_rate, 227 | beta1=self.options.beta1 228 | ).minimize(self.gen_loss, var_list=gen_factory.var_list) 229 | 230 | # discriminator optimizaer 231 | self.dis_train = tf.train.AdamOptimizer( 232 | learning_rate=self.learning_rate / 10, 233 | beta1=self.options.beta1 234 | ).minimize(self.dis_loss, var_list=dis_factory.var_list, global_step=self.global_step) 235 | 236 | self.saver = tf.train.Saver() 237 | 238 | def load(self): 239 | ckpt = tf.train.get_checkpoint_state(self.options.checkpoints_path) 240 | if ckpt is not None: 241 | print('loading model...\n') 242 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 243 | self.saver.restore(self.sess, os.path.join(self.options.checkpoints_path, ckpt_name)) 244 | return True 245 | 246 | return False 247 | 248 | def save(self): 249 | print('saving model...\n') 250 | self.saver.save(self.sess, os.path.join(self.options.checkpoints_path, 'CGAN_' + self.options.dataset), write_meta_graph=False) 251 | 252 | def eval_outputs(self, feed_dic): 253 | ''' 254 | evaluates the loss and accuracy 255 | returns (D loss, D_fake loss, D_real loss, G loss, G_L1 loss, G_gan loss, accuracy, step) 256 | ''' 257 | lossD_fake = self.dis_loss_fake.eval(feed_dict=feed_dic) 258 | lossD_real = self.dis_loss_real.eval(feed_dict=feed_dic) 259 | lossD = self.dis_loss.eval(feed_dict=feed_dic) 260 | 261 | lossG_l1 = self.gen_loss_l1.eval(feed_dict=feed_dic) 262 | lossG_gan = self.gen_loss_gan.eval(feed_dict=feed_dic) 263 | lossG = lossG_l1 + lossG_gan 264 | 265 | acc = self.accuracy.eval(feed_dict=feed_dic) 266 | step = self.sess.run(self.global_step) 267 | 268 | return lossD, lossD_fake, lossD_real, lossG, lossG_l1, lossG_gan, acc, step 269 | 270 | @abstractmethod 271 | def create_generator(self): 272 | raise NotImplementedError 273 | 274 | @abstractmethod 275 | def create_discriminator(self): 276 | raise NotImplementedError 277 | 278 | @abstractmethod 279 | def create_dataset(self, training): 280 | raise NotImplementedError 281 | 282 | 283 | class Cifar10Model(BaseModel): 284 | def __init__(self, sess, options): 285 | super(Cifar10Model, self).__init__(sess, options) 286 | 287 | def create_generator(self): 288 | kernels_gen_encoder = [ 289 | (64, 1, 0), # [batch, 32, 32, ch] => [batch, 32, 32, 64] 290 | (128, 2, 0), # [batch, 32, 32, 64] => [batch, 16, 16, 128] 291 | (256, 2, 0), # [batch, 16, 16, 128] => [batch, 8, 8, 256] 292 | (512, 2, 0), # [batch, 8, 8, 256] => [batch, 4, 4, 512] 293 | (512, 2, 0), # [batch, 4, 4, 512] => [batch, 2, 2, 512] 294 | ] 295 | 296 | kernels_gen_decoder = [ 297 | (512, 2, 0.5), # [batch, 2, 2, 512] => [batch, 4, 4, 512] 298 | (256, 2, 0.5), # [batch, 4, 4, 512] => [batch, 8, 8, 256] 299 | (128, 2, 0), # [batch, 8, 8, 256] => [batch, 16, 16, 128] 300 | (64, 2, 0), # [batch, 16, 16, 128] => [batch, 32, 32, 64] 301 | ] 302 | 303 | return Generator('gen', kernels_gen_encoder, kernels_gen_decoder, training=self.options.training) 304 | 305 | def create_discriminator(self): 306 | kernels_dis = [ 307 | (64, 2, 0), # [batch, 32, 32, ch] => [batch, 16, 16, 64] 308 | (128, 2, 0), # [batch, 16, 16, 64] => [batch, 8, 8, 128] 309 | (256, 2, 0), # [batch, 8, 8, 128] => [batch, 4, 4, 256] 310 | (512, 1, 0), # [batch, 4, 4, 256] => [batch, 4, 4, 512] 311 | ] 312 | 313 | return Discriminator('dis', kernels_dis, training=self.options.training) 314 | 315 | def create_dataset(self, training=True): 316 | return Cifar10Dataset( 317 | path=self.options.dataset_path, 318 | training=training, 319 | augment=self.options.augment) 320 | 321 | 322 | class Places365Model(BaseModel): 323 | def __init__(self, sess, options): 324 | super(Places365Model, self).__init__(sess, options) 325 | 326 | def create_generator(self): 327 | kernels_gen_encoder = [ 328 | (64, 1, 0), # [batch, 256, 256, ch] => [batch, 256, 256, 64] 329 | (64, 2, 0), # [batch, 256, 256, 64] => [batch, 128, 128, 64] 330 | (128, 2, 0), # [batch, 128, 128, 64] => [batch, 64, 64, 128] 331 | (256, 2, 0), # [batch, 64, 64, 128] => [batch, 32, 32, 256] 332 | (512, 2, 0), # [batch, 32, 32, 256] => [batch, 16, 16, 512] 333 | (512, 2, 0), # [batch, 16, 16, 512] => [batch, 8, 8, 512] 334 | (512, 2, 0), # [batch, 8, 8, 512] => [batch, 4, 4, 512] 335 | (512, 2, 0) # [batch, 4, 4, 512] => [batch, 2, 2, 512] 336 | ] 337 | 338 | kernels_gen_decoder = [ 339 | (512, 2, 0), # [batch, 2, 2, 512] => [batch, 4, 4, 512] 340 | (512, 2, 0), # [batch, 4, 4, 512] => [batch, 8, 8, 512] 341 | (512, 2, 0), # [batch, 8, 8, 512] => [batch, 16, 16, 512] 342 | (256, 2, 0), # [batch, 16, 16, 512] => [batch, 32, 32, 256] 343 | (128, 2, 0), # [batch, 32, 32, 256] => [batch, 64, 64, 128] 344 | (64, 2, 0), # [batch, 64, 64, 128] => [batch, 128, 128, 64] 345 | (64, 2, 0) # [batch, 128, 128, 64] => [batch, 256, 256, 64] 346 | ] 347 | 348 | return Generator('gen', kernels_gen_encoder, kernels_gen_decoder, training=self.options.training) 349 | 350 | def create_discriminator(self): 351 | kernels_dis = [ 352 | (64, 2, 0), # [batch, 256, 256, ch] => [batch, 128, 128, 64] 353 | (128, 2, 0), # [batch, 128, 128, 64] => [batch, 64, 64, 128] 354 | (256, 2, 0), # [batch, 64, 64, 128] => [batch, 32, 32, 256] 355 | (512, 1, 0), # [batch, 32, 32, 256] => [batch, 32, 32, 512] 356 | ] 357 | 358 | return Discriminator('dis', kernels_dis, training=self.options.training) 359 | 360 | def create_dataset(self, training=True): 361 | return Places365Dataset( 362 | path=self.options.dataset_path, 363 | training=training, 364 | augment=self.options.augment) 365 | -------------------------------------------------------------------------------- /src/networks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from .ops import conv2d, conv2d_transpose, pixelwise_accuracy 4 | 5 | 6 | class Discriminator(object): 7 | def __init__(self, name, kernels, training=True): 8 | self.name = name 9 | self.kernels = kernels 10 | self.training = training 11 | self.var_list = [] 12 | 13 | def create(self, inputs, kernel_size=None, seed=None, reuse_variables=None): 14 | output = inputs 15 | with tf.variable_scope(self.name, reuse=reuse_variables): 16 | for index, kernel in enumerate(self.kernels): 17 | 18 | # not use batch-norm in the first layer 19 | bnorm = False if index == 0 else True 20 | name = 'conv' + str(index) 21 | output = conv2d( 22 | inputs=output, 23 | name=name, 24 | kernel_size=kernel_size, 25 | filters=kernel[0], 26 | strides=kernel[1], 27 | bnorm=bnorm, 28 | activation=tf.nn.leaky_relu, 29 | seed=seed 30 | ) 31 | 32 | if kernel[2] > 0: 33 | keep_prob = 1.0 - kernel[2] if self.training else 1.0 34 | output = tf.nn.dropout(output, keep_prob=keep_prob, name='dropout_' + name, seed=seed) 35 | 36 | output = conv2d( 37 | inputs=output, 38 | name='conv_last', 39 | filters=1, 40 | kernel_size=4, # last layer kernel size = 4 41 | strides=1, # last layer stride = 1 42 | bnorm=False, # do not use batch-norm for the last layer 43 | seed=seed 44 | ) 45 | 46 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name) 47 | 48 | return output 49 | 50 | 51 | class Generator(object): 52 | def __init__(self, name, encoder_kernels, decoder_kernels, output_channels=3, training=True): 53 | self.name = name 54 | self.encoder_kernels = encoder_kernels 55 | self.decoder_kernels = decoder_kernels 56 | self.output_channels = output_channels 57 | self.training = training 58 | self.var_list = [] 59 | 60 | def create(self, inputs, kernel_size=None, seed=None, reuse_variables=None): 61 | output = inputs 62 | 63 | with tf.variable_scope(self.name, reuse=reuse_variables): 64 | 65 | layers = [] 66 | 67 | # encoder branch 68 | for index, kernel in enumerate(self.encoder_kernels): 69 | 70 | name = 'conv' + str(index) 71 | output = conv2d( 72 | inputs=output, 73 | name=name, 74 | kernel_size=kernel_size, 75 | filters=kernel[0], 76 | strides=kernel[1], 77 | activation=tf.nn.leaky_relu, 78 | seed=seed 79 | ) 80 | 81 | # save contracting path layers to be used for skip connections 82 | layers.append(output) 83 | 84 | if kernel[2] > 0: 85 | keep_prob = 1.0 - kernel[2] if self.training else 1.0 86 | output = tf.nn.dropout(output, keep_prob=keep_prob, name='dropout_' + name, seed=seed) 87 | 88 | # decoder branch 89 | for index, kernel in enumerate(self.decoder_kernels): 90 | 91 | name = 'deconv' + str(index) 92 | output = conv2d_transpose( 93 | inputs=output, 94 | name=name, 95 | kernel_size=kernel_size, 96 | filters=kernel[0], 97 | strides=kernel[1], 98 | activation=tf.nn.relu, 99 | seed=seed 100 | ) 101 | 102 | if kernel[2] > 0: 103 | keep_prob = 1.0 - kernel[2] if self.training else 1.0 104 | output = tf.nn.dropout(output, keep_prob=keep_prob, name='dropout_' + name, seed=seed) 105 | 106 | # concat the layer from the contracting path with the output of the current layer 107 | # concat only the channels (axis=3) 108 | output = tf.concat([layers[len(layers) - index - 2], output], axis=3) 109 | 110 | output = conv2d( 111 | inputs=output, 112 | name='conv_last', 113 | filters=self.output_channels, # number of output chanels 114 | kernel_size=1, # last layer kernel size = 1 115 | strides=1, # last layer stride = 1 116 | bnorm=False, # do not use batch-norm for the last layer 117 | activation=tf.nn.tanh, # tanh activation function for the output 118 | seed=seed 119 | ) 120 | 121 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name) 122 | 123 | return output 124 | -------------------------------------------------------------------------------- /src/ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | COLORSPACE_RGB = 'RGB' 5 | COLORSPACE_LAB = 'LAB' 6 | tf.nn.softmax_cross_entropy_with_logits_v2 7 | 8 | def conv2d(inputs, filters, name, kernel_size=4, strides=2, bnorm=True, activation=None, seed=None): 9 | """ 10 | Creates a conv2D block 11 | """ 12 | initializer=tf.variance_scaling_initializer(seed=seed) 13 | res = tf.layers.conv2d( 14 | name=name, 15 | inputs=inputs, 16 | filters=filters, 17 | kernel_size=kernel_size, 18 | strides=strides, 19 | padding="same", 20 | kernel_initializer=initializer) 21 | 22 | if bnorm: 23 | res = tf.layers.batch_normalization(inputs=res, name='bn_' + name, training=True) 24 | 25 | # activation after batch-norm 26 | if activation is not None: 27 | res = activation(res) 28 | 29 | return res 30 | 31 | 32 | def conv2d_transpose(inputs, filters, name, kernel_size=4, strides=2, bnorm=True, activation=None, seed=None): 33 | """ 34 | Creates a conv2D-transpose block 35 | """ 36 | initializer=tf.variance_scaling_initializer(seed=seed) 37 | res = tf.layers.conv2d_transpose( 38 | name=name, 39 | inputs=inputs, 40 | filters=filters, 41 | kernel_size=kernel_size, 42 | strides=strides, 43 | padding="same", 44 | kernel_initializer=initializer) 45 | 46 | if bnorm: 47 | res = tf.layers.batch_normalization(inputs=res, name='bn_' + name, training=True) 48 | 49 | # activation after batch-norm 50 | if activation is not None: 51 | res = activation(res) 52 | 53 | return res 54 | 55 | 56 | def pixelwise_accuracy(img_real, img_fake, colorspace, thresh): 57 | """ 58 | Measures the accuracy of the colorization process by comparing pixels 59 | """ 60 | img_real = postprocess(img_real, colorspace, COLORSPACE_LAB) 61 | img_fake = postprocess(img_fake, colorspace, COLORSPACE_LAB) 62 | 63 | diffL = tf.abs(tf.round(img_real[..., 0]) - tf.round(img_fake[..., 0])) 64 | diffA = tf.abs(tf.round(img_real[..., 1]) - tf.round(img_fake[..., 1])) 65 | diffB = tf.abs(tf.round(img_real[..., 2]) - tf.round(img_fake[..., 2])) 66 | 67 | # within %thresh of the original 68 | predL = tf.cast(tf.less_equal(diffL, 1 * thresh), tf.float64) # L: [0, 100] 69 | predA = tf.cast(tf.less_equal(diffA, 2.2 * thresh), tf.float64) # A: [-110, 110] 70 | predB = tf.cast(tf.less_equal(diffB, 2.2 * thresh), tf.float64) # B: [-110, 110] 71 | 72 | # all three channels are within the threshold 73 | pred = predL * predA * predB 74 | 75 | return tf.reduce_mean(pred) 76 | 77 | 78 | def preprocess(img, colorspace_in, colorspace_out): 79 | if colorspace_out.upper() == COLORSPACE_RGB: 80 | if colorspace_in == COLORSPACE_LAB: 81 | img = lab_to_rgb(img) 82 | 83 | # [0, 1] => [-1, 1] 84 | img = (img / 255.0) * 2 - 1 85 | 86 | elif colorspace_out.upper() == COLORSPACE_LAB: 87 | if colorspace_in == COLORSPACE_RGB: 88 | img = rgb_to_lab(img / 255.0) 89 | 90 | L_chan, a_chan, b_chan = tf.unstack(img, axis=3) 91 | 92 | # L: [0, 100] => [-1, 1] 93 | # A, B: [-110, 110] => [-1, 1] 94 | img = tf.stack([L_chan / 50 - 1, a_chan / 110, b_chan / 110], axis=3) 95 | 96 | return img 97 | 98 | 99 | def postprocess(img, colorspace_in, colorspace_out): 100 | if colorspace_in.upper() == COLORSPACE_RGB: 101 | # [-1, 1] => [0, 1] 102 | img = (img + 1) / 2 103 | 104 | if colorspace_out == COLORSPACE_LAB: 105 | img = rgb_to_lab(img) 106 | 107 | elif colorspace_in.upper() == COLORSPACE_LAB: 108 | L_chan, a_chan, b_chan = tf.unstack(img, axis=3) 109 | 110 | # L: [-1, 1] => [0, 100] 111 | # A, B: [-1, 1] => [-110, 110] 112 | img = tf.stack([(L_chan + 1) / 2 * 100, a_chan * 110, b_chan * 110], axis=3) 113 | 114 | if colorspace_out == COLORSPACE_RGB: 115 | img = lab_to_rgb(img) 116 | 117 | return img 118 | 119 | 120 | def rgb_to_lab(srgb): 121 | # based on https://github.com/torch/image/blob/9f65c30167b2048ecbe8b7befdc6b2d6d12baee9/generic/image.c 122 | with tf.name_scope("rgb_to_lab"): 123 | srgb_pixels = tf.reshape(srgb, [-1, 3]) 124 | 125 | with tf.name_scope("srgb_to_xyz"): 126 | linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32) 127 | exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32) 128 | rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask 129 | rgb_to_xyz = tf.constant([ 130 | # X Y Z 131 | [0.412453, 0.212671, 0.019334], # R 132 | [0.357580, 0.715160, 0.119193], # G 133 | [0.180423, 0.072169, 0.950227], # B 134 | ]) 135 | xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz) 136 | 137 | # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions 138 | with tf.name_scope("xyz_to_cielab"): 139 | 140 | # normalize for D65 white point 141 | xyz_normalized_pixels = tf.multiply(xyz_pixels, [1 / 0.950456, 1.0, 1 / 1.088754]) 142 | 143 | epsilon = 6 / 29 144 | linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32) 145 | exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32) 146 | fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4 / 29) * linear_mask + (xyz_normalized_pixels ** (1 / 3)) * exponential_mask 147 | 148 | # convert to lab 149 | fxfyfz_to_lab = tf.constant([ 150 | # l a b 151 | [0.0, 500.0, 0.0], # fx 152 | [116.0, -500.0, 200.0], # fy 153 | [0.0, 0.0, -200.0], # fz 154 | ]) 155 | lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0]) 156 | 157 | return tf.reshape(lab_pixels, tf.shape(srgb)) 158 | 159 | 160 | def lab_to_rgb(lab): 161 | with tf.name_scope("lab_to_rgb"): 162 | lab_pixels = tf.reshape(lab, [-1, 3]) 163 | 164 | # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions 165 | with tf.name_scope("cielab_to_xyz"): 166 | # convert to fxfyfz 167 | lab_to_fxfyfz = tf.constant([ 168 | # fx fy fz 169 | [1 / 116.0, 1 / 116.0, 1 / 116.0], # l 170 | [1 / 500.0, 0.0, 0.0], # a 171 | [0.0, 0.0, -1 / 200.0], # b 172 | ]) 173 | fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz) 174 | 175 | # convert to xyz 176 | epsilon = 6 / 29 177 | linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32) 178 | exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32) 179 | xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4 / 29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask 180 | 181 | # denormalize for D65 white point 182 | xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754]) 183 | 184 | with tf.name_scope("xyz_to_srgb"): 185 | xyz_to_rgb = tf.constant([ 186 | # r g b 187 | [3.2404542, -0.9692660, 0.0556434], # x 188 | [-1.5371385, 1.8760108, -0.2040259], # y 189 | [-0.4985314, 0.0415560, 1.0572252], # z 190 | ]) 191 | rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb) 192 | # avoid a slightly negative number messing up the conversion 193 | rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0) 194 | linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32) 195 | exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32) 196 | srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1 / 2.4) * 1.055) - 0.055) * exponential_mask 197 | 198 | return tf.reshape(srgb_pixels, tf.shape(lab)) 199 | -------------------------------------------------------------------------------- /src/options.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import random 4 | import argparse 5 | 6 | 7 | def str2bool(v): 8 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 9 | return True 10 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 11 | return False 12 | else: 13 | raise argparse.ArgumentTypeError('Boolean value expected.') 14 | 15 | 16 | class ModelOptions: 17 | def __init__(self): 18 | parser = argparse.ArgumentParser(description='Colorization with GANs') 19 | parser.add_argument('--seed', type=int, default=0, metavar='S', help='random seed (default: 0)') 20 | parser.add_argument('--name', type=str, default='CGAN', help='arbitrary model name (default: CGAN)') 21 | parser.add_argument('--mode', default=0, help='run mode [0: train, 1: test, 2: turing-test] (default: 0)') 22 | parser.add_argument('--dataset', type=str, default='places365', help='the name of dataset [places365, cifar10] (default: places365)') 23 | parser.add_argument('--dataset-path', type=str, default='./dataset', help='dataset path (default: ./dataset)') 24 | parser.add_argument('--checkpoints-path', type=str, default='./checkpoints', help='models are saved here (default: ./checkpoints)') 25 | parser.add_argument('--batch-size', type=int, default=16, metavar='N', help='input batch size for training (default: 16)') 26 | parser.add_argument('--color-space', type=str, default='lab', help='model color space [lab, rgb] (default: lab)') 27 | parser.add_argument('--epochs', type=int, default=30, metavar='N', help='number of epochs to train (default: 30)') 28 | parser.add_argument('--lr', type=float, default=3e-4, metavar='LR', help='learning rate (default: 3e-4)') 29 | parser.add_argument('--lr-decay', type=str2bool, default=True, help='True for learning-rate decay (default: True)') 30 | parser.add_argument('--lr-decay-rate', type=float, default=0.1, help='learning rate exponentially decay rate (default: 0.1)') 31 | parser.add_argument('--lr-decay-steps', type=float, default=5e5, help='learning rate exponentially decay steps (default: 1e5)') 32 | parser.add_argument('--beta1', type=float, default=0, help='momentum term of adam optimizer (default: 0)') 33 | parser.add_argument("--l1-weight", type=float, default=100.0, help="weight on L1 term for generator gradient (default: 100.0)") 34 | parser.add_argument('--augment', type=str2bool, default=True, help='True for augmentation (default: True)') 35 | parser.add_argument('--label-smoothing', type=str2bool, default=False, help='True for one-sided label smoothing (default: False)') 36 | parser.add_argument('--acc-thresh', type=float, default=2.0, help="accuracy threshold (default: 2.0)") 37 | parser.add_argument('--gpu-ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 38 | 39 | parser.add_argument('--save', type=str2bool, default=True, help='True for saving (default: True)') 40 | parser.add_argument('--save-interval', type=int, default=1000, help='how many batches to wait before saving model (default: 1000)') 41 | parser.add_argument('--sample', type=str2bool, default=True, help='True for sampling (default: True)') 42 | parser.add_argument('--sample-size', type=int, default=8, help='number of images to sample (default: 8)') 43 | parser.add_argument('--sample-interval', type=int, default=1000, help='how many batches to wait before sampling (default: 1000)') 44 | parser.add_argument('--validate', type=str2bool, default=True, help='True for validation (default: True)') 45 | parser.add_argument('--validate-interval', type=int, default=0, help='how many batches to wait before validating (default: 0)') 46 | parser.add_argument('--log', type=str2bool, default=False, help='True for logging (default: True)') 47 | parser.add_argument('--log-interval', type=int, default=10, help='how many iterations to wait before logging training status (default: 10)') 48 | parser.add_argument('--visualize', type=str2bool, default=False, help='True for accuracy visualization (default: False)') 49 | parser.add_argument('--visualize-window', type=int, default=100, help='the exponentially moving average window width (default: 100)') 50 | 51 | parser.add_argument('--test-input', type=str, default='', help='path to the grayscale images directory or a grayscale file') 52 | parser.add_argument('--test-output', type=str, default='', help='model test output directory') 53 | parser.add_argument('--turing-test-size', type=int, default=100, metavar='N', help='number of Turing tests (default: 100)') 54 | parser.add_argument('--turing-test-delay', type=int, default=0, metavar='N', help='number of seconds to wait when doing Turing test, 0 for unlimited (default: 0)') 55 | 56 | self._parser = parser 57 | 58 | def parse(self): 59 | opt = self._parser.parse_args() 60 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids 61 | 62 | opt.color_space = opt.color_space.upper() 63 | opt.training = opt.mode == 1 64 | 65 | if opt.seed == 0: 66 | opt.seed = random.randint(0, 2**31 - 1) 67 | 68 | if opt.dataset_path == './dataset': 69 | opt.dataset_path += ('/' + opt.dataset) 70 | 71 | if opt.checkpoints_path == './checkpoints': 72 | opt.checkpoints_path += ('/' + opt.dataset) 73 | 74 | return opt 75 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | import pickle 6 | import numpy as np 7 | from PIL import Image 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | def stitch_images(grayscale, original, pred): 12 | gap = 5 13 | width, height = original[0][:, :, 0].shape 14 | img_per_row = 2 if width > 200 else 4 15 | img = Image.new('RGB', (width * img_per_row * 3 + gap * (img_per_row - 1), height * int(len(original) / img_per_row))) 16 | 17 | grayscale = np.array(grayscale).squeeze() 18 | original = np.array(original) 19 | pred = np.array(pred) 20 | 21 | for ix in range(len(original)): 22 | xoffset = int(ix % img_per_row) * width * 3 + int(ix % img_per_row) * gap 23 | yoffset = int(ix / img_per_row) * height 24 | im1 = Image.fromarray(grayscale[ix]) 25 | im2 = Image.fromarray(original[ix]) 26 | im3 = Image.fromarray((pred[ix] * 255).astype(np.uint8)) 27 | img.paste(im1, (xoffset, yoffset)) 28 | img.paste(im2, (xoffset + width, yoffset)) 29 | img.paste(im3, (xoffset + width + width, yoffset)) 30 | 31 | return img 32 | 33 | 34 | def create_dir(dir): 35 | if not os.path.exists(dir): 36 | os.makedirs(dir) 37 | 38 | return dir 39 | 40 | 41 | def unpickle(file): 42 | with open(file, 'rb') as fo: 43 | dict = pickle.load(fo, encoding='bytes') 44 | return dict 45 | 46 | 47 | def moving_average(data, window_width): 48 | cumsum_vec = np.cumsum(np.insert(data, 0, 0)) 49 | ma_vec = (cumsum_vec[window_width:] - cumsum_vec[:-window_width]) / window_width 50 | return ma_vec 51 | 52 | 53 | def imshow(img, title=''): 54 | fig = plt.gcf() 55 | fig.canvas.set_window_title(title) 56 | plt.axis('off') 57 | plt.imshow(img, interpolation='none') 58 | plt.show() 59 | 60 | 61 | def imsave(img, path): 62 | im = Image.fromarray(np.array(img).astype(np.uint8).squeeze()) 63 | im.save(path) 64 | 65 | 66 | def turing_test(real_img, fake_img, delay=0): 67 | height, width, _ = real_img.shape 68 | imgs = np.array([real_img, (fake_img * 255).astype(np.uint8)]) 69 | real_index = np.random.binomial(1, 0.5) 70 | fake_index = (real_index + 1) % 2 71 | 72 | img = Image.new('RGB', (2 + width * 2, height)) 73 | img.paste(Image.fromarray(imgs[real_index]), (0, 0)) 74 | img.paste(Image.fromarray(imgs[fake_index]), (2 + width, 0)) 75 | 76 | img.success = 0 77 | 78 | def onclick(event): 79 | if event.xdata is not None: 80 | if event.x < width and real_index == 0: 81 | img.success = 1 82 | 83 | elif event.x > width and real_index == 1: 84 | img.success = 1 85 | 86 | plt.gcf().canvas.stop_event_loop() 87 | 88 | plt.ion() 89 | plt.gcf().canvas.mpl_connect('button_press_event', onclick) 90 | plt.title('click on the real image') 91 | plt.axis('off') 92 | plt.imshow(img, interpolation='none') 93 | plt.show() 94 | plt.draw() 95 | plt.gcf().canvas.start_event_loop(delay) 96 | 97 | return img.success 98 | 99 | 100 | def visualize(train_log_file, test_log_file, window_width, title=''): 101 | train_data = np.loadtxt(train_log_file) 102 | test_data = np.loadtxt(test_log_file) 103 | 104 | if len(train_data.shape) < 2: 105 | return 106 | 107 | if len(train_data) < window_width: 108 | window_width = len(train_data) - 1 109 | 110 | fig = plt.gcf() 111 | fig.canvas.set_window_title(title) 112 | 113 | plt.ion() 114 | plt.subplot('121') 115 | plt.cla() 116 | if len(train_data) > 1: 117 | plt.plot(moving_average(train_data[:, 8], window_width)) 118 | plt.title('train') 119 | 120 | plt.subplot('122') 121 | plt.cla() 122 | if len(test_data) > 1: 123 | plt.plot(test_data[:, 8]) 124 | plt.title('test') 125 | 126 | plt.show() 127 | plt.draw() 128 | plt.pause(.01) 129 | 130 | 131 | 132 | class Progbar(object): 133 | """Displays a progress bar. 134 | 135 | Arguments: 136 | target: Total number of steps expected, None if unknown. 137 | width: Progress bar width on screen. 138 | verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) 139 | stateful_metrics: Iterable of string names of metrics that 140 | should *not* be averaged over time. Metrics in this list 141 | will be displayed as-is. All others will be averaged 142 | by the progbar before display. 143 | interval: Minimum visual progress update interval (in seconds). 144 | """ 145 | 146 | def __init__(self, target, width=25, verbose=1, interval=0.05, 147 | stateful_metrics=None): 148 | self.target = target 149 | self.width = width 150 | self.verbose = verbose 151 | self.interval = interval 152 | if stateful_metrics: 153 | self.stateful_metrics = set(stateful_metrics) 154 | else: 155 | self.stateful_metrics = set() 156 | 157 | self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and 158 | sys.stdout.isatty()) or 159 | 'ipykernel' in sys.modules or 160 | 'posix' in sys.modules) 161 | self._total_width = 0 162 | self._seen_so_far = 0 163 | # We use a dict + list to avoid garbage collection 164 | # issues found in OrderedDict 165 | self._values = {} 166 | self._values_order = [] 167 | self._start = time.time() 168 | self._last_update = 0 169 | 170 | def update(self, current, values=None): 171 | """Updates the progress bar. 172 | 173 | Arguments: 174 | current: Index of current step. 175 | values: List of tuples: 176 | `(name, value_for_last_step)`. 177 | If `name` is in `stateful_metrics`, 178 | `value_for_last_step` will be displayed as-is. 179 | Else, an average of the metric over time will be displayed. 180 | """ 181 | values = values or [] 182 | for k, v in values: 183 | if k not in self._values_order: 184 | self._values_order.append(k) 185 | if k not in self.stateful_metrics: 186 | if k not in self._values: 187 | self._values[k] = [v * (current - self._seen_so_far), 188 | current - self._seen_so_far] 189 | else: 190 | self._values[k][0] += v * (current - self._seen_so_far) 191 | self._values[k][1] += (current - self._seen_so_far) 192 | else: 193 | self._values[k] = v 194 | self._seen_so_far = current 195 | 196 | now = time.time() 197 | info = ' - %.0fs' % (now - self._start) 198 | if self.verbose == 1: 199 | if (now - self._last_update < self.interval and 200 | self.target is not None and current < self.target): 201 | return 202 | 203 | prev_total_width = self._total_width 204 | if self._dynamic_display: 205 | sys.stdout.write('\b' * prev_total_width) 206 | sys.stdout.write('\r') 207 | else: 208 | sys.stdout.write('\n') 209 | 210 | if self.target is not None: 211 | numdigits = int(np.floor(np.log10(self.target))) + 1 212 | barstr = '%%%dd/%d [' % (numdigits, self.target) 213 | bar = barstr % current 214 | prog = float(current) / self.target 215 | prog_width = int(self.width * prog) 216 | if prog_width > 0: 217 | bar += ('=' * (prog_width - 1)) 218 | if current < self.target: 219 | bar += '>' 220 | else: 221 | bar += '=' 222 | bar += ('.' * (self.width - prog_width)) 223 | bar += ']' 224 | else: 225 | bar = '%7d/Unknown' % current 226 | 227 | self._total_width = len(bar) 228 | sys.stdout.write(bar) 229 | 230 | if current: 231 | time_per_unit = (now - self._start) / current 232 | else: 233 | time_per_unit = 0 234 | if self.target is not None and current < self.target: 235 | eta = time_per_unit * (self.target - current) 236 | if eta > 3600: 237 | eta_format = '%d:%02d:%02d' % (eta // 3600, 238 | (eta % 3600) // 60, 239 | eta % 60) 240 | elif eta > 60: 241 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 242 | else: 243 | eta_format = '%ds' % eta 244 | 245 | info = ' - ETA: %s' % eta_format 246 | else: 247 | if time_per_unit >= 1: 248 | info += ' %.0fs/step' % time_per_unit 249 | elif time_per_unit >= 1e-3: 250 | info += ' %.0fms/step' % (time_per_unit * 1e3) 251 | else: 252 | info += ' %.0fus/step' % (time_per_unit * 1e6) 253 | 254 | for k in self._values_order: 255 | info += ' - %s:' % k 256 | if isinstance(self._values[k], list): 257 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 258 | if abs(avg) > 1e-3: 259 | info += ' %.4f' % avg 260 | else: 261 | info += ' %.4e' % avg 262 | else: 263 | info += ' %s' % self._values[k] 264 | 265 | self._total_width += len(info) 266 | if prev_total_width > self._total_width: 267 | info += (' ' * (prev_total_width - self._total_width)) 268 | 269 | if self.target is not None and current >= self.target: 270 | info += '\n' 271 | 272 | sys.stdout.write(info) 273 | sys.stdout.flush() 274 | 275 | elif self.verbose == 2: 276 | if self.target is None or current >= self.target: 277 | for k in self._values_order: 278 | info += ' - %s:' % k 279 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 280 | if avg > 1e-3: 281 | info += ' %.4f' % avg 282 | else: 283 | info += ' %.4e' % avg 284 | info += '\n' 285 | 286 | sys.stdout.write(info) 287 | sys.stdout.flush() 288 | 289 | self._last_update = now 290 | 291 | def add(self, n, values=None): 292 | self.update(self._seen_so_far + n, values) 293 | -------------------------------------------------------------------------------- /test-turing.py: -------------------------------------------------------------------------------- 1 | from src import ModelOptions, main 2 | 3 | options = ModelOptions().parse() 4 | options.mode = 2 5 | main(options) 6 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from src import ModelOptions, main 2 | 3 | options = ModelOptions().parse() 4 | options.mode = 1 5 | main(options) 6 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from src import ModelOptions, main 2 | 3 | options = ModelOptions().parse() 4 | options.mode = 0 5 | main(options) 6 | --------------------------------------------------------------------------------