├── .gitignore ├── LICENSE ├── README.md ├── assets ├── 104050_AE_G.png ├── 104050_G.png ├── 107300_AE_G.png ├── 107300_G.png ├── 115827_AE_G.png ├── 115827_G.png ├── 58600.png ├── 82400_1.png ├── 82400_2.png ├── AE_G1.png ├── AE_G10.png ├── AE_G11.png ├── AE_G12.png ├── AE_G13.png ├── AE_G14.png ├── AE_G15.png ├── AE_G16.png ├── AE_G17.png ├── AE_G18.png ├── AE_G19.png ├── AE_G2.png ├── AE_G20.png ├── AE_G21.png ├── AE_G22.png ├── AE_G23.png ├── AE_G24.png ├── AE_G25.png ├── AE_G26.png ├── AE_G3.png ├── AE_G4.png ├── AE_G5.png ├── AE_G6.png ├── AE_G7.png ├── AE_G8.png ├── AE_G9.png ├── AE_batch.png ├── G1.png ├── G2.png ├── G3.png ├── G4.png ├── G5.png ├── G6.png ├── G7.png ├── all_G_z0_128x128.png ├── all_G_z0_64x64.png ├── interp_1.png ├── interp_10.png ├── interp_2.png ├── interp_3.png ├── interp_4.png ├── interp_5.png ├── interp_6.png ├── interp_7.png ├── interp_8.png ├── interp_9.png ├── interp_G0_128x128.png ├── interp_G0_64x64.png └── model.png ├── config.py ├── data_loader.py ├── download.py ├── folder.py ├── layers.py ├── main.py ├── models.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Data 2 | test* 3 | data/hand 4 | data/gaze 5 | data/* 6 | samples 7 | outputs 8 | 9 | # ipython checkpoints 10 | .ipynb_checkpoints 11 | 12 | # Log 13 | logs 14 | 15 | # ETC 16 | paper.pdf 17 | .DS_Store 18 | 19 | # Created by https://www.gitignore.io/api/python,vim 20 | 21 | ### Python ### 22 | # Byte-compiled / optimized / DLL files 23 | __pycache__/ 24 | *.py[cod] 25 | *$py.class 26 | 27 | # C extensions 28 | *.so 29 | 30 | # Distribution / packaging 31 | .Python 32 | env/ 33 | build/ 34 | develop-eggs/ 35 | dist/ 36 | downloads/ 37 | eggs/ 38 | .eggs/ 39 | lib/ 40 | lib64/ 41 | parts/ 42 | sdist/ 43 | var/ 44 | wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | 49 | # PyInstaller 50 | # Usually these files are written by a python script from a template 51 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 52 | *.manifest 53 | *.spec 54 | 55 | # Installer logs 56 | pip-log.txt 57 | pip-delete-this-directory.txt 58 | 59 | # Unit test / coverage reports 60 | htmlcov/ 61 | .tox/ 62 | .coverage 63 | .coverage.* 64 | .cache 65 | nosetests.xml 66 | coverage.xml 67 | *,cover 68 | .hypothesis/ 69 | 70 | # Translations 71 | *.mo 72 | *.pot 73 | 74 | # Django stuff: 75 | *.log 76 | local_settings.py 77 | 78 | # Flask stuff: 79 | instance/ 80 | .webassets-cache 81 | 82 | # Scrapy stuff: 83 | .scrapy 84 | 85 | # Sphinx documentation 86 | docs/_build/ 87 | 88 | # PyBuilder 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # celery beat schedule file 98 | celerybeat-schedule 99 | 100 | # dotenv 101 | .env 102 | 103 | # virtualenv 104 | .venv/ 105 | venv/ 106 | ENV/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | 115 | ### Vim ### 116 | # swap 117 | [._]*.s[a-v][a-z] 118 | [._]*.sw[a-p] 119 | [._]s[a-v][a-z] 120 | [._]sw[a-p] 121 | # session 122 | Session.vim 123 | # temporary 124 | .netrwhist 125 | *~ 126 | # auto-generated tag files 127 | tags 128 | 129 | # End of https://www.gitignore.io/api/python,vim 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BEGAN in Tensorflow 2 | 3 | Tensorflow implementation of [BEGAN: Boundary Equilibrium Generative Adversarial Networks](https://arxiv.org/abs/1703.10717). 4 | 5 | ![alt tag](./assets/model.png) 6 | 7 | 8 | ## Requirements 9 | 10 | - Python 2.7 or 3.x 11 | - [Pillow](https://pillow.readthedocs.io/en/4.0.x/) 12 | - [tqdm](https://github.com/tqdm/tqdm) 13 | - [requests](https://github.com/kennethreitz/requests) (Only used for downloading CelebA dataset) 14 | - [TensorFlow 1.3.0](https://github.com/tensorflow/tensorflow) 15 | 16 | 17 | ## Usage 18 | 19 | First download [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) datasets with: 20 | 21 | $ apt-get install p7zip-full # ubuntu 22 | $ brew install p7zip # Mac 23 | $ python download.py 24 | 25 | or you can use your own dataset by placing images like: 26 | 27 | data 28 | └── YOUR_DATASET_NAME 29 | ├── xxx.jpg (name doesn't matter) 30 | ├── yyy.jpg 31 | └── ... 32 | 33 | To train a model: 34 | 35 | $ python main.py --dataset=CelebA --use_gpu=True 36 | $ python main.py --dataset=YOUR_DATASET_NAME --use_gpu=True 37 | 38 | To test a model (use your `load_path`): 39 | 40 | $ python main.py --dataset=CelebA --load_path=CelebA_0405_124806 --use_gpu=True --is_train=False --split valid 41 | 42 | 43 | ## Results 44 | 45 | ### Generator output (64x64) with `gamma=0.5` after 300k steps 46 | 47 | ![all_G_z0_64x64](./assets/all_G_z0_64x64.png) 48 | 49 | 50 | ### Generator output (128x128) with `gamma=0.5` after 200k steps 51 | 52 | ![all_G_z0_64x64](./assets/all_G_z0_128x128.png) 53 | 54 | 55 | ### Interpolation of Generator output (64x64) with `gamma=0.5` after 300k steps 56 | 57 | ![interp_G0_64x64](./assets/interp_G0_64x64.png) 58 | 59 | 60 | ### Interpolation of Generator output (128x128) with `gamma=0.5` after 200k steps 61 | 62 | ![interp_G0_128x128](./assets/interp_G0_128x128.png) 63 | 64 | 65 | ### Interpolation of Discriminator output of real images 66 | 67 | ![alt tag](./assets/AE_batch.png) 68 | ![alt tag](./assets/interp_1.png) 69 | ![alt tag](./assets/interp_2.png) 70 | ![alt tag](./assets/interp_3.png) 71 | ![alt tag](./assets/interp_4.png) 72 | ![alt tag](./assets/interp_5.png) 73 | ![alt tag](./assets/interp_6.png) 74 | ![alt tag](./assets/interp_7.png) 75 | ![alt tag](./assets/interp_8.png) 76 | ![alt tag](./assets/interp_9.png) 77 | ![alt tag](./assets/interp_10.png) 78 | 79 | 80 | ## Related works 81 | 82 | - [DCGAN-tensorflow](https://github.com/carpedm20/DCGAN-tensorflow) 83 | - [DiscoGAN-pytorch](https://github.com/carpedm20/DiscoGAN-pytorch) 84 | - [simulated-unsupervised-tensorflow](https://github.com/carpedm20/simulated-unsupervised-tensorflow) 85 | 86 | 87 | ## Author 88 | 89 | Taehoon Kim / [@carpedm20](http://carpedm20.github.io) 90 | -------------------------------------------------------------------------------- /assets/104050_AE_G.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/104050_AE_G.png -------------------------------------------------------------------------------- /assets/104050_G.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/104050_G.png -------------------------------------------------------------------------------- /assets/107300_AE_G.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/107300_AE_G.png -------------------------------------------------------------------------------- /assets/107300_G.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/107300_G.png -------------------------------------------------------------------------------- /assets/115827_AE_G.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/115827_AE_G.png -------------------------------------------------------------------------------- /assets/115827_G.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/115827_G.png -------------------------------------------------------------------------------- /assets/58600.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/58600.png -------------------------------------------------------------------------------- /assets/82400_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/82400_1.png -------------------------------------------------------------------------------- /assets/82400_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/82400_2.png -------------------------------------------------------------------------------- /assets/AE_G1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G1.png -------------------------------------------------------------------------------- /assets/AE_G10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G10.png -------------------------------------------------------------------------------- /assets/AE_G11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G11.png -------------------------------------------------------------------------------- /assets/AE_G12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G12.png -------------------------------------------------------------------------------- /assets/AE_G13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G13.png -------------------------------------------------------------------------------- /assets/AE_G14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G14.png -------------------------------------------------------------------------------- /assets/AE_G15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G15.png -------------------------------------------------------------------------------- /assets/AE_G16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G16.png -------------------------------------------------------------------------------- /assets/AE_G17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G17.png -------------------------------------------------------------------------------- /assets/AE_G18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G18.png -------------------------------------------------------------------------------- /assets/AE_G19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G19.png -------------------------------------------------------------------------------- /assets/AE_G2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G2.png -------------------------------------------------------------------------------- /assets/AE_G20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G20.png -------------------------------------------------------------------------------- /assets/AE_G21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G21.png -------------------------------------------------------------------------------- /assets/AE_G22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G22.png -------------------------------------------------------------------------------- /assets/AE_G23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G23.png -------------------------------------------------------------------------------- /assets/AE_G24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G24.png -------------------------------------------------------------------------------- /assets/AE_G25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G25.png -------------------------------------------------------------------------------- /assets/AE_G26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G26.png -------------------------------------------------------------------------------- /assets/AE_G3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G3.png -------------------------------------------------------------------------------- /assets/AE_G4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G4.png -------------------------------------------------------------------------------- /assets/AE_G5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G5.png -------------------------------------------------------------------------------- /assets/AE_G6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G6.png -------------------------------------------------------------------------------- /assets/AE_G7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G7.png -------------------------------------------------------------------------------- /assets/AE_G8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G8.png -------------------------------------------------------------------------------- /assets/AE_G9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_G9.png -------------------------------------------------------------------------------- /assets/AE_batch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/AE_batch.png -------------------------------------------------------------------------------- /assets/G1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/G1.png -------------------------------------------------------------------------------- /assets/G2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/G2.png -------------------------------------------------------------------------------- /assets/G3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/G3.png -------------------------------------------------------------------------------- /assets/G4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/G4.png -------------------------------------------------------------------------------- /assets/G5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/G5.png -------------------------------------------------------------------------------- /assets/G6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/G6.png -------------------------------------------------------------------------------- /assets/G7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/G7.png -------------------------------------------------------------------------------- /assets/all_G_z0_128x128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/all_G_z0_128x128.png -------------------------------------------------------------------------------- /assets/all_G_z0_64x64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/all_G_z0_64x64.png -------------------------------------------------------------------------------- /assets/interp_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/interp_1.png -------------------------------------------------------------------------------- /assets/interp_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/interp_10.png -------------------------------------------------------------------------------- /assets/interp_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/interp_2.png -------------------------------------------------------------------------------- /assets/interp_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/interp_3.png -------------------------------------------------------------------------------- /assets/interp_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/interp_4.png -------------------------------------------------------------------------------- /assets/interp_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/interp_5.png -------------------------------------------------------------------------------- /assets/interp_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/interp_6.png -------------------------------------------------------------------------------- /assets/interp_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/interp_7.png -------------------------------------------------------------------------------- /assets/interp_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/interp_8.png -------------------------------------------------------------------------------- /assets/interp_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/interp_9.png -------------------------------------------------------------------------------- /assets/interp_G0_128x128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/interp_G0_128x128.png -------------------------------------------------------------------------------- /assets/interp_G0_64x64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/interp_G0_64x64.png -------------------------------------------------------------------------------- /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/BEGAN-tensorflow/722b3f1a2ff9fd79b8508d3fd29b2a15e4e8ae36/assets/model.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | import argparse 3 | 4 | def str2bool(v): 5 | return v.lower() in ('true', '1') 6 | 7 | arg_lists = [] 8 | parser = argparse.ArgumentParser() 9 | 10 | def add_argument_group(name): 11 | arg = parser.add_argument_group(name) 12 | arg_lists.append(arg) 13 | return arg 14 | 15 | # Network 16 | net_arg = add_argument_group('Network') 17 | net_arg.add_argument('--input_scale_size', type=int, default=64, 18 | help='input image will be resized with the given value as width and height') 19 | net_arg.add_argument('--conv_hidden_num', type=int, default=128, 20 | choices=[64, 128],help='n in the paper') 21 | net_arg.add_argument('--z_num', type=int, default=64, choices=[64, 128]) 22 | 23 | # Data 24 | data_arg = add_argument_group('Data') 25 | data_arg.add_argument('--dataset', type=str, default='CelebA') 26 | data_arg.add_argument('--split', type=str, default='train') 27 | data_arg.add_argument('--batch_size', type=int, default=16) 28 | data_arg.add_argument('--grayscale', type=str2bool, default=False) 29 | data_arg.add_argument('--num_worker', type=int, default=4) 30 | 31 | # Training / test parameters 32 | train_arg = add_argument_group('Training') 33 | train_arg.add_argument('--is_train', type=str2bool, default=True) 34 | train_arg.add_argument('--optimizer', type=str, default='adam') 35 | train_arg.add_argument('--max_step', type=int, default=500000) 36 | train_arg.add_argument('--lr_update_step', type=int, default=100000, choices=[100000, 75000]) 37 | train_arg.add_argument('--d_lr', type=float, default=0.00008) 38 | train_arg.add_argument('--g_lr', type=float, default=0.00008) 39 | train_arg.add_argument('--lr_lower_boundary', type=float, default=0.00002) 40 | train_arg.add_argument('--beta1', type=float, default=0.5) 41 | train_arg.add_argument('--beta2', type=float, default=0.999) 42 | train_arg.add_argument('--gamma', type=float, default=0.5) 43 | train_arg.add_argument('--lambda_k', type=float, default=0.001) 44 | train_arg.add_argument('--use_gpu', type=str2bool, default=True) 45 | 46 | # Misc 47 | misc_arg = add_argument_group('Misc') 48 | misc_arg.add_argument('--load_path', type=str, default='') 49 | misc_arg.add_argument('--log_step', type=int, default=50) 50 | misc_arg.add_argument('--save_step', type=int, default=5000) 51 | misc_arg.add_argument('--num_log_samples', type=int, default=3) 52 | misc_arg.add_argument('--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN']) 53 | misc_arg.add_argument('--log_dir', type=str, default='logs') 54 | misc_arg.add_argument('--data_dir', type=str, default='data') 55 | misc_arg.add_argument('--test_data_path', type=str, default=None, 56 | help='directory with images which will be used in test sample generation') 57 | misc_arg.add_argument('--sample_per_image', type=int, default=64, 58 | help='# of sample per image during test sample generation') 59 | misc_arg.add_argument('--random_seed', type=int, default=123) 60 | 61 | def get_config(): 62 | config, unparsed = parser.parse_known_args() 63 | if config.use_gpu: 64 | data_format = 'NCHW' 65 | else: 66 | data_format = 'NHWC' 67 | setattr(config, 'data_format', data_format) 68 | return config, unparsed 69 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from glob import glob 4 | import tensorflow as tf 5 | 6 | def get_loader(root, batch_size, scale_size, data_format, split=None, is_grayscale=False, seed=None): 7 | dataset_name = os.path.basename(root) 8 | if dataset_name in ['CelebA'] and split: 9 | root = os.path.join(root, 'splits', split) 10 | 11 | for ext in ["jpg", "png"]: 12 | paths = glob("{}/*.{}".format(root, ext)) 13 | 14 | if ext == "jpg": 15 | tf_decode = tf.image.decode_jpeg 16 | elif ext == "png": 17 | tf_decode = tf.image.decode_png 18 | 19 | if len(paths) != 0: 20 | break 21 | 22 | with Image.open(paths[0]) as img: 23 | w, h = img.size 24 | shape = [h, w, 3] 25 | 26 | filename_queue = tf.train.string_input_producer(list(paths), shuffle=False, seed=seed) 27 | reader = tf.WholeFileReader() 28 | filename, data = reader.read(filename_queue) 29 | image = tf_decode(data, channels=3) 30 | 31 | if is_grayscale: 32 | image = tf.image.rgb_to_grayscale(image) 33 | image.set_shape(shape) 34 | 35 | min_after_dequeue = 5000 36 | capacity = min_after_dequeue + 3 * batch_size 37 | 38 | queue = tf.train.shuffle_batch( 39 | [image], batch_size=batch_size, 40 | num_threads=4, capacity=capacity, 41 | min_after_dequeue=min_after_dequeue, name='synthetic_inputs') 42 | 43 | if dataset_name in ['CelebA']: 44 | queue = tf.image.crop_to_bounding_box(queue, 50, 25, 128, 128) 45 | queue = tf.image.resize_nearest_neighbor(queue, [scale_size, scale_size]) 46 | else: 47 | queue = tf.image.resize_nearest_neighbor(queue, [scale_size, scale_size]) 48 | 49 | if data_format == 'NCHW': 50 | queue = tf.transpose(queue, [0, 3, 1, 2]) 51 | elif data_format == 'NHWC': 52 | pass 53 | else: 54 | raise Exception("[!] Unkown data_format: {}".format(data_format)) 55 | 56 | return tf.to_float(queue) 57 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modification of 3 | - https://github.com/carpedm20/DCGAN-tensorflow/blob/master/download.py 4 | - http://stackoverflow.com/a/39225039 5 | """ 6 | from __future__ import print_function 7 | import os 8 | import zipfile 9 | import requests 10 | import subprocess 11 | from tqdm import tqdm 12 | from collections import OrderedDict 13 | 14 | def download_file_from_google_drive(id, destination): 15 | URL = "https://docs.google.com/uc?export=download" 16 | session = requests.Session() 17 | 18 | response = session.get(URL, params={ 'id': id }, stream=True) 19 | token = get_confirm_token(response) 20 | 21 | if token: 22 | params = { 'id' : id, 'confirm' : token } 23 | response = session.get(URL, params=params, stream=True) 24 | 25 | save_response_content(response, destination) 26 | 27 | def get_confirm_token(response): 28 | for key, value in response.cookies.items(): 29 | if key.startswith('download_warning'): 30 | return value 31 | return None 32 | 33 | def save_response_content(response, destination, chunk_size=32*1024): 34 | total_size = int(response.headers.get('content-length', 0)) 35 | with open(destination, "wb") as f: 36 | for chunk in tqdm(response.iter_content(chunk_size), total=total_size, 37 | unit='B', unit_scale=True, desc=destination): 38 | if chunk: # filter out keep-alive new chunks 39 | f.write(chunk) 40 | 41 | def unzip(filepath): 42 | print("Extracting: " + filepath) 43 | base_path = os.path.dirname(filepath) 44 | with zipfile.ZipFile(filepath) as zf: 45 | zf.extractall(base_path) 46 | os.remove(filepath) 47 | 48 | def download_celeb_a(base_path): 49 | data_path = os.path.join(base_path, 'CelebA') 50 | images_path = os.path.join(data_path, 'images') 51 | if os.path.exists(data_path): 52 | print('[!] Found Celeb-A - skip') 53 | return 54 | 55 | filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM" 56 | save_path = os.path.join(base_path, filename) 57 | 58 | if os.path.exists(save_path): 59 | print('[*] {} already exists'.format(save_path)) 60 | else: 61 | download_file_from_google_drive(drive_id, save_path) 62 | 63 | zip_dir = '' 64 | with zipfile.ZipFile(save_path) as zf: 65 | zip_dir = zf.namelist()[0] 66 | zf.extractall(base_path) 67 | if not os.path.exists(data_path): 68 | os.mkdir(data_path) 69 | os.rename(os.path.join(base_path, "img_align_celeba"), images_path) 70 | os.remove(save_path) 71 | 72 | def prepare_data_dir(path = './data'): 73 | if not os.path.exists(path): 74 | os.mkdir(path) 75 | 76 | # check, if file exists, make link 77 | def check_link(in_dir, basename, out_dir): 78 | in_file = os.path.join(in_dir, basename) 79 | if os.path.exists(in_file): 80 | link_file = os.path.join(out_dir, basename) 81 | rel_link = os.path.relpath(in_file, out_dir) 82 | os.symlink(rel_link, link_file) 83 | 84 | def add_splits(base_path): 85 | data_path = os.path.join(base_path, 'CelebA') 86 | images_path = os.path.join(data_path, 'images') 87 | train_dir = os.path.join(data_path, 'splits', 'train') 88 | valid_dir = os.path.join(data_path, 'splits', 'valid') 89 | test_dir = os.path.join(data_path, 'splits', 'test') 90 | if not os.path.exists(train_dir): 91 | os.makedirs(train_dir) 92 | if not os.path.exists(valid_dir): 93 | os.makedirs(valid_dir) 94 | if not os.path.exists(test_dir): 95 | os.makedirs(test_dir) 96 | 97 | # these constants based on the standard CelebA splits 98 | NUM_EXAMPLES = 202599 99 | TRAIN_STOP = 162770 100 | VALID_STOP = 182637 101 | 102 | for i in range(0, TRAIN_STOP): 103 | basename = "{:06d}.jpg".format(i+1) 104 | check_link(images_path, basename, train_dir) 105 | for i in range(TRAIN_STOP, VALID_STOP): 106 | basename = "{:06d}.jpg".format(i+1) 107 | check_link(images_path, basename, valid_dir) 108 | for i in range(VALID_STOP, NUM_EXAMPLES): 109 | basename = "{:06d}.jpg".format(i+1) 110 | check_link(images_path, basename, test_dir) 111 | 112 | if __name__ == '__main__': 113 | base_path = './data' 114 | prepare_data_dir() 115 | download_celeb_a(base_path) 116 | add_splits(base_path) 117 | -------------------------------------------------------------------------------- /folder.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | 7 | IMG_EXTENSIONS = [ 8 | '.jpg', '.JPG', '.jpeg', '.JPEG', 9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 10 | ] 11 | 12 | def is_image_file(filename): 13 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 14 | 15 | def make_dataset(dir): 16 | images = [] 17 | for root, _, fnames in sorted(os.walk(dir)): 18 | for fname in sorted(fnames): 19 | if is_image_file(fname): 20 | path = os.path.join(root, fname) 21 | item = (path, 0) 22 | images.append(item) 23 | 24 | return images 25 | 26 | def default_loader(path): 27 | return Image.open(path).convert('RGB') 28 | 29 | class ImageFolder(data.Dataset): 30 | 31 | def __init__(self, root, transform=None, target_transform=None, 32 | loader=default_loader): 33 | imgs = make_dataset(root) 34 | if len(imgs) == 0: 35 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 36 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 37 | 38 | print("Found {} images in subfolders of: {}".format(len(imgs), root)) 39 | 40 | self.root = root 41 | self.imgs = imgs 42 | self.transform = transform 43 | self.target_transform = target_transform 44 | self.loader = loader 45 | 46 | def __getitem__(self, index): 47 | path, target = self.imgs[index] 48 | img = self.loader(path) 49 | if self.transform is not None: 50 | img = self.transform(img) 51 | if self.target_transform is not None: 52 | target = self.target_transform(target) 53 | 54 | return img, target 55 | 56 | def __len__(self): 57 | return len(self.imgs) 58 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/david-berthelot/tf_img_tech/blob/master/tfswag/layers.py 2 | import numpy as N 3 | import numpy.linalg as LA 4 | import tensorflow as tf 5 | 6 | __author__ = 'David Berthelot' 7 | 8 | 9 | def unboxn(vin, n): 10 | """vin = (batch, h, w, depth), returns vout = (batch, n*h, n*w, depth), each pixel is duplicated.""" 11 | s = tf.shape(vin) 12 | vout = tf.concat([vin] * (n ** 2), 0) # Poor man's replacement for tf.tile (required for Adversarial Training support). 13 | vout = tf.reshape(vout, [s[0] * (n ** 2), s[1], s[2], s[3]]) 14 | vout = tf.batch_to_space(vout, [[0, 0], [0, 0]], n) 15 | return vout 16 | 17 | 18 | def boxn(vin, n): 19 | """vin = (batch, h, w, depth), returns vout = (batch, h//n, w//n, depth), each pixel is averaged.""" 20 | if n == 1: 21 | return vin 22 | s = tf.shape(vin) 23 | vout = tf.reshape(vin, [s[0], s[1] // n, n, s[2] // n, n, s[3]]) 24 | vout = tf.reduce_mean(vout, [2, 4]) 25 | return vout 26 | 27 | 28 | class LayerBase: 29 | pass 30 | 31 | 32 | class LayerConv(LayerBase): 33 | def __init__(self, name, w, n, nl=lambda x, y: x + y, strides=(1, 1, 1, 1), 34 | padding='SAME', conv=None, use_bias=True, data_format="NCHW"): 35 | """w = (wy, wx), n = (n_in, n_out)""" 36 | self.nl = nl 37 | self.strides = list(strides) 38 | self.padding = padding 39 | self.data_format = data_format 40 | with tf.name_scope(name): 41 | if conv is None: 42 | conv = tf.Variable(tf.truncated_normal([w[0], w[1], n[0], n[1]], stddev=0.01), name='conv') 43 | self.conv = conv 44 | self.bias = tf.Variable(tf.zeros([n[1]]), name='bias') if use_bias else 0 45 | 46 | def __call__(self, vin): 47 | return self.nl(tf.nn.conv2d(vin, self.conv, strides=self.strides, 48 | padding=self.padding, data_format=self.data_format), self.bias) 49 | 50 | class LayerEncodeConvGrowLinear(LayerBase): 51 | def __init__(self, name, n, width, colors, depth, scales, nl=lambda x, y: x + y, data_format="NCHW"): 52 | with tf.variable_scope(name) as vs: 53 | encode = [] 54 | nn = n 55 | for x in range(scales): 56 | cl = [] 57 | for y in range(depth - 1): 58 | cl.append(LayerConv('conv_%d_%d' % (x, y), [width, width], 59 | [nn, nn], nl, data_format=data_format)) 60 | cl.append(LayerConv('conv_%d_%d' % (x, depth - 1), [width, width], 61 | [nn, nn + n], nl, strides=[1, 2, 2, 1], data_format=data_format)) 62 | encode.append(cl) 63 | nn += n 64 | self.encode = [LayerConv('conv_pre', [width, width], [colors, n], nl, data_format=data_format), encode] 65 | self.variables = tf.contrib.framework.get_variables(vs) 66 | 67 | def __call__(self, vin, carry=0, train=True): 68 | vout = self.encode[0](vin) 69 | for convs in self.encode[1]: 70 | for conv in convs[:-1]: 71 | vtmp = tf.nn.elu(conv(vout)) 72 | vout = carry * vout + (1 - carry) * vtmp 73 | vout = convs[-1](vout) 74 | return vout, self.variables 75 | 76 | 77 | class LayerDecodeConvBlend(LayerBase): 78 | def __init__(self, name, n, width, colors, depth, scales, nl=lambda x, y: x + y, data_format="NCHW"): 79 | with tf.variable_scope(name) as vs: 80 | decode = [] 81 | for x in range(scales): 82 | cl = [] 83 | n2 = 2 * n if x else n 84 | cl.append(LayerConv('conv_%d_%d' % (x, 0), [width, width], 85 | [n2, n], nl, data_format=data_format)) 86 | for y in range(1, depth): 87 | cl.append(LayerConv('conv_%d_%d' % (x, y), [width, width], [n, n], nl, data_format=data_format)) 88 | decode.append(cl) 89 | self.decode = [decode, LayerConv('conv_post', [width, width], [n, colors], data_format=data_format)] 90 | self.variables = tf.contrib.framework.get_variables(vs) 91 | 92 | def __call__(self, data, carry, train=True): 93 | vout = data 94 | layers = [] 95 | for x, convs in enumerate(self.decode[0]): 96 | vout = tf.concat([vout, data], 3) if x else vout 97 | vout = unboxn(convs[0](vout), 2) 98 | data = unboxn(data, 2) 99 | for conv in convs[1:]: 100 | vtmp = tf.nn.elu(conv(vout)) 101 | vout = carry * vout + (1 - carry) * vtmp 102 | layers.append(vout) 103 | return self.decode[1](vout), self.variables 104 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from trainer import Trainer 5 | from config import get_config 6 | from data_loader import get_loader 7 | from utils import prepare_dirs_and_logger, save_config 8 | 9 | def main(config): 10 | prepare_dirs_and_logger(config) 11 | 12 | rng = np.random.RandomState(config.random_seed) 13 | tf.set_random_seed(config.random_seed) 14 | 15 | if config.is_train: 16 | data_path = config.data_path 17 | batch_size = config.batch_size 18 | do_shuffle = True 19 | else: 20 | setattr(config, 'batch_size', 64) 21 | if config.test_data_path is None: 22 | data_path = config.data_path 23 | else: 24 | data_path = config.test_data_path 25 | batch_size = config.sample_per_image 26 | do_shuffle = False 27 | 28 | data_loader = get_loader( 29 | data_path, config.batch_size, config.input_scale_size, 30 | config.data_format, config.split) 31 | trainer = Trainer(config, data_loader) 32 | 33 | if config.is_train: 34 | save_config(config) 35 | trainer.train() 36 | else: 37 | if not config.load_path: 38 | raise Exception("[!] You should specify `load_path` to load a pretrained model") 39 | trainer.test() 40 | 41 | if __name__ == "__main__": 42 | config, unparsed = get_config() 43 | main(config) 44 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | slim = tf.contrib.slim 4 | 5 | def GeneratorCNN(z, hidden_num, output_num, repeat_num, data_format, reuse): 6 | with tf.variable_scope("G", reuse=reuse) as vs: 7 | num_output = int(np.prod([8, 8, hidden_num])) 8 | x = slim.fully_connected(z, num_output, activation_fn=None) 9 | x = reshape(x, 8, 8, hidden_num, data_format) 10 | 11 | for idx in range(repeat_num): 12 | x = slim.conv2d(x, hidden_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format) 13 | x = slim.conv2d(x, hidden_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format) 14 | if idx < repeat_num - 1: 15 | x = upscale(x, 2, data_format) 16 | 17 | out = slim.conv2d(x, 3, 3, 1, activation_fn=None, data_format=data_format) 18 | 19 | variables = tf.contrib.framework.get_variables(vs) 20 | return out, variables 21 | 22 | def DiscriminatorCNN(x, input_channel, z_num, repeat_num, hidden_num, data_format): 23 | with tf.variable_scope("D") as vs: 24 | # Encoder 25 | x = slim.conv2d(x, hidden_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format) 26 | 27 | prev_channel_num = hidden_num 28 | for idx in range(repeat_num): 29 | channel_num = hidden_num * (idx + 1) 30 | x = slim.conv2d(x, channel_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format) 31 | x = slim.conv2d(x, channel_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format) 32 | if idx < repeat_num - 1: 33 | x = slim.conv2d(x, channel_num, 3, 2, activation_fn=tf.nn.elu, data_format=data_format) 34 | #x = tf.contrib.layers.max_pool2d(x, [2, 2], [2, 2], padding='VALID') 35 | 36 | x = tf.reshape(x, [-1, np.prod([8, 8, channel_num])]) 37 | z = x = slim.fully_connected(x, z_num, activation_fn=None) 38 | 39 | # Decoder 40 | num_output = int(np.prod([8, 8, hidden_num])) 41 | x = slim.fully_connected(x, num_output, activation_fn=None) 42 | x = reshape(x, 8, 8, hidden_num, data_format) 43 | 44 | for idx in range(repeat_num): 45 | x = slim.conv2d(x, hidden_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format) 46 | x = slim.conv2d(x, hidden_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format) 47 | if idx < repeat_num - 1: 48 | x = upscale(x, 2, data_format) 49 | 50 | out = slim.conv2d(x, input_channel, 3, 1, activation_fn=None, data_format=data_format) 51 | 52 | variables = tf.contrib.framework.get_variables(vs) 53 | return out, z, variables 54 | 55 | def int_shape(tensor): 56 | shape = tensor.get_shape().as_list() 57 | return [num if num is not None else -1 for num in shape] 58 | 59 | def get_conv_shape(tensor, data_format): 60 | shape = int_shape(tensor) 61 | # always return [N, H, W, C] 62 | if data_format == 'NCHW': 63 | return [shape[0], shape[2], shape[3], shape[1]] 64 | elif data_format == 'NHWC': 65 | return shape 66 | 67 | def nchw_to_nhwc(x): 68 | return tf.transpose(x, [0, 2, 3, 1]) 69 | 70 | def nhwc_to_nchw(x): 71 | return tf.transpose(x, [0, 3, 1, 2]) 72 | 73 | def reshape(x, h, w, c, data_format): 74 | if data_format == 'NCHW': 75 | x = tf.reshape(x, [-1, c, h, w]) 76 | else: 77 | x = tf.reshape(x, [-1, h, w, c]) 78 | return x 79 | 80 | def resize_nearest_neighbor(x, new_size, data_format): 81 | if data_format == 'NCHW': 82 | x = nchw_to_nhwc(x) 83 | x = tf.image.resize_nearest_neighbor(x, new_size) 84 | x = nhwc_to_nchw(x) 85 | else: 86 | x = tf.image.resize_nearest_neighbor(x, new_size) 87 | return x 88 | 89 | def upscale(x, scale, data_format): 90 | _, h, w, _ = get_conv_shape(x, data_format) 91 | return resize_nearest_neighbor(x, (h*scale, w*scale), data_format) 92 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | try: 5 | from StringIO import StringIO 6 | except ImportError: 7 | from io import StringIO 8 | import scipy.misc 9 | import numpy as np 10 | from glob import glob 11 | from tqdm import trange 12 | from itertools import chain 13 | from collections import deque 14 | 15 | from models import * 16 | from utils import save_image 17 | 18 | def next(loader): 19 | return loader.next()[0].data.numpy() 20 | 21 | def to_nhwc(image, data_format): 22 | if data_format == 'NCHW': 23 | new_image = nchw_to_nhwc(image) 24 | else: 25 | new_image = image 26 | return new_image 27 | 28 | def to_nchw_numpy(image): 29 | if image.shape[3] in [1, 3]: 30 | new_image = image.transpose([0, 3, 1, 2]) 31 | else: 32 | new_image = image 33 | return new_image 34 | 35 | def norm_img(image, data_format=None): 36 | image = image/127.5 - 1. 37 | if data_format: 38 | image = to_nhwc(image, data_format) 39 | return image 40 | 41 | def denorm_img(norm, data_format): 42 | return tf.clip_by_value(to_nhwc((norm + 1)*127.5, data_format), 0, 255) 43 | 44 | def slerp(val, low, high): 45 | """Code from https://github.com/soumith/dcgan.torch/issues/14""" 46 | omega = np.arccos(np.clip(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)), -1, 1)) 47 | so = np.sin(omega) 48 | if so == 0: 49 | return (1.0-val) * low + val * high # L'Hopital's rule/LERP 50 | return np.sin((1.0-val)*omega) / so * low + np.sin(val*omega) / so * high 51 | 52 | class Trainer(object): 53 | def __init__(self, config, data_loader): 54 | self.config = config 55 | self.data_loader = data_loader 56 | self.dataset = config.dataset 57 | 58 | self.beta1 = config.beta1 59 | self.beta2 = config.beta2 60 | self.optimizer = config.optimizer 61 | self.batch_size = config.batch_size 62 | 63 | self.step = tf.Variable(0, name='step', trainable=False) 64 | 65 | self.g_lr = tf.Variable(config.g_lr, name='g_lr') 66 | self.d_lr = tf.Variable(config.d_lr, name='d_lr') 67 | 68 | self.g_lr_update = tf.assign(self.g_lr, tf.maximum(self.g_lr * 0.5, config.lr_lower_boundary), name='g_lr_update') 69 | self.d_lr_update = tf.assign(self.d_lr, tf.maximum(self.d_lr * 0.5, config.lr_lower_boundary), name='d_lr_update') 70 | 71 | self.gamma = config.gamma 72 | self.lambda_k = config.lambda_k 73 | 74 | self.z_num = config.z_num 75 | self.conv_hidden_num = config.conv_hidden_num 76 | self.input_scale_size = config.input_scale_size 77 | 78 | self.model_dir = config.model_dir 79 | self.load_path = config.load_path 80 | 81 | self.use_gpu = config.use_gpu 82 | self.data_format = config.data_format 83 | 84 | _, height, width, self.channel = \ 85 | get_conv_shape(self.data_loader, self.data_format) 86 | self.repeat_num = int(np.log2(height)) - 2 87 | 88 | self.start_step = 0 89 | self.log_step = config.log_step 90 | self.max_step = config.max_step 91 | self.save_step = config.save_step 92 | self.lr_update_step = config.lr_update_step 93 | 94 | self.is_train = config.is_train 95 | self.build_model() 96 | 97 | self.saver = tf.train.Saver() 98 | self.summary_writer = tf.summary.FileWriter(self.model_dir) 99 | 100 | sv = tf.train.Supervisor(logdir=self.model_dir, 101 | is_chief=True, 102 | saver=self.saver, 103 | summary_op=None, 104 | summary_writer=self.summary_writer, 105 | save_model_secs=300, 106 | global_step=self.step, 107 | ready_for_local_init_op=None) 108 | 109 | gpu_options = tf.GPUOptions(allow_growth=True) 110 | sess_config = tf.ConfigProto(allow_soft_placement=True, 111 | gpu_options=gpu_options) 112 | 113 | self.sess = sv.prepare_or_wait_for_session(config=sess_config) 114 | 115 | if not self.is_train: 116 | # dirty way to bypass graph finilization error 117 | g = tf.get_default_graph() 118 | g._finalized = False 119 | 120 | self.build_test_model() 121 | 122 | def train(self): 123 | z_fixed = np.random.uniform(-1, 1, size=(self.batch_size, self.z_num)) 124 | 125 | x_fixed = self.get_image_from_loader() 126 | save_image(x_fixed, '{}/x_fixed.png'.format(self.model_dir)) 127 | 128 | prev_measure = 1 129 | measure_history = deque([0]*self.lr_update_step, self.lr_update_step) 130 | 131 | for step in trange(self.start_step, self.max_step): 132 | fetch_dict = { 133 | "k_update": self.k_update, 134 | "measure": self.measure, 135 | } 136 | if step % self.log_step == 0: 137 | fetch_dict.update({ 138 | "summary": self.summary_op, 139 | "g_loss": self.g_loss, 140 | "d_loss": self.d_loss, 141 | "k_t": self.k_t, 142 | }) 143 | result = self.sess.run(fetch_dict) 144 | 145 | measure = result['measure'] 146 | measure_history.append(measure) 147 | 148 | if step % self.log_step == 0: 149 | self.summary_writer.add_summary(result['summary'], step) 150 | self.summary_writer.flush() 151 | 152 | g_loss = result['g_loss'] 153 | d_loss = result['d_loss'] 154 | k_t = result['k_t'] 155 | 156 | print("[{}/{}] Loss_D: {:.6f} Loss_G: {:.6f} measure: {:.4f}, k_t: {:.4f}". \ 157 | format(step, self.max_step, d_loss, g_loss, measure, k_t)) 158 | 159 | if step % (self.log_step * 10) == 0: 160 | x_fake = self.generate(z_fixed, self.model_dir, idx=step) 161 | self.autoencode(x_fixed, self.model_dir, idx=step, x_fake=x_fake) 162 | 163 | if step % self.lr_update_step == self.lr_update_step - 1: 164 | self.sess.run([self.g_lr_update, self.d_lr_update]) 165 | #cur_measure = np.mean(measure_history) 166 | #if cur_measure > prev_measure * 0.99: 167 | #prev_measure = cur_measure 168 | 169 | def build_model(self): 170 | self.x = self.data_loader 171 | x = norm_img(self.x) 172 | 173 | self.z = tf.random_uniform( 174 | (tf.shape(x)[0], self.z_num), minval=-1.0, maxval=1.0) 175 | self.k_t = tf.Variable(0., trainable=False, name='k_t') 176 | 177 | G, self.G_var = GeneratorCNN( 178 | self.z, self.conv_hidden_num, self.channel, 179 | self.repeat_num, self.data_format, reuse=False) 180 | 181 | d_out, self.D_z, self.D_var = DiscriminatorCNN( 182 | tf.concat([G, x], 0), self.channel, self.z_num, self.repeat_num, 183 | self.conv_hidden_num, self.data_format) 184 | AE_G, AE_x = tf.split(d_out, 2) 185 | 186 | self.G = denorm_img(G, self.data_format) 187 | self.AE_G, self.AE_x = denorm_img(AE_G, self.data_format), denorm_img(AE_x, self.data_format) 188 | 189 | if self.optimizer == 'adam': 190 | optimizer = tf.train.AdamOptimizer 191 | else: 192 | raise Exception("[!] Caution! Paper didn't use {} opimizer other than Adam".format(config.optimizer)) 193 | 194 | g_optimizer, d_optimizer = optimizer(self.g_lr), optimizer(self.d_lr) 195 | 196 | self.d_loss_real = tf.reduce_mean(tf.abs(AE_x - x)) 197 | self.d_loss_fake = tf.reduce_mean(tf.abs(AE_G - G)) 198 | 199 | self.d_loss = self.d_loss_real - self.k_t * self.d_loss_fake 200 | self.g_loss = tf.reduce_mean(tf.abs(AE_G - G)) 201 | 202 | d_optim = d_optimizer.minimize(self.d_loss, var_list=self.D_var) 203 | g_optim = g_optimizer.minimize(self.g_loss, global_step=self.step, var_list=self.G_var) 204 | 205 | self.balance = self.gamma * self.d_loss_real - self.g_loss 206 | self.measure = self.d_loss_real + tf.abs(self.balance) 207 | 208 | with tf.control_dependencies([d_optim, g_optim]): 209 | self.k_update = tf.assign( 210 | self.k_t, tf.clip_by_value(self.k_t + self.lambda_k * self.balance, 0, 1)) 211 | 212 | self.summary_op = tf.summary.merge([ 213 | tf.summary.image("G", self.G), 214 | tf.summary.image("AE_G", self.AE_G), 215 | tf.summary.image("AE_x", self.AE_x), 216 | 217 | tf.summary.scalar("loss/d_loss", self.d_loss), 218 | tf.summary.scalar("loss/d_loss_real", self.d_loss_real), 219 | tf.summary.scalar("loss/d_loss_fake", self.d_loss_fake), 220 | tf.summary.scalar("loss/g_loss", self.g_loss), 221 | tf.summary.scalar("misc/measure", self.measure), 222 | tf.summary.scalar("misc/k_t", self.k_t), 223 | tf.summary.scalar("misc/d_lr", self.d_lr), 224 | tf.summary.scalar("misc/g_lr", self.g_lr), 225 | tf.summary.scalar("misc/balance", self.balance), 226 | ]) 227 | 228 | def build_test_model(self): 229 | with tf.variable_scope("test") as vs: 230 | # Extra ops for interpolation 231 | z_optimizer = tf.train.AdamOptimizer(0.0001) 232 | 233 | self.z_r = tf.get_variable("z_r", [self.batch_size, self.z_num], tf.float32) 234 | self.z_r_update = tf.assign(self.z_r, self.z) 235 | 236 | G_z_r, _ = GeneratorCNN( 237 | self.z_r, self.conv_hidden_num, self.channel, self.repeat_num, self.data_format, reuse=True) 238 | 239 | with tf.variable_scope("test") as vs: 240 | self.z_r_loss = tf.reduce_mean(tf.abs(self.x - G_z_r)) 241 | self.z_r_optim = z_optimizer.minimize(self.z_r_loss, var_list=[self.z_r]) 242 | 243 | test_variables = tf.contrib.framework.get_variables(vs) 244 | self.sess.run(tf.variables_initializer(test_variables)) 245 | 246 | def generate(self, inputs, root_path=None, path=None, idx=None, save=True): 247 | x = self.sess.run(self.G, {self.z: inputs}) 248 | if path is None and save: 249 | path = os.path.join(root_path, '{}_G.png'.format(idx)) 250 | save_image(x, path) 251 | print("[*] Samples saved: {}".format(path)) 252 | return x 253 | 254 | def autoencode(self, inputs, path, idx=None, x_fake=None): 255 | items = { 256 | 'real': inputs, 257 | 'fake': x_fake, 258 | } 259 | for key, img in items.items(): 260 | if img is None: 261 | continue 262 | if img.shape[3] in [1, 3]: 263 | img = img.transpose([0, 3, 1, 2]) 264 | 265 | x_path = os.path.join(path, '{}_D_{}.png'.format(idx, key)) 266 | x = self.sess.run(self.AE_x, {self.x: img}) 267 | save_image(x, x_path) 268 | print("[*] Samples saved: {}".format(x_path)) 269 | 270 | def encode(self, inputs): 271 | if inputs.shape[3] in [1, 3]: 272 | inputs = inputs.transpose([0, 3, 1, 2]) 273 | return self.sess.run(self.D_z, {self.x: inputs}) 274 | 275 | def decode(self, z): 276 | return self.sess.run(self.AE_x, {self.D_z: z}) 277 | 278 | def interpolate_G(self, real_batch, step=0, root_path='.', train_epoch=0): 279 | batch_size = len(real_batch) 280 | half_batch_size = int(batch_size/2) 281 | 282 | self.sess.run(self.z_r_update) 283 | tf_real_batch = to_nchw_numpy(real_batch) 284 | for i in trange(train_epoch): 285 | z_r_loss, _ = self.sess.run([self.z_r_loss, self.z_r_optim], {self.x: tf_real_batch}) 286 | z = self.sess.run(self.z_r) 287 | 288 | z1, z2 = z[:half_batch_size], z[half_batch_size:] 289 | real1_batch, real2_batch = real_batch[:half_batch_size], real_batch[half_batch_size:] 290 | 291 | generated = [] 292 | for idx, ratio in enumerate(np.linspace(0, 1, 10)): 293 | z = np.stack([slerp(ratio, r1, r2) for r1, r2 in zip(z1, z2)]) 294 | z_decode = self.generate(z, save=False) 295 | generated.append(z_decode) 296 | 297 | generated = np.stack(generated).transpose([1, 0, 2, 3, 4]) 298 | for idx, img in enumerate(generated): 299 | save_image(img, os.path.join(root_path, 'test{}_interp_G_{}.png'.format(step, idx)), nrow=10) 300 | 301 | all_img_num = np.prod(generated.shape[:2]) 302 | batch_generated = np.reshape(generated, [all_img_num] + list(generated.shape[2:])) 303 | save_image(batch_generated, os.path.join(root_path, 'test{}_interp_G.png'.format(step)), nrow=10) 304 | 305 | def interpolate_D(self, real1_batch, real2_batch, step=0, root_path="."): 306 | real1_encode = self.encode(real1_batch) 307 | real2_encode = self.encode(real2_batch) 308 | 309 | decodes = [] 310 | for idx, ratio in enumerate(np.linspace(0, 1, 10)): 311 | z = np.stack([slerp(ratio, r1, r2) for r1, r2 in zip(real1_encode, real2_encode)]) 312 | z_decode = self.decode(z) 313 | decodes.append(z_decode) 314 | 315 | decodes = np.stack(decodes).transpose([1, 0, 2, 3, 4]) 316 | for idx, img in enumerate(decodes): 317 | img = np.concatenate([[real1_batch[idx]], img, [real2_batch[idx]]], 0) 318 | save_image(img, os.path.join(root_path, 'test{}_interp_D_{}.png'.format(step, idx)), nrow=10 + 2) 319 | 320 | def test(self): 321 | root_path = "./"#self.model_dir 322 | 323 | all_G_z = None 324 | for step in range(3): 325 | real1_batch = self.get_image_from_loader() 326 | real2_batch = self.get_image_from_loader() 327 | 328 | save_image(real1_batch, os.path.join(root_path, 'test{}_real1.png'.format(step))) 329 | save_image(real2_batch, os.path.join(root_path, 'test{}_real2.png'.format(step))) 330 | 331 | self.autoencode( 332 | real1_batch, self.model_dir, idx=os.path.join(root_path, "test{}_real1".format(step))) 333 | self.autoencode( 334 | real2_batch, self.model_dir, idx=os.path.join(root_path, "test{}_real2".format(step))) 335 | 336 | self.interpolate_G(real1_batch, step, root_path) 337 | #self.interpolate_D(real1_batch, real2_batch, step, root_path) 338 | 339 | z_fixed = np.random.uniform(-1, 1, size=(self.batch_size, self.z_num)) 340 | G_z = self.generate(z_fixed, path=os.path.join(root_path, "test{}_G_z.png".format(step))) 341 | 342 | if all_G_z is None: 343 | all_G_z = G_z 344 | else: 345 | all_G_z = np.concatenate([all_G_z, G_z]) 346 | save_image(all_G_z, '{}/G_z{}.png'.format(root_path, step)) 347 | 348 | save_image(all_G_z, '{}/all_G_z.png'.format(root_path), nrow=16) 349 | 350 | def get_image_from_loader(self): 351 | x = self.data_loader.eval(session=self.sess) 352 | if self.data_format == 'NCHW': 353 | x = x.transpose([0, 2, 3, 1]) 354 | return x 355 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import math 5 | import json 6 | import logging 7 | import numpy as np 8 | from PIL import Image 9 | from datetime import datetime 10 | 11 | def prepare_dirs_and_logger(config): 12 | formatter = logging.Formatter("%(asctime)s:%(levelname)s::%(message)s") 13 | logger = logging.getLogger() 14 | 15 | for hdlr in logger.handlers: 16 | logger.removeHandler(hdlr) 17 | 18 | handler = logging.StreamHandler() 19 | handler.setFormatter(formatter) 20 | 21 | logger.addHandler(handler) 22 | 23 | if config.load_path: 24 | if config.load_path.startswith(config.log_dir): 25 | config.model_dir = config.load_path 26 | else: 27 | if config.load_path.startswith(config.dataset): 28 | config.model_name = config.load_path 29 | else: 30 | config.model_name = "{}_{}".format(config.dataset, config.load_path) 31 | else: 32 | config.model_name = "{}_{}".format(config.dataset, get_time()) 33 | 34 | if not hasattr(config, 'model_dir'): 35 | config.model_dir = os.path.join(config.log_dir, config.model_name) 36 | config.data_path = os.path.join(config.data_dir, config.dataset) 37 | 38 | for path in [config.log_dir, config.data_dir, config.model_dir]: 39 | if not os.path.exists(path): 40 | os.makedirs(path) 41 | 42 | def get_time(): 43 | return datetime.now().strftime("%m%d_%H%M%S") 44 | 45 | def save_config(config): 46 | param_path = os.path.join(config.model_dir, "params.json") 47 | 48 | print("[*] MODEL dir: %s" % config.model_dir) 49 | print("[*] PARAM path: %s" % param_path) 50 | 51 | with open(param_path, 'w') as fp: 52 | json.dump(config.__dict__, fp, indent=4, sort_keys=True) 53 | 54 | def rank(array): 55 | return len(array.shape) 56 | 57 | def make_grid(tensor, nrow=8, padding=2, 58 | normalize=False, scale_each=False): 59 | """Code based on https://github.com/pytorch/vision/blob/master/torchvision/utils.py""" 60 | nmaps = tensor.shape[0] 61 | xmaps = min(nrow, nmaps) 62 | ymaps = int(math.ceil(float(nmaps) / xmaps)) 63 | height, width = int(tensor.shape[1] + padding), int(tensor.shape[2] + padding) 64 | grid = np.zeros([height * ymaps + 1 + padding // 2, width * xmaps + 1 + padding // 2, 3], dtype=np.uint8) 65 | k = 0 66 | for y in range(ymaps): 67 | for x in range(xmaps): 68 | if k >= nmaps: 69 | break 70 | h, h_width = y * height + 1 + padding // 2, height - padding 71 | w, w_width = x * width + 1 + padding // 2, width - padding 72 | 73 | grid[h:h+h_width, w:w+w_width] = tensor[k] 74 | k = k + 1 75 | return grid 76 | 77 | def save_image(tensor, filename, nrow=8, padding=2, 78 | normalize=False, scale_each=False): 79 | ndarr = make_grid(tensor, nrow=nrow, padding=padding, 80 | normalize=normalize, scale_each=scale_each) 81 | im = Image.fromarray(ndarr) 82 | im.save(filename) 83 | --------------------------------------------------------------------------------