├── .gitignore ├── Dockerfile ├── LICENSE.md ├── README.md ├── TUTORIAL.md ├── USAGE.md ├── configs ├── demo_edges2handbags_folder.yaml ├── demo_edges2handbags_list.yaml ├── edges2handbags_folder.yaml ├── edges2shoes_folder.yaml ├── summer2winter_yosemite256_folder.yaml └── synthia2cityscape_folder.yaml ├── data.py ├── datasets └── demo_edges2handbags │ ├── list_testA.txt │ ├── list_testB.txt │ ├── list_trainA.txt │ ├── list_trainB.txt │ ├── testA │ ├── 00000.jpg │ ├── 00001.jpg │ ├── 00002.jpg │ ├── 00003.jpg │ ├── 00004.jpg │ ├── 00005.jpg │ ├── 00006.jpg │ ├── 00007.jpg │ └── 00008.jpg │ ├── testB │ ├── 00000.jpg │ ├── 00001.jpg │ ├── 00002.jpg │ ├── 00003.jpg │ ├── 00004.jpg │ ├── 00005.jpg │ ├── 00006.jpg │ ├── 00007.jpg │ └── 00008.jpg │ ├── trainA │ ├── 00000.jpg │ ├── 00002.jpg │ ├── 00004.jpg │ ├── 00006.jpg │ ├── 00008.jpg │ ├── 00010.jpg │ ├── 00012.jpg │ └── 00014.jpg │ └── trainB │ ├── 00001.jpg │ ├── 00003.jpg │ ├── 00005.jpg │ ├── 00007.jpg │ ├── 00009.jpg │ ├── 00011.jpg │ ├── 00013.jpg │ └── 00015.jpg ├── docs └── munit_assumption.jpg ├── inputs ├── edges2handbags_edge.jpg ├── edges2handbags_handbag.jpg ├── edges2shoes_edge.jpg └── edges2shoes_shoe.jpg ├── networks.py ├── results ├── animal.jpg ├── edges2handbags │ ├── input.jpg │ ├── output000.jpg │ ├── output001.jpg │ ├── output002.jpg │ ├── output003.jpg │ ├── output004.jpg │ ├── output005.jpg │ ├── output006.jpg │ ├── output007.jpg │ ├── output008.jpg │ └── output009.jpg ├── edges2shoes │ ├── input.jpg │ ├── output000.jpg │ ├── output001.jpg │ ├── output002.jpg │ ├── output003.jpg │ ├── output004.jpg │ ├── output005.jpg │ ├── output006.jpg │ ├── output007.jpg │ ├── output008.jpg │ └── output009.jpg ├── edges2shoes_handbags.jpg ├── example_guided.jpg ├── input.jpg ├── output000.jpg ├── street.jpg ├── summer2winter_yosemite.jpg └── video.jpg ├── scripts ├── demo_train_edges2handbags.sh ├── demo_train_edges2shoes.sh └── demo_train_summer2winter_yosemite256.sh ├── test.py ├── test_batch.py ├── train.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | outputs/ 2 | src/models/ 3 | data/ 4 | logs/ 5 | .idea/ 6 | dgx/scripts/ 7 | .ipynb_checkpoints/ 8 | src/*jpg 9 | notebooks/.ipynb_checkpoints/* 10 | exps/ 11 | src/yaml_generator.py 12 | *.tar.gz 13 | *ipynb 14 | *.zip 15 | *.pkl 16 | *.pyc 17 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.1-cudnn7-runtime-ubuntu16.04 2 | # Set anaconda path 3 | ENV ANACONDA /opt/anaconda 4 | ENV PATH $ANACONDA/bin:$PATH 5 | RUN apt-get update && apt-get install -y --no-install-recommends \ 6 | wget \ 7 | libopencv-dev \ 8 | python-opencv \ 9 | build-essential \ 10 | cmake \ 11 | git \ 12 | curl \ 13 | ca-certificates \ 14 | libjpeg-dev \ 15 | libpng-dev \ 16 | axel \ 17 | zip \ 18 | unzip 19 | RUN wget https://repo.continuum.io/archive/Anaconda3-5.0.1-Linux-x86_64.sh -P /tmp 20 | RUN bash /tmp/Anaconda3-5.0.1-Linux-x86_64.sh -b -p $ANACONDA 21 | RUN rm /tmp/Anaconda3-5.0.1-Linux-x86_64.sh -rf 22 | RUN conda install -y pytorch=0.4.1 torchvision cuda91 -c pytorch 23 | RUN conda install -y -c anaconda pip 24 | RUN conda install -y -c anaconda yaml 25 | RUN pip install tensorboard tensorboardX; 26 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | ## creative commons 2 | 3 | # Attribution-NonCommercial-ShareAlike 4.0 International 4 | 5 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible. 6 | 7 | ### Using Creative Commons Public Licenses 8 | 9 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. 10 | 11 | * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors). 12 | 13 | * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees). 14 | 15 | ## Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License 16 | 17 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. 18 | 19 | ### Section 1 – Definitions. 20 | 21 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. 22 | 23 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. 24 | 25 | c. __BY-NC-SA Compatible License__ means a license listed at [creativecommons.org/compatiblelicenses](http://creativecommons.org/compatiblelicenses), approved by Creative Commons as essentially the equivalent of this Public License. 26 | 27 | d. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. 28 | 29 | e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. 30 | 31 | f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. 32 | 33 | g. __License Elements__ means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike. 34 | 35 | h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License. 36 | 37 | i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. 38 | 39 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License. 40 | 41 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. 42 | 43 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. 44 | 45 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. 46 | 47 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. 48 | 49 | ### Section 2 – Scope. 50 | 51 | a. ___License grant.___ 52 | 53 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: 54 | 55 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and 56 | 57 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. 58 | 59 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 60 | 61 | 3. __Term.__ The term of this Public License is specified in Section 6(a). 62 | 63 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. 64 | 65 | 5. __Downstream recipients.__ 66 | 67 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. 68 | 69 | B. __Additional offer from the Licensor – Adapted Material.__ Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter’s License You apply. 70 | 71 | C. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 72 | 73 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). 74 | 75 | b. ___Other rights.___ 76 | 77 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 78 | 79 | 2. Patent and trademark rights are not licensed under this Public License. 80 | 81 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. 82 | 83 | ### Section 3 – License Conditions. 84 | 85 | Your exercise of the Licensed Rights is expressly made subject to the following conditions. 86 | 87 | a. ___Attribution.___ 88 | 89 | 1. If You Share the Licensed Material (including in modified form), You must: 90 | 91 | A. retain the following if it is supplied by the Licensor with the Licensed Material: 92 | 93 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); 94 | 95 | ii. a copyright notice; 96 | 97 | iii. a notice that refers to this Public License; 98 | 99 | iv. a notice that refers to the disclaimer of warranties; 100 | 101 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; 102 | 103 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and 104 | 105 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 106 | 107 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 108 | 109 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. 110 | 111 | b. ___ShareAlike.___ 112 | 113 | In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply. 114 | 115 | 1. The Adapter’s License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License. 116 | 117 | 2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material. 118 | 119 | 3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply. 120 | 121 | ### Section 4 – Sui Generis Database Rights. 122 | 123 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: 124 | 125 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; 126 | 127 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and 128 | 129 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. 130 | 131 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. 132 | 133 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability. 134 | 135 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__ 136 | 137 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__ 138 | 139 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. 140 | 141 | ### Section 6 – Term and Termination. 142 | 143 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. 144 | 145 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 146 | 147 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 148 | 149 | 2. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 150 | 151 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. 152 | 153 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. 154 | 155 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. 156 | 157 | ### Section 7 – Other Terms and Conditions. 158 | 159 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. 160 | 161 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. 162 | 163 | ### Section 8 – Interpretation. 164 | 165 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. 166 | 167 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. 168 | 169 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. 170 | 171 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. 172 | 173 | ``` 174 | Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. 175 | 176 | Creative Commons may be contacted at [creativecommons.org](http://creativecommons.org/). 177 | ``` 178 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **The code base is no longer maintained.** 2 | 3 | **Please check here for an improved implementation of MUNIT: https://github.com/NVlabs/imaginaire/tree/master/projects/munit** 4 | 5 | [![License CC BY-NC-SA 4.0](https://img.shields.io/badge/license-CC4.0-blue.svg)](https://raw.githubusercontent.com/NVIDIA/FastPhotoStyle/master/LICENSE.md) 6 | ![Python 2.7](https://img.shields.io/badge/python-2.7-green.svg) 7 | ![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg) 8 | ## MUNIT: Multimodal UNsupervised Image-to-image Translation 9 | 10 | ### License 11 | 12 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 13 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 14 | 15 | For commercial use, please consult [NVIDIA Research Inquiries](https://www.nvidia.com/en-us/research/inquiries/). 16 | 17 | ### Code usage 18 | 19 | Please check out the [user manual page](USAGE.md). 20 | 21 | ### Paper 22 | 23 | [Xun Huang](http://www.cs.cornell.edu/~xhuang/), [Ming-Yu Liu](http://mingyuliu.net/), [Serge Belongie](https://vision.cornell.edu/se3/people/serge-belongie/), [Jan Kautz](http://jankautz.com/), "[Multimodal Unsupervised Image-to-Image Translation](https://arxiv.org/abs/1804.04732)", ECCV 2018 24 | 25 | ### Results Video 26 | [![](results/video.jpg)](https://youtu.be/ab64TWzWn40) 27 | 28 | ### Edges to Shoes/handbags Translation 29 | 30 | ![](results/edges2shoes_handbags.jpg) 31 | 32 | ### Animal Image Translation 33 | 34 | ![](results/animal.jpg) 35 | 36 | ### Street Scene Translation 37 | 38 | ![](results/street.jpg) 39 | 40 | ### Yosemite Summer to Winter Translation (HD) 41 | 42 | ![](results/summer2winter_yosemite.jpg) 43 | 44 | ### Example-guided Image Translation 45 | 46 | ![](results/example_guided.jpg) 47 | 48 | ### Other Implementations 49 | 50 | [MUNIT-Tensorflow](https://github.com/taki0112/MUNIT-Tensorflow) by [Junho Kim](https://github.com/taki0112) 51 | 52 | [MUNIT-keras](https://github.com/shaoanlu/MUNIT-keras) by [shaoanlu](https://github.com/shaoanlu) 53 | 54 | ### Citation 55 | 56 | If you find this code useful for your research, please cite our paper: 57 | 58 | ``` 59 | @inproceedings{huang2018munit, 60 | title={Multimodal Unsupervised Image-to-image Translation}, 61 | author={Huang, Xun and Liu, Ming-Yu and Belongie, Serge and Kautz, Jan}, 62 | booktitle={ECCV}, 63 | year={2018} 64 | } 65 | ``` 66 | 67 | 68 | -------------------------------------------------------------------------------- /TUTORIAL.md: -------------------------------------------------------------------------------- 1 | [![License CC BY-NC-SA 4.0](https://img.shields.io/badge/license-CC4.0-blue.svg)](https://raw.githubusercontent.com/NVIDIA/FastPhotoStyle/master/LICENSE.md) 2 | ![Python 2.7](https://img.shields.io/badge/python-2.7-green.svg) 3 | ![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg) 4 | ## MUNIT Tutorial 5 | 6 | In this short tutorial, we will guide you through setting up the system environment for running the MUNIT, which stands for multimodal unsupervised image-to-image translation, software and then show several usage examples. 7 | 8 | ### Background 9 | 10 | Most of the existing unsupervised/unpaired image-to-image translation algorithms assume a unimodal mapping function between two image domains. That is for a given input image in domain A, the model can only map it to one corresponding image in domain B. This is undesired since in many cases the mapping function should be multimodal or many-to-many. For example, for an input summer image, a summer-to-winter translation model should be able to synthesize various winter images that would correspond to the input summer image. These images could differ in amount of snow accumulation but they all represent valid translations of the input image. In the most idea case, given an input image, an image translation model should be able to map the input image to a distribution of output image. This is precisely the goal of MUNIT. 11 | 12 | ### Algorithm 13 | 14 | 15 | 16 | MUNIT is based on the partially-shared latent space assumption as illustrated in (a) of the above image. Basically, it assumes that latent representation of an image can be decomposed into two parts where one represents content of the image that is shared across domains, while the other represents style of the image that is not-shared across domains. To realize this assumption, MUNIT uses 3 networks for each domain, which are 17 | 18 | 1. content encoder (for extracting a domain-shared latent code, content code) 19 | 2. style encoder (for extracting a domain-specific latent code, style code) 20 | 3. decoder (for generating an image using a content code and a style code) 21 | 22 | In the test time as illustrated in (b) of the above image, when we want to translate an input image in the 1st domain (source domain) to a corresponding image in the 2nd domain (target domain). MUNIT first uses the content encoder in the source domain to extract a content codes, combines it with a randomly sampled style code from the target domain, and feed them to the decoder in the target domain to generate the translation. By sampling different style codes, MUNIT generates different translations. Since the style space is a continuous space, MUNIT essentially maps an input image in the source domain to a distribution of images in the target domain. 23 | 24 | ### Requirments 25 | 26 | 27 | - Hardware: PC with NVIDIA Titan GPU. For large resolution images, you need NVIDIA Tesla P100 or V100 GPUs, which have 16GB+ GPU memory. 28 | - Software: *Ubuntu 16.04*, *CUDA 9.1*, *Anaconda3*, *pytorch 0.4.1* 29 | - System package 30 | - `sudo apt-get install -y axel imagemagick` (Only used for demo) 31 | - Python package 32 | - `conda install pytorch=0.4.1 torchvision cuda91 -y -c pytorch` 33 | - `conda install -y -c anaconda pip` 34 | - `conda install -y -c anaconda pyyaml` 35 | - `pip install tensorboard tensorboardX` 36 | 37 | ### Docker Image 38 | 39 | We also provide a [Dockerfile](Dockerfile) for building an environment for running the MUNIT code. 40 | 41 | 1. Install docker-ce. Follow the instruction in the [Docker page](https://docs.docker.com/install/linux/docker-ce/ubuntu/#install-docker-ce-1) 42 | 2. Install nvidia-docker. Follow the instruction in the [NVIDIA-DOCKER README page](https://github.com/NVIDIA/nvidia-docker). 43 | 3. Build the docker image `docker build -t your-docker-image:v1.0 .` 44 | 4. Run an interactive session `docker run -v YOUR_PATH:YOUR_PATH --runtime=nvidia -i -t your-docker-image:v1.0 /bin/bash` 45 | 5. `cd YOUR_PATH` 46 | 6. Follow the rest of the tutorial. 47 | 48 | ### Training 49 | 50 | We provide several training scripts as usage examples. They are located under `scripts` folder. 51 | - `bash scripts/demo_train_edges2handbags.sh` to train a model for multimodal sketches of handbags to images of handbags translation. 52 | - `bash scripts/demo_train_edges2shoes.sh` to train a model for multimodal sketches of shoes to images of shoes translation. 53 | - `bash scripts/demo_train_summer2winter_yosemite256.sh` to train a model for multimodal Yosemite summer 256x256 images to Yosemite winter 256x256 image translation. 54 | 55 | If you break down the command lines in the scripts, you will find that to train a multimodal unsupervised image-to-image translation model you have to do 56 | 57 | 1. Download the dataset you want to use. 58 | 59 | 3. Setup the yaml file. Check out `configs/demo_edges2handbags_folder.yaml` for folder-based dataset organization. Change the `data_root` field to the path of your downloaded dataset. For list-based dataset organization, check out `configs/demo_edges2handbags_list.yaml` 60 | 61 | 3. Start training 62 | ``` 63 | python train.py --config configs/edges2handbags_folder.yaml 64 | ``` 65 | 66 | 4. Intermediate image outputs and model binary files are stored in `outputs/edges2handbags_folder` 67 | 68 | ### Testing 69 | 70 | First, download our pretrained models for the edges2shoes task and put them in `models` folder. 71 | 72 | ### Pretrained models 73 | 74 | | Dataset | Model Link | 75 | |-------------|----------------| 76 | | edges2shoes | [model](https://drive.google.com/drive/folders/10IEa7gibOWmQQuJUIUOkh-CV4cm6k8__?usp=sharing) | 77 | | edges2handbags | coming soon | 78 | | summer2winter_yosemite256 | coming soon | 79 | 80 | 81 | #### Multimodal Translation 82 | 83 | Run the following command to translate edges to shoes 84 | 85 | python test.py --config configs/edges2shoes_folder.yaml --input inputs/edges2shoes_edge.jpg --output_folder results/edges2shoes --checkpoint models/edges2shoes.pt --a2b 1 86 | 87 | The results are stored in `results/edges2shoes` folder. By default, it produces 10 random translation outputs. 88 | 89 | | Input | Translation 1 | Translation 2 | Translation 3 | Translation 4 | Translation 5 | 90 | |-------|---------------|---------------|---------------|---------------|---------------| 91 | | | | | | | | 92 | 93 | 94 | #### Example-guided Translation 95 | 96 | The above command outputs diverse shoes from an edge input. In addition, it is possible to control the style of output using an example shoe image. 97 | 98 | python test.py --config configs/edges2shoes_folder.yaml --input inputs/edges2shoes_edge.jpg --output_folder results --checkpoint models/edges2shoes.pt --a2b 1 --style inputs/edges2shoes_shoe.jpg 99 | 100 | | Input Photo | Style Photo | Output Photo | 101 | |-------|---------------|---------------| 102 | | | | | 103 | 104 | ### Yosemite Summer2Winter HD dataset 105 | 106 | Coming soon. 107 | 108 | 109 | -------------------------------------------------------------------------------- /USAGE.md: -------------------------------------------------------------------------------- 1 | [![License CC BY-NC-SA 4.0](https://img.shields.io/badge/license-CC4.0-blue.svg)](https://raw.githubusercontent.com/NVIDIA/FastPhotoStyle/master/LICENSE.md) 2 | ![Python 2.7](https://img.shields.io/badge/python-2.7-green.svg) 3 | ![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg) 4 | ## MUNIT: Multimodal UNsupervised Image-to-image Translation 5 | 6 | ### License 7 | 8 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 9 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 10 | 11 | ### Dependency 12 | 13 | 14 | pytorch, yaml, tensorboard (from https://github.com/dmlc/tensorboard), and tensorboardX (from https://github.com/lanpa/tensorboard-pytorch). 15 | 16 | 17 | The code base was developed using Anaconda with the following packages. 18 | ``` 19 | conda install pytorch=0.4.1 torchvision cuda91 -c pytorch; 20 | conda install -y -c anaconda pip; 21 | conda install -y -c anaconda pyyaml; 22 | pip install tensorboard tensorboardX; 23 | ``` 24 | 25 | We also provide a [Dockerfile](Dockerfile) for building an environment for running the MUNIT code. 26 | 27 | ### Example Usage 28 | 29 | #### Testing 30 | 31 | First, download the [pretrained models](https://drive.google.com/drive/folders/10IEa7gibOWmQQuJUIUOkh-CV4cm6k8__?usp=sharing) and put them in `models` folder. 32 | 33 | ###### Multimodal Translation 34 | 35 | Run the following command to translate edges to shoes 36 | 37 | python test.py --config configs/edges2shoes_folder.yaml --input inputs/edge.jpg --output_folder outputs --checkpoint models/edges2shoes.pt --a2b 1 38 | 39 | The results are stored in `outputs` folder. By default, it produces 10 random translation outputs. 40 | 41 | ###### Example-guided Translation 42 | 43 | The above command outputs diverse shoes from an edge input. In addition, it is possible to control the style of output using an example shoe image. 44 | 45 | python test.py --config configs/edges2shoes_folder.yaml --input inputs/edge.jpg --output_folder outputs --checkpoint models/edges2shoes.pt --a2b 1 --style inputs/shoe.jpg 46 | 47 | 48 | #### Training 49 | 1. Download the dataset you want to use. For example, you can use the edges2shoes dataset provided by [Zhu et al.](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) 50 | 51 | 3. Setup the yaml file. Check out `configs/edges2handbags_folder.yaml` for folder-based dataset organization. Change the `data_root` field to the path of your downloaded dataset. For list-based dataset organization, check out `configs/edges2handbags_list.yaml` 52 | 53 | 3. Start training 54 | ``` 55 | python train.py --config configs/edges2handbags_folder.yaml 56 | ``` 57 | 58 | 4. Intermediate image outputs and model binary files are stored in `outputs/edges2handbags_folder` 59 | -------------------------------------------------------------------------------- /configs/demo_edges2handbags_folder.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_iter: 1000 # How often do you want to save output images during training 6 | image_display_iter: 100 # How often do you want to display output images during training 7 | display_size: 8 # How many images do you want to display each time 8 | snapshot_save_iter: 10000 # How often do you want to save trained models 9 | log_iter: 1 # How often do you want to log the training stats 10 | 11 | # optimization options 12 | max_iter: 1000000 # maximum number of training iterations 13 | batch_size: 1 # batch size 14 | weight_decay: 0.0001 # weight decay 15 | beta1: 0.5 # Adam parameter 16 | beta2: 0.999 # Adam parameter 17 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] 18 | lr: 0.0001 # initial learning rate 19 | lr_policy: step # learning rate scheduler 20 | step_size: 100000 # how often to decay learning rate 21 | gamma: 0.5 # how much to decay learning rate 22 | gan_w: 1 # weight of adversarial loss 23 | recon_x_w: 10 # weight of image reconstruction loss 24 | recon_s_w: 1 # weight of style reconstruction loss 25 | recon_c_w: 1 # weight of content reconstruction loss 26 | recon_x_cyc_w: 0 # weight of explicit style augmented cycle consistency loss 27 | vgg_w: 0 # weight of domain-invariant perceptual loss 28 | 29 | # model options 30 | gen: 31 | dim: 64 # number of filters in the bottommost layer 32 | mlp_dim: 256 # number of filters in MLP 33 | style_dim: 8 # length of style code 34 | activ: relu # activation function [relu/lrelu/prelu/selu/tanh] 35 | n_downsample: 2 # number of downsampling layers in content encoder 36 | n_res: 4 # number of residual blocks in content encoder/decoder 37 | pad_type: reflect # padding type [zero/reflect] 38 | dis: 39 | dim: 64 # number of filters in the bottommost layer 40 | norm: none # normalization layer [none/bn/in/ln] 41 | activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh] 42 | n_layer: 4 # number of layers in D 43 | gan_type: lsgan # GAN loss [lsgan/nsgan] 44 | num_scales: 3 # number of scales 45 | pad_type: reflect # padding type [zero/reflect] 46 | 47 | # data options 48 | input_dim_a: 3 # number of image channels [1/3] 49 | input_dim_b: 3 # number of image channels [1/3] 50 | num_workers: 8 # number of data loading threads 51 | new_size: 256 # first resize the shortest image side to this size 52 | crop_image_height: 256 # random crop image of this height 53 | crop_image_width: 256 # random crop image of this width 54 | data_root: ./datasets/demo_edges2handbags/ # dataset folder location -------------------------------------------------------------------------------- /configs/demo_edges2handbags_list.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_iter: 1000 # How often do you want to save output images during training 6 | image_display_iter: 100 # How often do you want to display output images during training 7 | display_size: 8 # How many images do you want to display each time 8 | snapshot_save_iter: 10000 # How often do you want to save trained models 9 | log_iter: 1 # How often do you want to log the training stats 10 | 11 | # optimization options 12 | max_iter: 1000000 # maximum number of training iterations 13 | batch_size: 1 # batch size 14 | weight_decay: 0.0001 # weight decay 15 | beta1: 0.5 # Adam parameter 16 | beta2: 0.999 # Adam parameter 17 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] 18 | lr: 0.0001 # initial learning rate 19 | lr_policy: step # learning rate scheduler 20 | step_size: 100000 # how often to decay learning rate 21 | gamma: 0.5 # how much to decay learning rate 22 | gan_w: 1 # weight of adversarial loss 23 | recon_x_w: 10 # weight of image reconstruction loss 24 | recon_s_w: 1 # weight of style reconstruction loss 25 | recon_c_w: 1 # weight of content reconstruction loss 26 | recon_x_cyc_w: 0 # weight of explicit style augmented cycle consistency loss 27 | vgg_w: 0 # weight of domain-invariant perceptual loss 28 | 29 | # model options 30 | gen: 31 | dim: 64 # number of filters in the bottommost layer 32 | mlp_dim: 256 # number of filters in MLP 33 | style_dim: 8 # length of style code 34 | activ: relu # activation function [relu/lrelu/prelu/selu/tanh] 35 | n_downsample: 2 # number of downsampling layers in content encoder 36 | n_res: 4 # number of residual blocks in content encoder/decoder 37 | pad_type: reflect # padding type [zero/reflect] 38 | dis: 39 | dim: 64 # number of filters in the bottommost layer 40 | norm: none # normalization layer [none/bn/in/ln] 41 | activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh] 42 | n_layer: 4 # number of layers in D 43 | gan_type: lsgan # GAN loss [lsgan/nsgan] 44 | num_scales: 3 # number of scales 45 | pad_type: reflect # padding type [zero/reflect] 46 | 47 | # data options 48 | input_dim_a: 3 # number of image channels [1/3] 49 | input_dim_b: 3 # number of image channels [1/3] 50 | num_workers: 8 # number of data loading threads 51 | new_size: 256 # first resize the shortest image side to this size 52 | crop_image_height: 256 # random crop image of this height 53 | crop_image_width: 256 # random crop image of this width 54 | 55 | data_folder_train_a: ./datasets/demo_edges2handbags/trainA 56 | data_list_train_a: ./datasets/demo_edges2handbags/list_trainA.txt 57 | data_folder_test_a: ./datasets/demo_edges2handbags/testA 58 | data_list_test_a: ./datasets/demo_edges2handbags/list_testA.txt 59 | data_folder_train_b: ./datasets/demo_edges2handbags/trainB 60 | data_list_train_b: ./datasets/demo_edges2handbags/list_trainB.txt 61 | data_folder_test_b: ./datasets/demo_edges2handbags/testB 62 | data_list_test_b: ./datasets/demo_edges2handbags/list_testB.txt 63 | -------------------------------------------------------------------------------- /configs/edges2handbags_folder.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_iter: 10000 # How often do you want to save output images during training 6 | image_display_iter: 500 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_iter: 10000 # How often do you want to save trained models 9 | log_iter: 10 # How often do you want to log the training stats 10 | 11 | # optimization options 12 | max_iter: 1000000 # maximum number of training iterations 13 | batch_size: 1 # batch size 14 | weight_decay: 0.0001 # weight decay 15 | beta1: 0.5 # Adam parameter 16 | beta2: 0.999 # Adam parameter 17 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] 18 | lr: 0.0001 # initial learning rate 19 | lr_policy: step # learning rate scheduler 20 | step_size: 100000 # how often to decay learning rate 21 | gamma: 0.5 # how much to decay learning rate 22 | gan_w: 1 # weight of adversarial loss 23 | recon_x_w: 10 # weight of image reconstruction loss 24 | recon_s_w: 1 # weight of style reconstruction loss 25 | recon_c_w: 1 # weight of content reconstruction loss 26 | recon_x_cyc_w: 0 # weight of explicit style augmented cycle consistency loss 27 | vgg_w: 0 # weight of domain-invariant perceptual loss 28 | 29 | # model options 30 | gen: 31 | dim: 64 # number of filters in the bottommost layer 32 | mlp_dim: 256 # number of filters in MLP 33 | style_dim: 8 # length of style code 34 | activ: relu # activation function [relu/lrelu/prelu/selu/tanh] 35 | n_downsample: 2 # number of downsampling layers in content encoder 36 | n_res: 4 # number of residual blocks in content encoder/decoder 37 | pad_type: reflect # padding type [zero/reflect] 38 | dis: 39 | dim: 64 # number of filters in the bottommost layer 40 | norm: none # normalization layer [none/bn/in/ln] 41 | activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh] 42 | n_layer: 4 # number of layers in D 43 | gan_type: lsgan # GAN loss [lsgan/nsgan] 44 | num_scales: 3 # number of scales 45 | pad_type: reflect # padding type [zero/reflect] 46 | 47 | # data options 48 | input_dim_a: 3 # number of image channels [1/3] 49 | input_dim_b: 3 # number of image channels [1/3] 50 | num_workers: 8 # number of data loading threads 51 | new_size: 256 # first resize the shortest image side to this size 52 | crop_image_height: 256 # random crop image of this height 53 | crop_image_width: 256 # random crop image of this width 54 | data_root: ./datasets/edges2handbags/ # dataset folder location -------------------------------------------------------------------------------- /configs/edges2shoes_folder.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_iter: 10000 # How often do you want to save output images during training 6 | image_display_iter: 500 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_iter: 10000 # How often do you want to save trained models 9 | log_iter: 10 # How often do you want to log the training stats 10 | 11 | # optimization options 12 | max_iter: 1000000 # maximum number of training iterations 13 | batch_size: 1 # batch size 14 | weight_decay: 0.0001 # weight decay 15 | beta1: 0.5 # Adam parameter 16 | beta2: 0.999 # Adam parameter 17 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] 18 | lr: 0.0001 # initial learning rate 19 | lr_policy: step # learning rate scheduler 20 | step_size: 100000 # how often to decay learning rate 21 | gamma: 0.5 # how much to decay learning rate 22 | gan_w: 1 # weight of adversarial loss 23 | recon_x_w: 10 # weight of image reconstruction loss 24 | recon_s_w: 1 # weight of style reconstruction loss 25 | recon_c_w: 1 # weight of content reconstruction loss 26 | recon_x_cyc_w: 0 # weight of explicit style augmented cycle consistency loss 27 | vgg_w: 0 # weight of domain-invariant perceptual loss 28 | 29 | # model options 30 | gen: 31 | dim: 64 # number of filters in the bottommost layer 32 | mlp_dim: 256 # number of filters in MLP 33 | style_dim: 8 # length of style code 34 | activ: relu # activation function [relu/lrelu/prelu/selu/tanh] 35 | n_downsample: 2 # number of downsampling layers in content encoder 36 | n_res: 4 # number of residual blocks in content encoder/decoder 37 | pad_type: reflect # padding type [zero/reflect] 38 | dis: 39 | dim: 64 # number of filters in the bottommost layer 40 | norm: none # normalization layer [none/bn/in/ln] 41 | activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh] 42 | n_layer: 4 # number of layers in D 43 | gan_type: lsgan # GAN loss [lsgan/nsgan] 44 | num_scales: 3 # number of scales 45 | pad_type: reflect # padding type [zero/reflect] 46 | 47 | # data options 48 | input_dim_a: 3 # number of image channels [1/3] 49 | input_dim_b: 3 # number of image channels [1/3] 50 | num_workers: 8 # number of data loading threads 51 | new_size: 256 # first resize the shortest image side to this size 52 | crop_image_height: 256 # random crop image of this height 53 | crop_image_width: 256 # random crop image of this width 54 | data_root: ./datasets/edges2shoes/ # dataset folder location -------------------------------------------------------------------------------- /configs/summer2winter_yosemite256_folder.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_iter: 10000 # How often do you want to save output images during training 6 | image_display_iter: 100 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_iter: 10000 # How often do you want to save trained models 9 | log_iter: 1 # How often do you want to log the training stats 10 | 11 | # optimization options 12 | max_iter: 1000000 # maximum number of training iterations 13 | batch_size: 1 # batch size 14 | weight_decay: 0.0001 # weight decay 15 | beta1: 0.5 # Adam parameter 16 | beta2: 0.999 # Adam parameter 17 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] 18 | lr: 0.0001 # initial learning rate 19 | lr_policy: step # learning rate scheduler 20 | step_size: 100000 # how often to decay learning rate 21 | gamma: 0.5 # how much to decay learning rate 22 | gan_w: 1 # weight of adversarial loss 23 | recon_x_w: 10 # weight of image reconstruction loss 24 | recon_s_w: 1 # weight of style reconstruction loss 25 | recon_c_w: 1 # weight of content reconstruction loss 26 | recon_x_cyc_w: 10 # weight of explicit style augmented cycle consistency loss 27 | vgg_w: 0 # weight of domain-invariant perceptual loss 28 | 29 | # model options 30 | gen: 31 | dim: 64 # number of filters in the bottommost layer 32 | mlp_dim: 256 # number of filters in MLP 33 | style_dim: 8 # length of style code 34 | activ: relu # activation function [relu/lrelu/prelu/selu/tanh] 35 | n_downsample: 2 # number of downsampling layers in content encoder 36 | n_res: 4 # number of residual blocks in content encoder/decoder 37 | pad_type: reflect # padding type [zero/reflect] 38 | dis: 39 | dim: 64 # number of filters in the bottommost layer 40 | norm: none # normalization layer [none/bn/in/ln] 41 | activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh] 42 | n_layer: 4 # number of layers in D 43 | gan_type: lsgan # GAN loss [lsgan/nsgan] 44 | num_scales: 3 # number of scales 45 | pad_type: reflect # padding type [zero/reflect] 46 | 47 | # data options 48 | input_dim_a: 3 # number of image channels [1/3] 49 | input_dim_b: 3 # number of image channels [1/3] 50 | num_workers: 8 # number of data loading threads 51 | new_size: 256 # first resize the shortest image side to this size 52 | crop_image_height: 256 # random crop image of this height 53 | crop_image_width: 256 # random crop image of this width 54 | data_root: ./datasets/summer2winter_yosemite256/summer2winter_yosemite # dataset folder location -------------------------------------------------------------------------------- /configs/synthia2cityscape_folder.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_iter: 10000 # How often do you want to save output images during training 6 | image_display_iter: 500 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_iter: 10000 # How often do you want to save trained models 9 | log_iter: 10 # How often do you want to log the training stats 10 | 11 | # optimization options 12 | max_iter: 1000000 # maximum number of training iterations 13 | batch_size: 1 # batch size 14 | weight_decay: 0.0001 # weight decay 15 | beta1: 0.5 # Adam parameter 16 | beta2: 0.999 # Adam parameter 17 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] 18 | lr: 0.0001 # initial learning rate 19 | lr_policy: step # learning rate scheduler 20 | step_size: 100000 # how often to decay learning rate 21 | gamma: 0.5 # how much to decay learning rate 22 | gan_w: 1 # weight of adversarial loss 23 | recon_x_w: 10 # weight of image reconstruction loss 24 | recon_s_w: 1 # weight of style reconstruction loss 25 | recon_c_w: 1 # weight of content reconstruction loss 26 | recon_x_cyc_w: 10 # weight of explicit style augmented cycle consistency loss 27 | vgg_w: 1 # weight of domain-invariant perceptual loss 28 | 29 | # model options 30 | gen: 31 | dim: 64 # number of filters in the bottommost layer 32 | mlp_dim: 256 # number of filters in MLP 33 | style_dim: 8 # length of style code 34 | activ: relu # activation function [relu/lrelu/prelu/selu/tanh] 35 | n_downsample: 2 # number of downsampling layers in content encoder 36 | n_res: 4 # number of residual blocks in content encoder/decoder 37 | pad_type: reflect # padding type [zero/reflect] 38 | dis: 39 | dim: 64 # number of filters in the bottommost layer 40 | norm: none # normalization layer [none/bn/in/ln] 41 | activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh] 42 | n_layer: 4 # number of layers in D 43 | gan_type: lsgan # GAN loss [lsgan/nsgan] 44 | num_scales: 3 # number of scales 45 | pad_type: reflect # padding type [zero/reflect] 46 | 47 | # data options 48 | input_dim_a: 3 # number of image channels [1/3] 49 | input_dim_b: 3 # number of image channels [1/3] 50 | num_workers: 8 # number of data loading threads 51 | new_size: 512 # first resize the shortest image side to this size 52 | crop_image_height: 512 # random crop image of this height 53 | crop_image_width: 512 # random crop image of this width 54 | data_root: ./datasets/synthia2cityscape/ # dataset folder location 55 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | import torch.utils.data as data 6 | import os.path 7 | 8 | def default_loader(path): 9 | return Image.open(path).convert('RGB') 10 | 11 | 12 | def default_flist_reader(flist): 13 | """ 14 | flist format: impath label\nimpath label\n ...(same to caffe's filelist) 15 | """ 16 | imlist = [] 17 | with open(flist, 'r') as rf: 18 | for line in rf.readlines(): 19 | impath = line.strip() 20 | imlist.append(impath) 21 | 22 | return imlist 23 | 24 | 25 | class ImageFilelist(data.Dataset): 26 | def __init__(self, root, flist, transform=None, 27 | flist_reader=default_flist_reader, loader=default_loader): 28 | self.root = root 29 | self.imlist = flist_reader(flist) 30 | self.transform = transform 31 | self.loader = loader 32 | 33 | def __getitem__(self, index): 34 | impath = self.imlist[index] 35 | img = self.loader(os.path.join(self.root, impath)) 36 | if self.transform is not None: 37 | img = self.transform(img) 38 | 39 | return img 40 | 41 | def __len__(self): 42 | return len(self.imlist) 43 | 44 | 45 | class ImageLabelFilelist(data.Dataset): 46 | def __init__(self, root, flist, transform=None, 47 | flist_reader=default_flist_reader, loader=default_loader): 48 | self.root = root 49 | self.imlist = flist_reader(os.path.join(self.root, flist)) 50 | self.transform = transform 51 | self.loader = loader 52 | self.classes = sorted(list(set([path.split('/')[0] for path in self.imlist]))) 53 | self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))} 54 | self.imgs = [(impath, self.class_to_idx[impath.split('/')[0]]) for impath in self.imlist] 55 | 56 | def __getitem__(self, index): 57 | impath, label = self.imgs[index] 58 | img = self.loader(os.path.join(self.root, impath)) 59 | if self.transform is not None: 60 | img = self.transform(img) 61 | return img, label 62 | 63 | def __len__(self): 64 | return len(self.imgs) 65 | 66 | ############################################################################### 67 | # Code from 68 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 69 | # Modified the original code so that it also loads images from the current 70 | # directory as well as the subdirectories 71 | ############################################################################### 72 | 73 | import torch.utils.data as data 74 | 75 | from PIL import Image 76 | import os 77 | import os.path 78 | 79 | IMG_EXTENSIONS = [ 80 | '.jpg', '.JPG', '.jpeg', '.JPEG', 81 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 82 | ] 83 | 84 | 85 | def is_image_file(filename): 86 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 87 | 88 | 89 | def make_dataset(dir): 90 | images = [] 91 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 92 | 93 | for root, _, fnames in sorted(os.walk(dir)): 94 | for fname in fnames: 95 | if is_image_file(fname): 96 | path = os.path.join(root, fname) 97 | images.append(path) 98 | 99 | return images 100 | 101 | 102 | class ImageFolder(data.Dataset): 103 | 104 | def __init__(self, root, transform=None, return_paths=False, 105 | loader=default_loader): 106 | imgs = sorted(make_dataset(root)) 107 | if len(imgs) == 0: 108 | raise(RuntimeError("Found 0 images in: " + root + "\n" 109 | "Supported image extensions are: " + 110 | ",".join(IMG_EXTENSIONS))) 111 | 112 | self.root = root 113 | self.imgs = imgs 114 | self.transform = transform 115 | self.return_paths = return_paths 116 | self.loader = loader 117 | 118 | def __getitem__(self, index): 119 | path = self.imgs[index] 120 | img = self.loader(path) 121 | if self.transform is not None: 122 | img = self.transform(img) 123 | if self.return_paths: 124 | return img, path 125 | else: 126 | return img 127 | 128 | def __len__(self): 129 | return len(self.imgs) 130 | -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/list_testA.txt: -------------------------------------------------------------------------------- 1 | ./00002.jpg 2 | ./00006.jpg 3 | ./00004.jpg 4 | ./00000.jpg 5 | ./00007.jpg 6 | ./00001.jpg 7 | ./00003.jpg 8 | ./00005.jpg 9 | ./00008.jpg 10 | -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/list_testB.txt: -------------------------------------------------------------------------------- 1 | ./00002.jpg 2 | ./00006.jpg 3 | ./00004.jpg 4 | ./00000.jpg 5 | ./00007.jpg 6 | ./00001.jpg 7 | ./00003.jpg 8 | ./00005.jpg 9 | ./00008.jpg 10 | -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/list_trainA.txt: -------------------------------------------------------------------------------- 1 | ./00002.jpg 2 | ./00006.jpg 3 | ./00004.jpg 4 | ./00000.jpg 5 | ./00010.jpg 6 | ./00014.jpg 7 | ./00008.jpg 8 | ./00012.jpg 9 | -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/list_trainB.txt: -------------------------------------------------------------------------------- 1 | ./00009.jpg 2 | ./00007.jpg 3 | ./00001.jpg 4 | ./00003.jpg 5 | ./00013.jpg 6 | ./00005.jpg 7 | ./00011.jpg 8 | ./00015.jpg 9 | -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/testA/00000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testA/00000.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/testA/00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testA/00001.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/testA/00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testA/00002.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/testA/00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testA/00003.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/testA/00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testA/00004.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/testA/00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testA/00005.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/testA/00006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testA/00006.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/testA/00007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testA/00007.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/testA/00008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testA/00008.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/testB/00000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testB/00000.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/testB/00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testB/00001.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/testB/00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testB/00002.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/testB/00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testB/00003.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/testB/00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testB/00004.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/testB/00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testB/00005.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/testB/00006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testB/00006.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/testB/00007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testB/00007.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/testB/00008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testB/00008.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/trainA/00000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainA/00000.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/trainA/00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainA/00002.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/trainA/00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainA/00004.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/trainA/00006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainA/00006.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/trainA/00008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainA/00008.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/trainA/00010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainA/00010.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/trainA/00012.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainA/00012.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/trainA/00014.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainA/00014.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/trainB/00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainB/00001.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/trainB/00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainB/00003.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/trainB/00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainB/00005.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/trainB/00007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainB/00007.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/trainB/00009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainB/00009.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/trainB/00011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainB/00011.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/trainB/00013.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainB/00013.jpg -------------------------------------------------------------------------------- /datasets/demo_edges2handbags/trainB/00015.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainB/00015.jpg -------------------------------------------------------------------------------- /docs/munit_assumption.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/docs/munit_assumption.jpg -------------------------------------------------------------------------------- /inputs/edges2handbags_edge.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/inputs/edges2handbags_edge.jpg -------------------------------------------------------------------------------- /inputs/edges2handbags_handbag.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/inputs/edges2handbags_handbag.jpg -------------------------------------------------------------------------------- /inputs/edges2shoes_edge.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/inputs/edges2shoes_edge.jpg -------------------------------------------------------------------------------- /inputs/edges2shoes_shoe.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/inputs/edges2shoes_shoe.jpg -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | from torch import nn 6 | from torch.autograd import Variable 7 | import torch 8 | import torch.nn.functional as F 9 | try: 10 | from itertools import izip as zip 11 | except ImportError: # will be 3.x series 12 | pass 13 | 14 | ################################################################################## 15 | # Discriminator 16 | ################################################################################## 17 | 18 | class MsImageDis(nn.Module): 19 | # Multi-scale discriminator architecture 20 | def __init__(self, input_dim, params): 21 | super(MsImageDis, self).__init__() 22 | self.n_layer = params['n_layer'] 23 | self.gan_type = params['gan_type'] 24 | self.dim = params['dim'] 25 | self.norm = params['norm'] 26 | self.activ = params['activ'] 27 | self.num_scales = params['num_scales'] 28 | self.pad_type = params['pad_type'] 29 | self.input_dim = input_dim 30 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) 31 | self.cnns = nn.ModuleList() 32 | for _ in range(self.num_scales): 33 | self.cnns.append(self._make_net()) 34 | 35 | def _make_net(self): 36 | dim = self.dim 37 | cnn_x = [] 38 | cnn_x += [Conv2dBlock(self.input_dim, dim, 4, 2, 1, norm='none', activation=self.activ, pad_type=self.pad_type)] 39 | for i in range(self.n_layer - 1): 40 | cnn_x += [Conv2dBlock(dim, dim * 2, 4, 2, 1, norm=self.norm, activation=self.activ, pad_type=self.pad_type)] 41 | dim *= 2 42 | cnn_x += [nn.Conv2d(dim, 1, 1, 1, 0)] 43 | cnn_x = nn.Sequential(*cnn_x) 44 | return cnn_x 45 | 46 | def forward(self, x): 47 | outputs = [] 48 | for model in self.cnns: 49 | outputs.append(model(x)) 50 | x = self.downsample(x) 51 | return outputs 52 | 53 | def calc_dis_loss(self, input_fake, input_real): 54 | # calculate the loss to train D 55 | outs0 = self.forward(input_fake) 56 | outs1 = self.forward(input_real) 57 | loss = 0 58 | 59 | for it, (out0, out1) in enumerate(zip(outs0, outs1)): 60 | if self.gan_type == 'lsgan': 61 | loss += torch.mean((out0 - 0)**2) + torch.mean((out1 - 1)**2) 62 | elif self.gan_type == 'nsgan': 63 | all0 = Variable(torch.zeros_like(out0.data).cuda(), requires_grad=False) 64 | all1 = Variable(torch.ones_like(out1.data).cuda(), requires_grad=False) 65 | loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all0) + 66 | F.binary_cross_entropy(F.sigmoid(out1), all1)) 67 | else: 68 | assert 0, "Unsupported GAN type: {}".format(self.gan_type) 69 | return loss 70 | 71 | def calc_gen_loss(self, input_fake): 72 | # calculate the loss to train G 73 | outs0 = self.forward(input_fake) 74 | loss = 0 75 | for it, (out0) in enumerate(outs0): 76 | if self.gan_type == 'lsgan': 77 | loss += torch.mean((out0 - 1)**2) # LSGAN 78 | elif self.gan_type == 'nsgan': 79 | all1 = Variable(torch.ones_like(out0.data).cuda(), requires_grad=False) 80 | loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all1)) 81 | else: 82 | assert 0, "Unsupported GAN type: {}".format(self.gan_type) 83 | return loss 84 | 85 | ################################################################################## 86 | # Generator 87 | ################################################################################## 88 | 89 | class AdaINGen(nn.Module): 90 | # AdaIN auto-encoder architecture 91 | def __init__(self, input_dim, params): 92 | super(AdaINGen, self).__init__() 93 | dim = params['dim'] 94 | style_dim = params['style_dim'] 95 | n_downsample = params['n_downsample'] 96 | n_res = params['n_res'] 97 | activ = params['activ'] 98 | pad_type = params['pad_type'] 99 | mlp_dim = params['mlp_dim'] 100 | 101 | # style encoder 102 | self.enc_style = StyleEncoder(4, input_dim, dim, style_dim, norm='none', activ=activ, pad_type=pad_type) 103 | 104 | # content encoder 105 | self.enc_content = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type) 106 | self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, input_dim, res_norm='adain', activ=activ, pad_type=pad_type) 107 | 108 | # MLP to generate AdaIN parameters 109 | self.mlp = MLP(style_dim, self.get_num_adain_params(self.dec), mlp_dim, 3, norm='none', activ=activ) 110 | 111 | def forward(self, images): 112 | # reconstruct an image 113 | content, style_fake = self.encode(images) 114 | images_recon = self.decode(content, style_fake) 115 | return images_recon 116 | 117 | def encode(self, images): 118 | # encode an image to its content and style codes 119 | style_fake = self.enc_style(images) 120 | content = self.enc_content(images) 121 | return content, style_fake 122 | 123 | def decode(self, content, style): 124 | # decode content and style codes to an image 125 | adain_params = self.mlp(style) 126 | self.assign_adain_params(adain_params, self.dec) 127 | images = self.dec(content) 128 | return images 129 | 130 | def assign_adain_params(self, adain_params, model): 131 | # assign the adain_params to the AdaIN layers in model 132 | for m in model.modules(): 133 | if m.__class__.__name__ == "AdaptiveInstanceNorm2d": 134 | mean = adain_params[:, :m.num_features] 135 | std = adain_params[:, m.num_features:2*m.num_features] 136 | m.bias = mean.contiguous().view(-1) 137 | m.weight = std.contiguous().view(-1) 138 | if adain_params.size(1) > 2*m.num_features: 139 | adain_params = adain_params[:, 2*m.num_features:] 140 | 141 | def get_num_adain_params(self, model): 142 | # return the number of AdaIN parameters needed by the model 143 | num_adain_params = 0 144 | for m in model.modules(): 145 | if m.__class__.__name__ == "AdaptiveInstanceNorm2d": 146 | num_adain_params += 2*m.num_features 147 | return num_adain_params 148 | 149 | 150 | class VAEGen(nn.Module): 151 | # VAE architecture 152 | def __init__(self, input_dim, params): 153 | super(VAEGen, self).__init__() 154 | dim = params['dim'] 155 | n_downsample = params['n_downsample'] 156 | n_res = params['n_res'] 157 | activ = params['activ'] 158 | pad_type = params['pad_type'] 159 | 160 | # content encoder 161 | self.enc = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type) 162 | self.dec = Decoder(n_downsample, n_res, self.enc.output_dim, input_dim, res_norm='in', activ=activ, pad_type=pad_type) 163 | 164 | def forward(self, images): 165 | # This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones. 166 | hiddens = self.encode(images) 167 | if self.training == True: 168 | noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device())) 169 | images_recon = self.decode(hiddens + noise) 170 | else: 171 | images_recon = self.decode(hiddens) 172 | return images_recon, hiddens 173 | 174 | def encode(self, images): 175 | hiddens = self.enc(images) 176 | noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device())) 177 | return hiddens, noise 178 | 179 | def decode(self, hiddens): 180 | images = self.dec(hiddens) 181 | return images 182 | 183 | 184 | ################################################################################## 185 | # Encoder and Decoders 186 | ################################################################################## 187 | 188 | class StyleEncoder(nn.Module): 189 | def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type): 190 | super(StyleEncoder, self).__init__() 191 | self.model = [] 192 | self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] 193 | for i in range(2): 194 | self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] 195 | dim *= 2 196 | for i in range(n_downsample - 2): 197 | self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] 198 | self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling 199 | self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)] 200 | self.model = nn.Sequential(*self.model) 201 | self.output_dim = dim 202 | 203 | def forward(self, x): 204 | return self.model(x) 205 | 206 | class ContentEncoder(nn.Module): 207 | def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): 208 | super(ContentEncoder, self).__init__() 209 | self.model = [] 210 | self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] 211 | # downsampling blocks 212 | for i in range(n_downsample): 213 | self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] 214 | dim *= 2 215 | # residual blocks 216 | self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)] 217 | self.model = nn.Sequential(*self.model) 218 | self.output_dim = dim 219 | 220 | def forward(self, x): 221 | return self.model(x) 222 | 223 | class Decoder(nn.Module): 224 | def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='relu', pad_type='zero'): 225 | super(Decoder, self).__init__() 226 | 227 | self.model = [] 228 | # AdaIN residual blocks 229 | self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)] 230 | # upsampling blocks 231 | for i in range(n_upsample): 232 | self.model += [nn.Upsample(scale_factor=2), 233 | Conv2dBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)] 234 | dim //= 2 235 | # use reflection padding in the last conv layer 236 | self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)] 237 | self.model = nn.Sequential(*self.model) 238 | 239 | def forward(self, x): 240 | return self.model(x) 241 | 242 | ################################################################################## 243 | # Sequential Models 244 | ################################################################################## 245 | class ResBlocks(nn.Module): 246 | def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'): 247 | super(ResBlocks, self).__init__() 248 | self.model = [] 249 | for i in range(num_blocks): 250 | self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)] 251 | self.model = nn.Sequential(*self.model) 252 | 253 | def forward(self, x): 254 | return self.model(x) 255 | 256 | class MLP(nn.Module): 257 | def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'): 258 | 259 | super(MLP, self).__init__() 260 | self.model = [] 261 | self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)] 262 | for i in range(n_blk - 2): 263 | self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)] 264 | self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations 265 | self.model = nn.Sequential(*self.model) 266 | 267 | def forward(self, x): 268 | return self.model(x.view(x.size(0), -1)) 269 | 270 | ################################################################################## 271 | # Basic Blocks 272 | ################################################################################## 273 | class ResBlock(nn.Module): 274 | def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): 275 | super(ResBlock, self).__init__() 276 | 277 | model = [] 278 | model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] 279 | model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] 280 | self.model = nn.Sequential(*model) 281 | 282 | def forward(self, x): 283 | residual = x 284 | out = self.model(x) 285 | out += residual 286 | return out 287 | 288 | class Conv2dBlock(nn.Module): 289 | def __init__(self, input_dim ,output_dim, kernel_size, stride, 290 | padding=0, norm='none', activation='relu', pad_type='zero'): 291 | super(Conv2dBlock, self).__init__() 292 | self.use_bias = True 293 | # initialize padding 294 | if pad_type == 'reflect': 295 | self.pad = nn.ReflectionPad2d(padding) 296 | elif pad_type == 'replicate': 297 | self.pad = nn.ReplicationPad2d(padding) 298 | elif pad_type == 'zero': 299 | self.pad = nn.ZeroPad2d(padding) 300 | else: 301 | assert 0, "Unsupported padding type: {}".format(pad_type) 302 | 303 | # initialize normalization 304 | norm_dim = output_dim 305 | if norm == 'bn': 306 | self.norm = nn.BatchNorm2d(norm_dim) 307 | elif norm == 'in': 308 | #self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True) 309 | self.norm = nn.InstanceNorm2d(norm_dim) 310 | elif norm == 'ln': 311 | self.norm = LayerNorm(norm_dim) 312 | elif norm == 'adain': 313 | self.norm = AdaptiveInstanceNorm2d(norm_dim) 314 | elif norm == 'none' or norm == 'sn': 315 | self.norm = None 316 | else: 317 | assert 0, "Unsupported normalization: {}".format(norm) 318 | 319 | # initialize activation 320 | if activation == 'relu': 321 | self.activation = nn.ReLU(inplace=True) 322 | elif activation == 'lrelu': 323 | self.activation = nn.LeakyReLU(0.2, inplace=True) 324 | elif activation == 'prelu': 325 | self.activation = nn.PReLU() 326 | elif activation == 'selu': 327 | self.activation = nn.SELU(inplace=True) 328 | elif activation == 'tanh': 329 | self.activation = nn.Tanh() 330 | elif activation == 'none': 331 | self.activation = None 332 | else: 333 | assert 0, "Unsupported activation: {}".format(activation) 334 | 335 | # initialize convolution 336 | if norm == 'sn': 337 | self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)) 338 | else: 339 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) 340 | 341 | def forward(self, x): 342 | x = self.conv(self.pad(x)) 343 | if self.norm: 344 | x = self.norm(x) 345 | if self.activation: 346 | x = self.activation(x) 347 | return x 348 | 349 | class LinearBlock(nn.Module): 350 | def __init__(self, input_dim, output_dim, norm='none', activation='relu'): 351 | super(LinearBlock, self).__init__() 352 | use_bias = True 353 | # initialize fully connected layer 354 | if norm == 'sn': 355 | self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias)) 356 | else: 357 | self.fc = nn.Linear(input_dim, output_dim, bias=use_bias) 358 | 359 | # initialize normalization 360 | norm_dim = output_dim 361 | if norm == 'bn': 362 | self.norm = nn.BatchNorm1d(norm_dim) 363 | elif norm == 'in': 364 | self.norm = nn.InstanceNorm1d(norm_dim) 365 | elif norm == 'ln': 366 | self.norm = LayerNorm(norm_dim) 367 | elif norm == 'none' or norm == 'sn': 368 | self.norm = None 369 | else: 370 | assert 0, "Unsupported normalization: {}".format(norm) 371 | 372 | # initialize activation 373 | if activation == 'relu': 374 | self.activation = nn.ReLU(inplace=True) 375 | elif activation == 'lrelu': 376 | self.activation = nn.LeakyReLU(0.2, inplace=True) 377 | elif activation == 'prelu': 378 | self.activation = nn.PReLU() 379 | elif activation == 'selu': 380 | self.activation = nn.SELU(inplace=True) 381 | elif activation == 'tanh': 382 | self.activation = nn.Tanh() 383 | elif activation == 'none': 384 | self.activation = None 385 | else: 386 | assert 0, "Unsupported activation: {}".format(activation) 387 | 388 | def forward(self, x): 389 | out = self.fc(x) 390 | if self.norm: 391 | out = self.norm(out) 392 | if self.activation: 393 | out = self.activation(out) 394 | return out 395 | 396 | ################################################################################## 397 | # VGG network definition 398 | ################################################################################## 399 | class Vgg16(nn.Module): 400 | def __init__(self): 401 | super(Vgg16, self).__init__() 402 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 403 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 404 | 405 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) 406 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 407 | 408 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) 409 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 410 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 411 | 412 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) 413 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 414 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 415 | 416 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 417 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 418 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 419 | 420 | def forward(self, X): 421 | h = F.relu(self.conv1_1(X), inplace=True) 422 | h = F.relu(self.conv1_2(h), inplace=True) 423 | # relu1_2 = h 424 | h = F.max_pool2d(h, kernel_size=2, stride=2) 425 | 426 | h = F.relu(self.conv2_1(h), inplace=True) 427 | h = F.relu(self.conv2_2(h), inplace=True) 428 | # relu2_2 = h 429 | h = F.max_pool2d(h, kernel_size=2, stride=2) 430 | 431 | h = F.relu(self.conv3_1(h), inplace=True) 432 | h = F.relu(self.conv3_2(h), inplace=True) 433 | h = F.relu(self.conv3_3(h), inplace=True) 434 | # relu3_3 = h 435 | h = F.max_pool2d(h, kernel_size=2, stride=2) 436 | 437 | h = F.relu(self.conv4_1(h), inplace=True) 438 | h = F.relu(self.conv4_2(h), inplace=True) 439 | h = F.relu(self.conv4_3(h), inplace=True) 440 | # relu4_3 = h 441 | 442 | h = F.relu(self.conv5_1(h), inplace=True) 443 | h = F.relu(self.conv5_2(h), inplace=True) 444 | h = F.relu(self.conv5_3(h), inplace=True) 445 | relu5_3 = h 446 | 447 | return relu5_3 448 | # return [relu1_2, relu2_2, relu3_3, relu4_3] 449 | 450 | ################################################################################## 451 | # Normalization layers 452 | ################################################################################## 453 | class AdaptiveInstanceNorm2d(nn.Module): 454 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 455 | super(AdaptiveInstanceNorm2d, self).__init__() 456 | self.num_features = num_features 457 | self.eps = eps 458 | self.momentum = momentum 459 | # weight and bias are dynamically assigned 460 | self.weight = None 461 | self.bias = None 462 | # just dummy buffers, not used 463 | self.register_buffer('running_mean', torch.zeros(num_features)) 464 | self.register_buffer('running_var', torch.ones(num_features)) 465 | 466 | def forward(self, x): 467 | assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!" 468 | b, c = x.size(0), x.size(1) 469 | running_mean = self.running_mean.repeat(b) 470 | running_var = self.running_var.repeat(b) 471 | 472 | # Apply instance norm 473 | x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) 474 | 475 | out = F.batch_norm( 476 | x_reshaped, running_mean, running_var, self.weight, self.bias, 477 | True, self.momentum, self.eps) 478 | 479 | return out.view(b, c, *x.size()[2:]) 480 | 481 | def __repr__(self): 482 | return self.__class__.__name__ + '(' + str(self.num_features) + ')' 483 | 484 | 485 | class LayerNorm(nn.Module): 486 | def __init__(self, num_features, eps=1e-5, affine=True): 487 | super(LayerNorm, self).__init__() 488 | self.num_features = num_features 489 | self.affine = affine 490 | self.eps = eps 491 | 492 | if self.affine: 493 | self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) 494 | self.beta = nn.Parameter(torch.zeros(num_features)) 495 | 496 | def forward(self, x): 497 | shape = [-1] + [1] * (x.dim() - 1) 498 | # print(x.size()) 499 | if x.size(0) == 1: 500 | # These two lines run much faster in pytorch 0.4 than the two lines listed below. 501 | mean = x.view(-1).mean().view(*shape) 502 | std = x.view(-1).std().view(*shape) 503 | else: 504 | mean = x.view(x.size(0), -1).mean(1).view(*shape) 505 | std = x.view(x.size(0), -1).std(1).view(*shape) 506 | 507 | x = (x - mean) / (std + self.eps) 508 | 509 | if self.affine: 510 | shape = [1, -1] + [1] * (x.dim() - 2) 511 | x = x * self.gamma.view(*shape) + self.beta.view(*shape) 512 | return x 513 | 514 | def l2normalize(v, eps=1e-12): 515 | return v / (v.norm() + eps) 516 | 517 | 518 | class SpectralNorm(nn.Module): 519 | """ 520 | Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida 521 | and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan 522 | """ 523 | def __init__(self, module, name='weight', power_iterations=1): 524 | super(SpectralNorm, self).__init__() 525 | self.module = module 526 | self.name = name 527 | self.power_iterations = power_iterations 528 | if not self._made_params(): 529 | self._make_params() 530 | 531 | def _update_u_v(self): 532 | u = getattr(self.module, self.name + "_u") 533 | v = getattr(self.module, self.name + "_v") 534 | w = getattr(self.module, self.name + "_bar") 535 | 536 | height = w.data.shape[0] 537 | for _ in range(self.power_iterations): 538 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data)) 539 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data)) 540 | 541 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 542 | sigma = u.dot(w.view(height, -1).mv(v)) 543 | setattr(self.module, self.name, w / sigma.expand_as(w)) 544 | 545 | def _made_params(self): 546 | try: 547 | u = getattr(self.module, self.name + "_u") 548 | v = getattr(self.module, self.name + "_v") 549 | w = getattr(self.module, self.name + "_bar") 550 | return True 551 | except AttributeError: 552 | return False 553 | 554 | 555 | def _make_params(self): 556 | w = getattr(self.module, self.name) 557 | 558 | height = w.data.shape[0] 559 | width = w.view(height, -1).data.shape[1] 560 | 561 | u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 562 | v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 563 | u.data = l2normalize(u.data) 564 | v.data = l2normalize(v.data) 565 | w_bar = nn.Parameter(w.data) 566 | 567 | del self.module._parameters[self.name] 568 | 569 | self.module.register_parameter(self.name + "_u", u) 570 | self.module.register_parameter(self.name + "_v", v) 571 | self.module.register_parameter(self.name + "_bar", w_bar) 572 | 573 | 574 | def forward(self, *args): 575 | self._update_u_v() 576 | return self.module.forward(*args) -------------------------------------------------------------------------------- /results/animal.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/animal.jpg -------------------------------------------------------------------------------- /results/edges2handbags/input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/input.jpg -------------------------------------------------------------------------------- /results/edges2handbags/output000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/output000.jpg -------------------------------------------------------------------------------- /results/edges2handbags/output001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/output001.jpg -------------------------------------------------------------------------------- /results/edges2handbags/output002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/output002.jpg -------------------------------------------------------------------------------- /results/edges2handbags/output003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/output003.jpg -------------------------------------------------------------------------------- /results/edges2handbags/output004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/output004.jpg -------------------------------------------------------------------------------- /results/edges2handbags/output005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/output005.jpg -------------------------------------------------------------------------------- /results/edges2handbags/output006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/output006.jpg -------------------------------------------------------------------------------- /results/edges2handbags/output007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/output007.jpg -------------------------------------------------------------------------------- /results/edges2handbags/output008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/output008.jpg -------------------------------------------------------------------------------- /results/edges2handbags/output009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/output009.jpg -------------------------------------------------------------------------------- /results/edges2shoes/input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/input.jpg -------------------------------------------------------------------------------- /results/edges2shoes/output000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/output000.jpg -------------------------------------------------------------------------------- /results/edges2shoes/output001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/output001.jpg -------------------------------------------------------------------------------- /results/edges2shoes/output002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/output002.jpg -------------------------------------------------------------------------------- /results/edges2shoes/output003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/output003.jpg -------------------------------------------------------------------------------- /results/edges2shoes/output004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/output004.jpg -------------------------------------------------------------------------------- /results/edges2shoes/output005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/output005.jpg -------------------------------------------------------------------------------- /results/edges2shoes/output006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/output006.jpg -------------------------------------------------------------------------------- /results/edges2shoes/output007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/output007.jpg -------------------------------------------------------------------------------- /results/edges2shoes/output008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/output008.jpg -------------------------------------------------------------------------------- /results/edges2shoes/output009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/output009.jpg -------------------------------------------------------------------------------- /results/edges2shoes_handbags.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes_handbags.jpg -------------------------------------------------------------------------------- /results/example_guided.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/example_guided.jpg -------------------------------------------------------------------------------- /results/input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/input.jpg -------------------------------------------------------------------------------- /results/output000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/output000.jpg -------------------------------------------------------------------------------- /results/street.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/street.jpg -------------------------------------------------------------------------------- /results/summer2winter_yosemite.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/summer2winter_yosemite.jpg -------------------------------------------------------------------------------- /results/video.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/video.jpg -------------------------------------------------------------------------------- /scripts/demo_train_edges2handbags.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | rm datasets/edges2handbags -rf 3 | mkdir datasets/edges2handbags -p 4 | axel -n 1 https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/edges2handbags.tar.gz --output=datasets/edges2handbags/edges2handbags.tar.gz 5 | tar -zxvf datasets/edges2handbags/edges2handbags.tar.gz -C datasets/ 6 | mkdir datasets/edges2handbags/train1 -p 7 | mkdir datasets/edges2handbags/train0 -p 8 | mkdir datasets/edges2handbags/test1 -p 9 | mkdir datasets/edges2handbags/test0 -p 10 | for f in datasets/edges2handbags/train/*; do convert -quality 100 -crop 50%x100% +repage $f datasets/edges2handbags/train%d/${f##*/}; done; 11 | for f in datasets/edges2handbags/val/*; do convert -quality 100 -crop 50%x100% +repage $f datasets/edges2handbags/test%d/${f##*/}; done; 12 | mv datasets/edges2handbags/train0 datasets/edges2handbags/trainA 13 | mv datasets/edges2handbags/train1 datasets/edges2handbags/trainB 14 | mv datasets/edges2handbags/test0 datasets/edges2handbags/testA 15 | mv datasets/edges2handbags/test1 datasets/edges2handbags/testB 16 | python train.py --config configs/edges2handbags_folder.yaml 17 | 18 | -------------------------------------------------------------------------------- /scripts/demo_train_edges2shoes.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | rm datasets/edges2shoes -rf 3 | mkdir datasets/edges2shoes -p 4 | axel -n 1 https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/edges2shoes.tar.gz --output=datasets/edges2shoes/edges2shoes.tar.gz 5 | tar -zxvf datasets/edges2shoes/edges2shoes.tar.gz -C datasets 6 | mkdir datasets/edges2shoes/train1 -p 7 | mkdir datasets/edges2shoes/train0 -p 8 | mkdir datasets/edges2shoes/test1 -p 9 | mkdir datasets/edges2shoes/test0 -p 10 | for f in datasets/edges2shoes/train/*; do convert -quality 100 -crop 50%x100% +repage $f datasets/edges2shoes/train%d/${f##*/}; done; 11 | for f in datasets/edges2shoes/val/*; do convert -quality 100 -crop 50%x100% +repage $f datasets/edges2shoes/test%d/${f##*/}; done; 12 | mv datasets/edges2shoes/train0 datasets/edges2shoes/trainA 13 | mv datasets/edges2shoes/train1 datasets/edges2shoes/trainB 14 | mv datasets/edges2shoes/test0 datasets/edges2shoes/testA 15 | mv datasets/edges2shoes/test1 datasets/edges2shoes/testB 16 | python train.py --config configs/edges2shoes_folder.yaml 17 | -------------------------------------------------------------------------------- /scripts/demo_train_summer2winter_yosemite256.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | rm datasets/summer2winter_yosemite256 -p 3 | mkdir datasets/summer2winter_yosemite256 -p 4 | axel -n 1 https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/summer2winter_yosemite.zip --output=datasets/summer2winter_yosemite256/summer2winter_yosemite.zip 5 | unzip datasets/summer2winter_yosemite256/summer2winter_yosemite.zip -d datasets/summer2winter_yosemite256 6 | python train.py --config configs/summer2winter_yosemite256_folder.yaml 7 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | from __future__ import print_function 6 | from utils import get_config, pytorch03_to_pytorch04 7 | from trainer import MUNIT_Trainer, UNIT_Trainer 8 | import argparse 9 | from torch.autograd import Variable 10 | import torchvision.utils as vutils 11 | import sys 12 | import torch 13 | import os 14 | from torchvision import transforms 15 | from PIL import Image 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--config', type=str, help="net configuration") 19 | parser.add_argument('--input', type=str, help="input image path") 20 | parser.add_argument('--output_folder', type=str, help="output image path") 21 | parser.add_argument('--checkpoint', type=str, help="checkpoint of autoencoders") 22 | parser.add_argument('--style', type=str, default='', help="style image path") 23 | parser.add_argument('--a2b', type=int, default=1, help="1 for a2b and 0 for b2a") 24 | parser.add_argument('--seed', type=int, default=10, help="random seed") 25 | parser.add_argument('--num_style',type=int, default=10, help="number of styles to sample") 26 | parser.add_argument('--synchronized', action='store_true', help="whether use synchronized style code or not") 27 | parser.add_argument('--output_only', action='store_true', help="whether use synchronized style code or not") 28 | parser.add_argument('--output_path', type=str, default='.', help="path for logs, checkpoints, and VGG model weight") 29 | parser.add_argument('--trainer', type=str, default='MUNIT', help="MUNIT|UNIT") 30 | opts = parser.parse_args() 31 | 32 | 33 | 34 | torch.manual_seed(opts.seed) 35 | torch.cuda.manual_seed(opts.seed) 36 | if not os.path.exists(opts.output_folder): 37 | os.makedirs(opts.output_folder) 38 | 39 | # Load experiment setting 40 | config = get_config(opts.config) 41 | opts.num_style = 1 if opts.style != '' else opts.num_style 42 | 43 | # Setup model and data loader 44 | config['vgg_model_path'] = opts.output_path 45 | if opts.trainer == 'MUNIT': 46 | style_dim = config['gen']['style_dim'] 47 | trainer = MUNIT_Trainer(config) 48 | elif opts.trainer == 'UNIT': 49 | trainer = UNIT_Trainer(config) 50 | else: 51 | sys.exit("Only support MUNIT|UNIT") 52 | 53 | try: 54 | state_dict = torch.load(opts.checkpoint) 55 | trainer.gen_a.load_state_dict(state_dict['a']) 56 | trainer.gen_b.load_state_dict(state_dict['b']) 57 | except: 58 | state_dict = pytorch03_to_pytorch04(torch.load(opts.checkpoint), opts.trainer) 59 | trainer.gen_a.load_state_dict(state_dict['a']) 60 | trainer.gen_b.load_state_dict(state_dict['b']) 61 | 62 | trainer.cuda() 63 | trainer.eval() 64 | encode = trainer.gen_a.encode if opts.a2b else trainer.gen_b.encode # encode function 65 | style_encode = trainer.gen_b.encode if opts.a2b else trainer.gen_a.encode # encode function 66 | decode = trainer.gen_b.decode if opts.a2b else trainer.gen_a.decode # decode function 67 | 68 | if 'new_size' in config: 69 | new_size = config['new_size'] 70 | else: 71 | if opts.a2b==1: 72 | new_size = config['new_size_a'] 73 | else: 74 | new_size = config['new_size_b'] 75 | 76 | with torch.no_grad(): 77 | transform = transforms.Compose([transforms.Resize(new_size), 78 | transforms.ToTensor(), 79 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 80 | image = Variable(transform(Image.open(opts.input).convert('RGB')).unsqueeze(0).cuda()) 81 | style_image = Variable(transform(Image.open(opts.style).convert('RGB')).unsqueeze(0).cuda()) if opts.style != '' else None 82 | 83 | # Start testing 84 | content, _ = encode(image) 85 | 86 | if opts.trainer == 'MUNIT': 87 | style_rand = Variable(torch.randn(opts.num_style, style_dim, 1, 1).cuda()) 88 | if opts.style != '': 89 | _, style = style_encode(style_image) 90 | else: 91 | style = style_rand 92 | for j in range(opts.num_style): 93 | s = style[j].unsqueeze(0) 94 | outputs = decode(content, s) 95 | outputs = (outputs + 1) / 2. 96 | path = os.path.join(opts.output_folder, 'output{:03d}.jpg'.format(j)) 97 | vutils.save_image(outputs.data, path, padding=0, normalize=True) 98 | elif opts.trainer == 'UNIT': 99 | outputs = decode(content) 100 | outputs = (outputs + 1) / 2. 101 | path = os.path.join(opts.output_folder, 'output.jpg') 102 | vutils.save_image(outputs.data, path, padding=0, normalize=True) 103 | else: 104 | pass 105 | 106 | if not opts.output_only: 107 | # also save input images 108 | vutils.save_image(image.data, os.path.join(opts.output_folder, 'input.jpg'), padding=0, normalize=True) 109 | 110 | -------------------------------------------------------------------------------- /test_batch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | from __future__ import print_function 6 | from utils import get_config, get_data_loader_folder, pytorch03_to_pytorch04, load_inception 7 | from trainer import MUNIT_Trainer, UNIT_Trainer 8 | from torch import nn 9 | from scipy.stats import entropy 10 | import torch.nn.functional as F 11 | import argparse 12 | from torch.autograd import Variable 13 | from data import ImageFolder 14 | import numpy as np 15 | import torchvision.utils as vutils 16 | try: 17 | from itertools import izip as zip 18 | except ImportError: # will be 3.x series 19 | pass 20 | import sys 21 | import torch 22 | import os 23 | 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--config', type=str, default='configs/edges2handbags_folder', help='Path to the config file.') 27 | parser.add_argument('--input_folder', type=str, help="input image folder") 28 | parser.add_argument('--output_folder', type=str, help="output image folder") 29 | parser.add_argument('--checkpoint', type=str, help="checkpoint of autoencoders") 30 | parser.add_argument('--a2b', type=int, help="1 for a2b and 0 for b2a", default=1) 31 | parser.add_argument('--seed', type=int, default=1, help="random seed") 32 | parser.add_argument('--num_style',type=int, default=10, help="number of styles to sample") 33 | parser.add_argument('--synchronized', action='store_true', help="whether use synchronized style code or not") 34 | parser.add_argument('--output_only', action='store_true', help="whether only save the output images or also save the input images") 35 | parser.add_argument('--output_path', type=str, default='.', help="path for logs, checkpoints, and VGG model weight") 36 | parser.add_argument('--trainer', type=str, default='MUNIT', help="MUNIT|UNIT") 37 | parser.add_argument('--compute_IS', action='store_true', help="whether to compute Inception Score or not") 38 | parser.add_argument('--compute_CIS', action='store_true', help="whether to compute Conditional Inception Score or not") 39 | parser.add_argument('--inception_a', type=str, default='.', help="path to the pretrained inception network for domain A") 40 | parser.add_argument('--inception_b', type=str, default='.', help="path to the pretrained inception network for domain B") 41 | 42 | opts = parser.parse_args() 43 | 44 | 45 | torch.manual_seed(opts.seed) 46 | torch.cuda.manual_seed(opts.seed) 47 | 48 | # Load experiment setting 49 | config = get_config(opts.config) 50 | input_dim = config['input_dim_a'] if opts.a2b else config['input_dim_b'] 51 | 52 | # Load the inception networks if we need to compute IS or CIIS 53 | if opts.compute_IS or opts.compute_IS: 54 | inception = load_inception(opts.inception_b) if opts.a2b else load_inception(opts.inception_a) 55 | # freeze the inception models and set eval mode 56 | inception.eval() 57 | for param in inception.parameters(): 58 | param.requires_grad = False 59 | inception_up = nn.Upsample(size=(299, 299), mode='bilinear') 60 | 61 | # Setup model and data loader 62 | image_names = ImageFolder(opts.input_folder, transform=None, return_paths=True) 63 | data_loader = get_data_loader_folder(opts.input_folder, 1, False, new_size=config['new_size_a'], crop=False) 64 | 65 | config['vgg_model_path'] = opts.output_path 66 | if opts.trainer == 'MUNIT': 67 | style_dim = config['gen']['style_dim'] 68 | trainer = MUNIT_Trainer(config) 69 | elif opts.trainer == 'UNIT': 70 | trainer = UNIT_Trainer(config) 71 | else: 72 | sys.exit("Only support MUNIT|UNIT") 73 | 74 | try: 75 | state_dict = torch.load(opts.checkpoint) 76 | trainer.gen_a.load_state_dict(state_dict['a']) 77 | trainer.gen_b.load_state_dict(state_dict['b']) 78 | except: 79 | state_dict = pytorch03_to_pytorch04(torch.load(opts.checkpoint), opts.trainer) 80 | trainer.gen_a.load_state_dict(state_dict['a']) 81 | trainer.gen_b.load_state_dict(state_dict['b']) 82 | 83 | trainer.cuda() 84 | trainer.eval() 85 | encode = trainer.gen_a.encode if opts.a2b else trainer.gen_b.encode # encode function 86 | decode = trainer.gen_b.decode if opts.a2b else trainer.gen_a.decode # decode function 87 | 88 | if opts.compute_IS: 89 | IS = [] 90 | all_preds = [] 91 | if opts.compute_CIS: 92 | CIS = [] 93 | 94 | if opts.trainer == 'MUNIT': 95 | # Start testing 96 | style_fixed = Variable(torch.randn(opts.num_style, style_dim, 1, 1).cuda(), volatile=True) 97 | for i, (images, names) in enumerate(zip(data_loader, image_names)): 98 | if opts.compute_CIS: 99 | cur_preds = [] 100 | print(names[1]) 101 | images = Variable(images.cuda(), volatile=True) 102 | content, _ = encode(images) 103 | style = style_fixed if opts.synchronized else Variable(torch.randn(opts.num_style, style_dim, 1, 1).cuda(), volatile=True) 104 | for j in range(opts.num_style): 105 | s = style[j].unsqueeze(0) 106 | outputs = decode(content, s) 107 | outputs = (outputs + 1) / 2. 108 | if opts.compute_IS or opts.compute_CIS: 109 | pred = F.softmax(inception(inception_up(outputs)), dim=1).cpu().data.numpy() # get the predicted class distribution 110 | if opts.compute_IS: 111 | all_preds.append(pred) 112 | if opts.compute_CIS: 113 | cur_preds.append(pred) 114 | # path = os.path.join(opts.output_folder, 'input{:03d}_output{:03d}.jpg'.format(i, j)) 115 | basename = os.path.basename(names[1]) 116 | path = os.path.join(opts.output_folder+"_%02d"%j,basename) 117 | if not os.path.exists(os.path.dirname(path)): 118 | os.makedirs(os.path.dirname(path)) 119 | vutils.save_image(outputs.data, path, padding=0, normalize=True) 120 | if opts.compute_CIS: 121 | cur_preds = np.concatenate(cur_preds, 0) 122 | py = np.sum(cur_preds, axis=0) # prior is computed from outputs given a specific input 123 | for j in range(cur_preds.shape[0]): 124 | pyx = cur_preds[j, :] 125 | CIS.append(entropy(pyx, py)) 126 | if not opts.output_only: 127 | # also save input images 128 | vutils.save_image(images.data, os.path.join(opts.output_folder, 'input{:03d}.jpg'.format(i)), padding=0, normalize=True) 129 | if opts.compute_IS: 130 | all_preds = np.concatenate(all_preds, 0) 131 | py = np.sum(all_preds, axis=0) # prior is computed from all outputs 132 | for j in range(all_preds.shape[0]): 133 | pyx = all_preds[j, :] 134 | IS.append(entropy(pyx, py)) 135 | 136 | if opts.compute_IS: 137 | print("Inception Score: {}".format(np.exp(np.mean(IS)))) 138 | if opts.compute_CIS: 139 | print("conditional Inception Score: {}".format(np.exp(np.mean(CIS)))) 140 | 141 | elif opts.trainer == 'UNIT': 142 | # Start testing 143 | for i, (images, names) in enumerate(zip(data_loader, image_names)): 144 | print(names[1]) 145 | images = Variable(images.cuda(), volatile=True) 146 | content, _ = encode(images) 147 | 148 | outputs = decode(content) 149 | outputs = (outputs + 1) / 2. 150 | # path = os.path.join(opts.output_folder, 'input{:03d}_output{:03d}.jpg'.format(i, j)) 151 | basename = os.path.basename(names[1]) 152 | path = os.path.join(opts.output_folder,basename) 153 | if not os.path.exists(os.path.dirname(path)): 154 | os.makedirs(os.path.dirname(path)) 155 | vutils.save_image(outputs.data, path, padding=0, normalize=True) 156 | if not opts.output_only: 157 | # also save input images 158 | vutils.save_image(images.data, os.path.join(opts.output_folder, 'input{:03d}.jpg'.format(i)), padding=0, normalize=True) 159 | else: 160 | pass 161 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | from utils import get_all_data_loaders, prepare_sub_folder, write_html, write_loss, get_config, write_2images, Timer 6 | import argparse 7 | from torch.autograd import Variable 8 | from trainer import MUNIT_Trainer, UNIT_Trainer 9 | import torch.backends.cudnn as cudnn 10 | import torch 11 | try: 12 | from itertools import izip as zip 13 | except ImportError: # will be 3.x series 14 | pass 15 | import os 16 | import sys 17 | import tensorboardX 18 | import shutil 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--config', type=str, default='configs/edges2handbags_folder.yaml', help='Path to the config file.') 22 | parser.add_argument('--output_path', type=str, default='.', help="outputs path") 23 | parser.add_argument("--resume", action="store_true") 24 | parser.add_argument('--trainer', type=str, default='MUNIT', help="MUNIT|UNIT") 25 | opts = parser.parse_args() 26 | 27 | cudnn.benchmark = True 28 | 29 | # Load experiment setting 30 | config = get_config(opts.config) 31 | max_iter = config['max_iter'] 32 | display_size = config['display_size'] 33 | config['vgg_model_path'] = opts.output_path 34 | 35 | # Setup model and data loader 36 | if opts.trainer == 'MUNIT': 37 | trainer = MUNIT_Trainer(config) 38 | elif opts.trainer == 'UNIT': 39 | trainer = UNIT_Trainer(config) 40 | else: 41 | sys.exit("Only support MUNIT|UNIT") 42 | trainer.cuda() 43 | train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(config) 44 | train_display_images_a = torch.stack([train_loader_a.dataset[i] for i in range(display_size)]).cuda() 45 | train_display_images_b = torch.stack([train_loader_b.dataset[i] for i in range(display_size)]).cuda() 46 | test_display_images_a = torch.stack([test_loader_a.dataset[i] for i in range(display_size)]).cuda() 47 | test_display_images_b = torch.stack([test_loader_b.dataset[i] for i in range(display_size)]).cuda() 48 | 49 | # Setup logger and output folders 50 | model_name = os.path.splitext(os.path.basename(opts.config))[0] 51 | train_writer = tensorboardX.SummaryWriter(os.path.join(opts.output_path + "/logs", model_name)) 52 | output_directory = os.path.join(opts.output_path + "/outputs", model_name) 53 | checkpoint_directory, image_directory = prepare_sub_folder(output_directory) 54 | shutil.copy(opts.config, os.path.join(output_directory, 'config.yaml')) # copy config file to output folder 55 | 56 | # Start training 57 | iterations = trainer.resume(checkpoint_directory, hyperparameters=config) if opts.resume else 0 58 | while True: 59 | for it, (images_a, images_b) in enumerate(zip(train_loader_a, train_loader_b)): 60 | trainer.update_learning_rate() 61 | images_a, images_b = images_a.cuda().detach(), images_b.cuda().detach() 62 | 63 | with Timer("Elapsed time in update: %f"): 64 | # Main training code 65 | trainer.dis_update(images_a, images_b, config) 66 | trainer.gen_update(images_a, images_b, config) 67 | torch.cuda.synchronize() 68 | 69 | # Dump training stats in log file 70 | if (iterations + 1) % config['log_iter'] == 0: 71 | print("Iteration: %08d/%08d" % (iterations + 1, max_iter)) 72 | write_loss(iterations, trainer, train_writer) 73 | 74 | # Write images 75 | if (iterations + 1) % config['image_save_iter'] == 0: 76 | with torch.no_grad(): 77 | test_image_outputs = trainer.sample(test_display_images_a, test_display_images_b) 78 | train_image_outputs = trainer.sample(train_display_images_a, train_display_images_b) 79 | write_2images(test_image_outputs, display_size, image_directory, 'test_%08d' % (iterations + 1)) 80 | write_2images(train_image_outputs, display_size, image_directory, 'train_%08d' % (iterations + 1)) 81 | # HTML 82 | write_html(output_directory + "/index.html", iterations + 1, config['image_save_iter'], 'images') 83 | 84 | if (iterations + 1) % config['image_display_iter'] == 0: 85 | with torch.no_grad(): 86 | image_outputs = trainer.sample(train_display_images_a, train_display_images_b) 87 | write_2images(image_outputs, display_size, image_directory, 'train_current') 88 | 89 | # Save network weights 90 | if (iterations + 1) % config['snapshot_save_iter'] == 0: 91 | trainer.save(checkpoint_directory, iterations) 92 | 93 | iterations += 1 94 | if iterations >= max_iter: 95 | sys.exit('Finish training') 96 | 97 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | from networks import AdaINGen, MsImageDis, VAEGen 6 | from utils import weights_init, get_model_list, vgg_preprocess, load_vgg16, get_scheduler 7 | from torch.autograd import Variable 8 | import torch 9 | import torch.nn as nn 10 | import os 11 | 12 | class MUNIT_Trainer(nn.Module): 13 | def __init__(self, hyperparameters): 14 | super(MUNIT_Trainer, self).__init__() 15 | lr = hyperparameters['lr'] 16 | # Initiate the networks 17 | self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a 18 | self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b 19 | self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a 20 | self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b 21 | self.instancenorm = nn.InstanceNorm2d(512, affine=False) 22 | self.style_dim = hyperparameters['gen']['style_dim'] 23 | 24 | # fix the noise used in sampling 25 | display_size = int(hyperparameters['display_size']) 26 | self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() 27 | self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() 28 | 29 | # Setup the optimizers 30 | beta1 = hyperparameters['beta1'] 31 | beta2 = hyperparameters['beta2'] 32 | dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) 33 | gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) 34 | self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], 35 | lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) 36 | self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], 37 | lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) 38 | self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) 39 | self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) 40 | 41 | # Network weight initialization 42 | self.apply(weights_init(hyperparameters['init'])) 43 | self.dis_a.apply(weights_init('gaussian')) 44 | self.dis_b.apply(weights_init('gaussian')) 45 | 46 | # Load VGG model if needed 47 | if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: 48 | self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') 49 | self.vgg.eval() 50 | for param in self.vgg.parameters(): 51 | param.requires_grad = False 52 | 53 | def recon_criterion(self, input, target): 54 | return torch.mean(torch.abs(input - target)) 55 | 56 | def forward(self, x_a, x_b): 57 | self.eval() 58 | s_a = Variable(self.s_a) 59 | s_b = Variable(self.s_b) 60 | c_a, s_a_fake = self.gen_a.encode(x_a) 61 | c_b, s_b_fake = self.gen_b.encode(x_b) 62 | x_ba = self.gen_a.decode(c_b, s_a) 63 | x_ab = self.gen_b.decode(c_a, s_b) 64 | self.train() 65 | return x_ab, x_ba 66 | 67 | def gen_update(self, x_a, x_b, hyperparameters): 68 | self.gen_opt.zero_grad() 69 | s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) 70 | s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) 71 | # encode 72 | c_a, s_a_prime = self.gen_a.encode(x_a) 73 | c_b, s_b_prime = self.gen_b.encode(x_b) 74 | # decode (within domain) 75 | x_a_recon = self.gen_a.decode(c_a, s_a_prime) 76 | x_b_recon = self.gen_b.decode(c_b, s_b_prime) 77 | # decode (cross domain) 78 | x_ba = self.gen_a.decode(c_b, s_a) 79 | x_ab = self.gen_b.decode(c_a, s_b) 80 | # encode again 81 | c_b_recon, s_a_recon = self.gen_a.encode(x_ba) 82 | c_a_recon, s_b_recon = self.gen_b.encode(x_ab) 83 | # decode again (if needed) 84 | x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None 85 | x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None 86 | 87 | # reconstruction loss 88 | self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) 89 | self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) 90 | self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) 91 | self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) 92 | self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) 93 | self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) 94 | self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 95 | self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 96 | # GAN loss 97 | self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) 98 | self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) 99 | # domain-invariant perceptual loss 100 | self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 101 | self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 102 | # total loss 103 | self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ 104 | hyperparameters['gan_w'] * self.loss_gen_adv_b + \ 105 | hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ 106 | hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ 107 | hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ 108 | hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ 109 | hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ 110 | hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ 111 | hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ 112 | hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ 113 | hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ 114 | hyperparameters['vgg_w'] * self.loss_gen_vgg_b 115 | self.loss_gen_total.backward() 116 | self.gen_opt.step() 117 | 118 | def compute_vgg_loss(self, vgg, img, target): 119 | img_vgg = vgg_preprocess(img) 120 | target_vgg = vgg_preprocess(target) 121 | img_fea = vgg(img_vgg) 122 | target_fea = vgg(target_vgg) 123 | return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) 124 | 125 | def sample(self, x_a, x_b): 126 | self.eval() 127 | s_a1 = Variable(self.s_a) 128 | s_b1 = Variable(self.s_b) 129 | s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) 130 | s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) 131 | x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] 132 | for i in range(x_a.size(0)): 133 | c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) 134 | c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) 135 | x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) 136 | x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) 137 | x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) 138 | x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) 139 | x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) 140 | x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) 141 | x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) 142 | x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) 143 | x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) 144 | self.train() 145 | return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 146 | 147 | def dis_update(self, x_a, x_b, hyperparameters): 148 | self.dis_opt.zero_grad() 149 | s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) 150 | s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) 151 | # encode 152 | c_a, _ = self.gen_a.encode(x_a) 153 | c_b, _ = self.gen_b.encode(x_b) 154 | # decode (cross domain) 155 | x_ba = self.gen_a.decode(c_b, s_a) 156 | x_ab = self.gen_b.decode(c_a, s_b) 157 | # D loss 158 | self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) 159 | self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) 160 | self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b 161 | self.loss_dis_total.backward() 162 | self.dis_opt.step() 163 | 164 | def update_learning_rate(self): 165 | if self.dis_scheduler is not None: 166 | self.dis_scheduler.step() 167 | if self.gen_scheduler is not None: 168 | self.gen_scheduler.step() 169 | 170 | def resume(self, checkpoint_dir, hyperparameters): 171 | # Load generators 172 | last_model_name = get_model_list(checkpoint_dir, "gen") 173 | state_dict = torch.load(last_model_name) 174 | self.gen_a.load_state_dict(state_dict['a']) 175 | self.gen_b.load_state_dict(state_dict['b']) 176 | iterations = int(last_model_name[-11:-3]) 177 | # Load discriminators 178 | last_model_name = get_model_list(checkpoint_dir, "dis") 179 | state_dict = torch.load(last_model_name) 180 | self.dis_a.load_state_dict(state_dict['a']) 181 | self.dis_b.load_state_dict(state_dict['b']) 182 | # Load optimizers 183 | state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) 184 | self.dis_opt.load_state_dict(state_dict['dis']) 185 | self.gen_opt.load_state_dict(state_dict['gen']) 186 | # Reinitilize schedulers 187 | self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) 188 | self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) 189 | print('Resume from iteration %d' % iterations) 190 | return iterations 191 | 192 | def save(self, snapshot_dir, iterations): 193 | # Save generators, discriminators, and optimizers 194 | gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) 195 | dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) 196 | opt_name = os.path.join(snapshot_dir, 'optimizer.pt') 197 | torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) 198 | torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) 199 | torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name) 200 | 201 | 202 | class UNIT_Trainer(nn.Module): 203 | def __init__(self, hyperparameters): 204 | super(UNIT_Trainer, self).__init__() 205 | lr = hyperparameters['lr'] 206 | # Initiate the networks 207 | self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a 208 | self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b 209 | self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a 210 | self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b 211 | self.instancenorm = nn.InstanceNorm2d(512, affine=False) 212 | 213 | # Setup the optimizers 214 | beta1 = hyperparameters['beta1'] 215 | beta2 = hyperparameters['beta2'] 216 | dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) 217 | gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) 218 | self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], 219 | lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) 220 | self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], 221 | lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) 222 | self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) 223 | self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) 224 | 225 | # Network weight initialization 226 | self.apply(weights_init(hyperparameters['init'])) 227 | self.dis_a.apply(weights_init('gaussian')) 228 | self.dis_b.apply(weights_init('gaussian')) 229 | 230 | # Load VGG model if needed 231 | if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: 232 | self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') 233 | self.vgg.eval() 234 | for param in self.vgg.parameters(): 235 | param.requires_grad = False 236 | 237 | def recon_criterion(self, input, target): 238 | return torch.mean(torch.abs(input - target)) 239 | 240 | def forward(self, x_a, x_b): 241 | self.eval() 242 | h_a, _ = self.gen_a.encode(x_a) 243 | h_b, _ = self.gen_b.encode(x_b) 244 | x_ba = self.gen_a.decode(h_b) 245 | x_ab = self.gen_b.decode(h_a) 246 | self.train() 247 | return x_ab, x_ba 248 | 249 | def __compute_kl(self, mu): 250 | # def _compute_kl(self, mu, sd): 251 | # mu_2 = torch.pow(mu, 2) 252 | # sd_2 = torch.pow(sd, 2) 253 | # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0) 254 | # return encoding_loss 255 | mu_2 = torch.pow(mu, 2) 256 | encoding_loss = torch.mean(mu_2) 257 | return encoding_loss 258 | 259 | def gen_update(self, x_a, x_b, hyperparameters): 260 | self.gen_opt.zero_grad() 261 | # encode 262 | h_a, n_a = self.gen_a.encode(x_a) 263 | h_b, n_b = self.gen_b.encode(x_b) 264 | # decode (within domain) 265 | x_a_recon = self.gen_a.decode(h_a + n_a) 266 | x_b_recon = self.gen_b.decode(h_b + n_b) 267 | # decode (cross domain) 268 | x_ba = self.gen_a.decode(h_b + n_b) 269 | x_ab = self.gen_b.decode(h_a + n_a) 270 | # encode again 271 | h_b_recon, n_b_recon = self.gen_a.encode(x_ba) 272 | h_a_recon, n_a_recon = self.gen_b.encode(x_ab) 273 | # decode again (if needed) 274 | x_aba = self.gen_a.decode(h_a_recon + n_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None 275 | x_bab = self.gen_b.decode(h_b_recon + n_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None 276 | 277 | # reconstruction loss 278 | self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) 279 | self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) 280 | self.loss_gen_recon_kl_a = self.__compute_kl(h_a) 281 | self.loss_gen_recon_kl_b = self.__compute_kl(h_b) 282 | self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a) 283 | self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b) 284 | self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon) 285 | self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon) 286 | # GAN loss 287 | self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) 288 | self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) 289 | # domain-invariant perceptual loss 290 | self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 291 | self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 292 | # total loss 293 | self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ 294 | hyperparameters['gan_w'] * self.loss_gen_adv_b + \ 295 | hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ 296 | hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ 297 | hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ 298 | hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ 299 | hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ 300 | hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ 301 | hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ 302 | hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \ 303 | hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ 304 | hyperparameters['vgg_w'] * self.loss_gen_vgg_b 305 | self.loss_gen_total.backward() 306 | self.gen_opt.step() 307 | 308 | def compute_vgg_loss(self, vgg, img, target): 309 | img_vgg = vgg_preprocess(img) 310 | target_vgg = vgg_preprocess(target) 311 | img_fea = vgg(img_vgg) 312 | target_fea = vgg(target_vgg) 313 | return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) 314 | 315 | def sample(self, x_a, x_b): 316 | self.eval() 317 | x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], [] 318 | for i in range(x_a.size(0)): 319 | h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0)) 320 | h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0)) 321 | x_a_recon.append(self.gen_a.decode(h_a)) 322 | x_b_recon.append(self.gen_b.decode(h_b)) 323 | x_ba.append(self.gen_a.decode(h_b)) 324 | x_ab.append(self.gen_b.decode(h_a)) 325 | x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) 326 | x_ba = torch.cat(x_ba) 327 | x_ab = torch.cat(x_ab) 328 | self.train() 329 | return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba 330 | 331 | def dis_update(self, x_a, x_b, hyperparameters): 332 | self.dis_opt.zero_grad() 333 | # encode 334 | h_a, n_a = self.gen_a.encode(x_a) 335 | h_b, n_b = self.gen_b.encode(x_b) 336 | # decode (cross domain) 337 | x_ba = self.gen_a.decode(h_b + n_b) 338 | x_ab = self.gen_b.decode(h_a + n_a) 339 | # D loss 340 | self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) 341 | self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) 342 | self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b 343 | self.loss_dis_total.backward() 344 | self.dis_opt.step() 345 | 346 | def update_learning_rate(self): 347 | if self.dis_scheduler is not None: 348 | self.dis_scheduler.step() 349 | if self.gen_scheduler is not None: 350 | self.gen_scheduler.step() 351 | 352 | def resume(self, checkpoint_dir, hyperparameters): 353 | # Load generators 354 | last_model_name = get_model_list(checkpoint_dir, "gen") 355 | state_dict = torch.load(last_model_name) 356 | self.gen_a.load_state_dict(state_dict['a']) 357 | self.gen_b.load_state_dict(state_dict['b']) 358 | iterations = int(last_model_name[-11:-3]) 359 | # Load discriminators 360 | last_model_name = get_model_list(checkpoint_dir, "dis") 361 | state_dict = torch.load(last_model_name) 362 | self.dis_a.load_state_dict(state_dict['a']) 363 | self.dis_b.load_state_dict(state_dict['b']) 364 | # Load optimizers 365 | state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) 366 | self.dis_opt.load_state_dict(state_dict['dis']) 367 | self.gen_opt.load_state_dict(state_dict['gen']) 368 | # Reinitilize schedulers 369 | self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) 370 | self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) 371 | print('Resume from iteration %d' % iterations) 372 | return iterations 373 | 374 | def save(self, snapshot_dir, iterations): 375 | # Save generators, discriminators, and optimizers 376 | gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) 377 | dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) 378 | opt_name = os.path.join(snapshot_dir, 'optimizer.pt') 379 | torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) 380 | torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) 381 | torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name) 382 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | from torch.utils.serialization import load_lua 6 | from torch.utils.data import DataLoader 7 | from networks import Vgg16 8 | from torch.autograd import Variable 9 | from torch.optim import lr_scheduler 10 | from torchvision import transforms 11 | from data import ImageFilelist, ImageFolder 12 | import torch 13 | import torch.nn as nn 14 | import os 15 | import math 16 | import torchvision.utils as vutils 17 | import yaml 18 | import numpy as np 19 | import torch.nn.init as init 20 | import time 21 | # Methods 22 | # get_all_data_loaders : primary data loader interface (load trainA, testA, trainB, testB) 23 | # get_data_loader_list : list-based data loader 24 | # get_data_loader_folder : folder-based data loader 25 | # get_config : load yaml file 26 | # eformat : 27 | # write_2images : save output image 28 | # prepare_sub_folder : create checkpoints and images folders for saving outputs 29 | # write_one_row_html : write one row of the html file for output images 30 | # write_html : create the html file. 31 | # write_loss 32 | # slerp 33 | # get_slerp_interp 34 | # get_model_list 35 | # load_vgg16 36 | # load_inception 37 | # vgg_preprocess 38 | # get_scheduler 39 | # weights_init 40 | 41 | def get_all_data_loaders(conf): 42 | batch_size = conf['batch_size'] 43 | num_workers = conf['num_workers'] 44 | if 'new_size' in conf: 45 | new_size_a = new_size_b = conf['new_size'] 46 | else: 47 | new_size_a = conf['new_size_a'] 48 | new_size_b = conf['new_size_b'] 49 | height = conf['crop_image_height'] 50 | width = conf['crop_image_width'] 51 | 52 | if 'data_root' in conf: 53 | train_loader_a = get_data_loader_folder(os.path.join(conf['data_root'], 'trainA'), batch_size, True, 54 | new_size_a, height, width, num_workers, True) 55 | test_loader_a = get_data_loader_folder(os.path.join(conf['data_root'], 'testA'), batch_size, False, 56 | new_size_a, new_size_a, new_size_a, num_workers, True) 57 | train_loader_b = get_data_loader_folder(os.path.join(conf['data_root'], 'trainB'), batch_size, True, 58 | new_size_b, height, width, num_workers, True) 59 | test_loader_b = get_data_loader_folder(os.path.join(conf['data_root'], 'testB'), batch_size, False, 60 | new_size_b, new_size_b, new_size_b, num_workers, True) 61 | else: 62 | train_loader_a = get_data_loader_list(conf['data_folder_train_a'], conf['data_list_train_a'], batch_size, True, 63 | new_size_a, height, width, num_workers, True) 64 | test_loader_a = get_data_loader_list(conf['data_folder_test_a'], conf['data_list_test_a'], batch_size, False, 65 | new_size_a, new_size_a, new_size_a, num_workers, True) 66 | train_loader_b = get_data_loader_list(conf['data_folder_train_b'], conf['data_list_train_b'], batch_size, True, 67 | new_size_b, height, width, num_workers, True) 68 | test_loader_b = get_data_loader_list(conf['data_folder_test_b'], conf['data_list_test_b'], batch_size, False, 69 | new_size_b, new_size_b, new_size_b, num_workers, True) 70 | return train_loader_a, train_loader_b, test_loader_a, test_loader_b 71 | 72 | 73 | def get_data_loader_list(root, file_list, batch_size, train, new_size=None, 74 | height=256, width=256, num_workers=4, crop=True): 75 | transform_list = [transforms.ToTensor(), 76 | transforms.Normalize((0.5, 0.5, 0.5), 77 | (0.5, 0.5, 0.5))] 78 | transform_list = [transforms.RandomCrop((height, width))] + transform_list if crop else transform_list 79 | transform_list = [transforms.Resize(new_size)] + transform_list if new_size is not None else transform_list 80 | transform_list = [transforms.RandomHorizontalFlip()] + transform_list if train else transform_list 81 | transform = transforms.Compose(transform_list) 82 | dataset = ImageFilelist(root, file_list, transform=transform) 83 | loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=train, drop_last=True, num_workers=num_workers) 84 | return loader 85 | 86 | def get_data_loader_folder(input_folder, batch_size, train, new_size=None, 87 | height=256, width=256, num_workers=4, crop=True): 88 | transform_list = [transforms.ToTensor(), 89 | transforms.Normalize((0.5, 0.5, 0.5), 90 | (0.5, 0.5, 0.5))] 91 | transform_list = [transforms.RandomCrop((height, width))] + transform_list if crop else transform_list 92 | transform_list = [transforms.Resize(new_size)] + transform_list if new_size is not None else transform_list 93 | transform_list = [transforms.RandomHorizontalFlip()] + transform_list if train else transform_list 94 | transform = transforms.Compose(transform_list) 95 | dataset = ImageFolder(input_folder, transform=transform) 96 | loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=train, drop_last=True, num_workers=num_workers) 97 | return loader 98 | 99 | 100 | def get_config(config): 101 | with open(config, 'r') as stream: 102 | return yaml.load(stream) 103 | 104 | 105 | def eformat(f, prec): 106 | s = "%.*e"%(prec, f) 107 | mantissa, exp = s.split('e') 108 | # add 1 to digits as 1 is taken by sign +/- 109 | return "%se%d"%(mantissa, int(exp)) 110 | 111 | 112 | def __write_images(image_outputs, display_image_num, file_name): 113 | image_outputs = [images.expand(-1, 3, -1, -1) for images in image_outputs] # expand gray-scale images to 3 channels 114 | image_tensor = torch.cat([images[:display_image_num] for images in image_outputs], 0) 115 | image_grid = vutils.make_grid(image_tensor.data, nrow=display_image_num, padding=0, normalize=True) 116 | vutils.save_image(image_grid, file_name, nrow=1) 117 | 118 | 119 | def write_2images(image_outputs, display_image_num, image_directory, postfix): 120 | n = len(image_outputs) 121 | __write_images(image_outputs[0:n//2], display_image_num, '%s/gen_a2b_%s.jpg' % (image_directory, postfix)) 122 | __write_images(image_outputs[n//2:n], display_image_num, '%s/gen_b2a_%s.jpg' % (image_directory, postfix)) 123 | 124 | 125 | def prepare_sub_folder(output_directory): 126 | image_directory = os.path.join(output_directory, 'images') 127 | if not os.path.exists(image_directory): 128 | print("Creating directory: {}".format(image_directory)) 129 | os.makedirs(image_directory) 130 | checkpoint_directory = os.path.join(output_directory, 'checkpoints') 131 | if not os.path.exists(checkpoint_directory): 132 | print("Creating directory: {}".format(checkpoint_directory)) 133 | os.makedirs(checkpoint_directory) 134 | return checkpoint_directory, image_directory 135 | 136 | 137 | def write_one_row_html(html_file, iterations, img_filename, all_size): 138 | html_file.write("

iteration [%d] (%s)

" % (iterations,img_filename.split('/')[-1])) 139 | html_file.write(""" 140 |

141 | 142 |
143 |

144 | """ % (img_filename, img_filename, all_size)) 145 | return 146 | 147 | 148 | def write_html(filename, iterations, image_save_iterations, image_directory, all_size=1536): 149 | html_file = open(filename, "w") 150 | html_file.write(''' 151 | 152 | 153 | 154 | Experiment name = %s 155 | 156 | 157 | 158 | ''' % os.path.basename(filename)) 159 | html_file.write("

current

") 160 | write_one_row_html(html_file, iterations, '%s/gen_a2b_train_current.jpg' % (image_directory), all_size) 161 | write_one_row_html(html_file, iterations, '%s/gen_b2a_train_current.jpg' % (image_directory), all_size) 162 | for j in range(iterations, image_save_iterations-1, -1): 163 | if j % image_save_iterations == 0: 164 | write_one_row_html(html_file, j, '%s/gen_a2b_test_%08d.jpg' % (image_directory, j), all_size) 165 | write_one_row_html(html_file, j, '%s/gen_b2a_test_%08d.jpg' % (image_directory, j), all_size) 166 | write_one_row_html(html_file, j, '%s/gen_a2b_train_%08d.jpg' % (image_directory, j), all_size) 167 | write_one_row_html(html_file, j, '%s/gen_b2a_train_%08d.jpg' % (image_directory, j), all_size) 168 | html_file.write("") 169 | html_file.close() 170 | 171 | 172 | def write_loss(iterations, trainer, train_writer): 173 | members = [attr for attr in dir(trainer) \ 174 | if not callable(getattr(trainer, attr)) and not attr.startswith("__") and ('loss' in attr or 'grad' in attr or 'nwd' in attr)] 175 | for m in members: 176 | train_writer.add_scalar(m, getattr(trainer, m), iterations + 1) 177 | 178 | 179 | def slerp(val, low, high): 180 | """ 181 | original: Animating Rotation with Quaternion Curves, Ken Shoemake 182 | https://arxiv.org/abs/1609.04468 183 | Code: https://github.com/soumith/dcgan.torch/issues/14, Tom White 184 | """ 185 | omega = np.arccos(np.dot(low / np.linalg.norm(low), high / np.linalg.norm(high))) 186 | so = np.sin(omega) 187 | return np.sin((1.0 - val) * omega) / so * low + np.sin(val * omega) / so * high 188 | 189 | 190 | def get_slerp_interp(nb_latents, nb_interp, z_dim): 191 | """ 192 | modified from: PyTorch inference for "Progressive Growing of GANs" with CelebA snapshot 193 | https://github.com/ptrblck/prog_gans_pytorch_inference 194 | """ 195 | 196 | latent_interps = np.empty(shape=(0, z_dim), dtype=np.float32) 197 | for _ in range(nb_latents): 198 | low = np.random.randn(z_dim) 199 | high = np.random.randn(z_dim) # low + np.random.randn(512) * 0.7 200 | interp_vals = np.linspace(0, 1, num=nb_interp) 201 | latent_interp = np.array([slerp(v, low, high) for v in interp_vals], 202 | dtype=np.float32) 203 | latent_interps = np.vstack((latent_interps, latent_interp)) 204 | 205 | return latent_interps[:, :, np.newaxis, np.newaxis] 206 | 207 | 208 | # Get model list for resume 209 | def get_model_list(dirname, key): 210 | if os.path.exists(dirname) is False: 211 | return None 212 | gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if 213 | os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f] 214 | if gen_models is None: 215 | return None 216 | gen_models.sort() 217 | last_model_name = gen_models[-1] 218 | return last_model_name 219 | 220 | 221 | def load_vgg16(model_dir): 222 | """ Use the model from https://github.com/abhiskk/fast-neural-style/blob/master/neural_style/utils.py """ 223 | if not os.path.exists(model_dir): 224 | os.mkdir(model_dir) 225 | if not os.path.exists(os.path.join(model_dir, 'vgg16.weight')): 226 | if not os.path.exists(os.path.join(model_dir, 'vgg16.t7')): 227 | os.system('wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O ' + os.path.join(model_dir, 'vgg16.t7')) 228 | vgglua = load_lua(os.path.join(model_dir, 'vgg16.t7')) 229 | vgg = Vgg16() 230 | for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()): 231 | dst.data[:] = src 232 | torch.save(vgg.state_dict(), os.path.join(model_dir, 'vgg16.weight')) 233 | vgg = Vgg16() 234 | vgg.load_state_dict(torch.load(os.path.join(model_dir, 'vgg16.weight'))) 235 | return vgg 236 | 237 | def load_inception(model_path): 238 | state_dict = torch.load(model_path) 239 | model = inception_v3(pretrained=False, transform_input=True) 240 | model.aux_logits = False 241 | num_ftrs = model.fc.in_features 242 | model.fc = nn.Linear(num_ftrs, state_dict['fc.weight'].size(0)) 243 | model.load_state_dict(state_dict) 244 | for param in model.parameters(): 245 | param.requires_grad = False 246 | return model 247 | 248 | def vgg_preprocess(batch): 249 | tensortype = type(batch.data) 250 | (r, g, b) = torch.chunk(batch, 3, dim = 1) 251 | batch = torch.cat((b, g, r), dim = 1) # convert RGB to BGR 252 | batch = (batch + 1) * 255 * 0.5 # [-1, 1] -> [0, 255] 253 | mean = tensortype(batch.data.size()).cuda() 254 | mean[:, 0, :, :] = 103.939 255 | mean[:, 1, :, :] = 116.779 256 | mean[:, 2, :, :] = 123.680 257 | batch = batch.sub(Variable(mean)) # subtract mean 258 | return batch 259 | 260 | 261 | def get_scheduler(optimizer, hyperparameters, iterations=-1): 262 | if 'lr_policy' not in hyperparameters or hyperparameters['lr_policy'] == 'constant': 263 | scheduler = None # constant scheduler 264 | elif hyperparameters['lr_policy'] == 'step': 265 | scheduler = lr_scheduler.StepLR(optimizer, step_size=hyperparameters['step_size'], 266 | gamma=hyperparameters['gamma'], last_epoch=iterations) 267 | else: 268 | return NotImplementedError('learning rate policy [%s] is not implemented', hyperparameters['lr_policy']) 269 | return scheduler 270 | 271 | 272 | def weights_init(init_type='gaussian'): 273 | def init_fun(m): 274 | classname = m.__class__.__name__ 275 | if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'): 276 | # print m.__class__.__name__ 277 | if init_type == 'gaussian': 278 | init.normal_(m.weight.data, 0.0, 0.02) 279 | elif init_type == 'xavier': 280 | init.xavier_normal_(m.weight.data, gain=math.sqrt(2)) 281 | elif init_type == 'kaiming': 282 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 283 | elif init_type == 'orthogonal': 284 | init.orthogonal_(m.weight.data, gain=math.sqrt(2)) 285 | elif init_type == 'default': 286 | pass 287 | else: 288 | assert 0, "Unsupported initialization: {}".format(init_type) 289 | if hasattr(m, 'bias') and m.bias is not None: 290 | init.constant_(m.bias.data, 0.0) 291 | 292 | return init_fun 293 | 294 | 295 | class Timer: 296 | def __init__(self, msg): 297 | self.msg = msg 298 | self.start_time = None 299 | 300 | def __enter__(self): 301 | self.start_time = time.time() 302 | 303 | def __exit__(self, exc_type, exc_value, exc_tb): 304 | print(self.msg % (time.time() - self.start_time)) 305 | 306 | 307 | def pytorch03_to_pytorch04(state_dict_base, trainer_name): 308 | def __conversion_core(state_dict_base, trainer_name): 309 | state_dict = state_dict_base.copy() 310 | if trainer_name == 'MUNIT': 311 | for key, value in state_dict_base.items(): 312 | if key.endswith(('enc_content.model.0.norm.running_mean', 313 | 'enc_content.model.0.norm.running_var', 314 | 'enc_content.model.1.norm.running_mean', 315 | 'enc_content.model.1.norm.running_var', 316 | 'enc_content.model.2.norm.running_mean', 317 | 'enc_content.model.2.norm.running_var', 318 | 'enc_content.model.3.model.0.model.1.norm.running_mean', 319 | 'enc_content.model.3.model.0.model.1.norm.running_var', 320 | 'enc_content.model.3.model.0.model.0.norm.running_mean', 321 | 'enc_content.model.3.model.0.model.0.norm.running_var', 322 | 'enc_content.model.3.model.1.model.1.norm.running_mean', 323 | 'enc_content.model.3.model.1.model.1.norm.running_var', 324 | 'enc_content.model.3.model.1.model.0.norm.running_mean', 325 | 'enc_content.model.3.model.1.model.0.norm.running_var', 326 | 'enc_content.model.3.model.2.model.1.norm.running_mean', 327 | 'enc_content.model.3.model.2.model.1.norm.running_var', 328 | 'enc_content.model.3.model.2.model.0.norm.running_mean', 329 | 'enc_content.model.3.model.2.model.0.norm.running_var', 330 | 'enc_content.model.3.model.3.model.1.norm.running_mean', 331 | 'enc_content.model.3.model.3.model.1.norm.running_var', 332 | 'enc_content.model.3.model.3.model.0.norm.running_mean', 333 | 'enc_content.model.3.model.3.model.0.norm.running_var', 334 | )): 335 | del state_dict[key] 336 | else: 337 | def __conversion_core(state_dict_base): 338 | state_dict = state_dict_base.copy() 339 | for key, value in state_dict_base.items(): 340 | if key.endswith(('enc.model.0.norm.running_mean', 341 | 'enc.model.0.norm.running_var', 342 | 'enc.model.1.norm.running_mean', 343 | 'enc.model.1.norm.running_var', 344 | 'enc.model.2.norm.running_mean', 345 | 'enc.model.2.norm.running_var', 346 | 'enc.model.3.model.0.model.1.norm.running_mean', 347 | 'enc.model.3.model.0.model.1.norm.running_var', 348 | 'enc.model.3.model.0.model.0.norm.running_mean', 349 | 'enc.model.3.model.0.model.0.norm.running_var', 350 | 'enc.model.3.model.1.model.1.norm.running_mean', 351 | 'enc.model.3.model.1.model.1.norm.running_var', 352 | 'enc.model.3.model.1.model.0.norm.running_mean', 353 | 'enc.model.3.model.1.model.0.norm.running_var', 354 | 'enc.model.3.model.2.model.1.norm.running_mean', 355 | 'enc.model.3.model.2.model.1.norm.running_var', 356 | 'enc.model.3.model.2.model.0.norm.running_mean', 357 | 'enc.model.3.model.2.model.0.norm.running_var', 358 | 'enc.model.3.model.3.model.1.norm.running_mean', 359 | 'enc.model.3.model.3.model.1.norm.running_var', 360 | 'enc.model.3.model.3.model.0.norm.running_mean', 361 | 'enc.model.3.model.3.model.0.norm.running_var', 362 | 363 | 'dec.model.0.model.0.model.1.norm.running_mean', 364 | 'dec.model.0.model.0.model.1.norm.running_var', 365 | 'dec.model.0.model.0.model.0.norm.running_mean', 366 | 'dec.model.0.model.0.model.0.norm.running_var', 367 | 'dec.model.0.model.1.model.1.norm.running_mean', 368 | 'dec.model.0.model.1.model.1.norm.running_var', 369 | 'dec.model.0.model.1.model.0.norm.running_mean', 370 | 'dec.model.0.model.1.model.0.norm.running_var', 371 | 'dec.model.0.model.2.model.1.norm.running_mean', 372 | 'dec.model.0.model.2.model.1.norm.running_var', 373 | 'dec.model.0.model.2.model.0.norm.running_mean', 374 | 'dec.model.0.model.2.model.0.norm.running_var', 375 | 'dec.model.0.model.3.model.1.norm.running_mean', 376 | 'dec.model.0.model.3.model.1.norm.running_var', 377 | 'dec.model.0.model.3.model.0.norm.running_mean', 378 | 'dec.model.0.model.3.model.0.norm.running_var', 379 | )): 380 | del state_dict[key] 381 | return state_dict 382 | 383 | state_dict = dict() 384 | state_dict['a'] = __conversion_core(state_dict_base['a'], trainer_name) 385 | state_dict['b'] = __conversion_core(state_dict_base['b'], trainer_name) 386 | return state_dict --------------------------------------------------------------------------------