├── .gitignore
├── Dockerfile
├── LICENSE.md
├── README.md
├── TUTORIAL.md
├── USAGE.md
├── configs
├── demo_edges2handbags_folder.yaml
├── demo_edges2handbags_list.yaml
├── edges2handbags_folder.yaml
├── edges2shoes_folder.yaml
├── summer2winter_yosemite256_folder.yaml
└── synthia2cityscape_folder.yaml
├── data.py
├── datasets
└── demo_edges2handbags
│ ├── list_testA.txt
│ ├── list_testB.txt
│ ├── list_trainA.txt
│ ├── list_trainB.txt
│ ├── testA
│ ├── 00000.jpg
│ ├── 00001.jpg
│ ├── 00002.jpg
│ ├── 00003.jpg
│ ├── 00004.jpg
│ ├── 00005.jpg
│ ├── 00006.jpg
│ ├── 00007.jpg
│ └── 00008.jpg
│ ├── testB
│ ├── 00000.jpg
│ ├── 00001.jpg
│ ├── 00002.jpg
│ ├── 00003.jpg
│ ├── 00004.jpg
│ ├── 00005.jpg
│ ├── 00006.jpg
│ ├── 00007.jpg
│ └── 00008.jpg
│ ├── trainA
│ ├── 00000.jpg
│ ├── 00002.jpg
│ ├── 00004.jpg
│ ├── 00006.jpg
│ ├── 00008.jpg
│ ├── 00010.jpg
│ ├── 00012.jpg
│ └── 00014.jpg
│ └── trainB
│ ├── 00001.jpg
│ ├── 00003.jpg
│ ├── 00005.jpg
│ ├── 00007.jpg
│ ├── 00009.jpg
│ ├── 00011.jpg
│ ├── 00013.jpg
│ └── 00015.jpg
├── docs
└── munit_assumption.jpg
├── inputs
├── edges2handbags_edge.jpg
├── edges2handbags_handbag.jpg
├── edges2shoes_edge.jpg
└── edges2shoes_shoe.jpg
├── networks.py
├── results
├── animal.jpg
├── edges2handbags
│ ├── input.jpg
│ ├── output000.jpg
│ ├── output001.jpg
│ ├── output002.jpg
│ ├── output003.jpg
│ ├── output004.jpg
│ ├── output005.jpg
│ ├── output006.jpg
│ ├── output007.jpg
│ ├── output008.jpg
│ └── output009.jpg
├── edges2shoes
│ ├── input.jpg
│ ├── output000.jpg
│ ├── output001.jpg
│ ├── output002.jpg
│ ├── output003.jpg
│ ├── output004.jpg
│ ├── output005.jpg
│ ├── output006.jpg
│ ├── output007.jpg
│ ├── output008.jpg
│ └── output009.jpg
├── edges2shoes_handbags.jpg
├── example_guided.jpg
├── input.jpg
├── output000.jpg
├── street.jpg
├── summer2winter_yosemite.jpg
└── video.jpg
├── scripts
├── demo_train_edges2handbags.sh
├── demo_train_edges2shoes.sh
└── demo_train_summer2winter_yosemite256.sh
├── test.py
├── test_batch.py
├── train.py
├── trainer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | outputs/
2 | src/models/
3 | data/
4 | logs/
5 | .idea/
6 | dgx/scripts/
7 | .ipynb_checkpoints/
8 | src/*jpg
9 | notebooks/.ipynb_checkpoints/*
10 | exps/
11 | src/yaml_generator.py
12 | *.tar.gz
13 | *ipynb
14 | *.zip
15 | *.pkl
16 | *.pyc
17 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:9.1-cudnn7-runtime-ubuntu16.04
2 | # Set anaconda path
3 | ENV ANACONDA /opt/anaconda
4 | ENV PATH $ANACONDA/bin:$PATH
5 | RUN apt-get update && apt-get install -y --no-install-recommends \
6 | wget \
7 | libopencv-dev \
8 | python-opencv \
9 | build-essential \
10 | cmake \
11 | git \
12 | curl \
13 | ca-certificates \
14 | libjpeg-dev \
15 | libpng-dev \
16 | axel \
17 | zip \
18 | unzip
19 | RUN wget https://repo.continuum.io/archive/Anaconda3-5.0.1-Linux-x86_64.sh -P /tmp
20 | RUN bash /tmp/Anaconda3-5.0.1-Linux-x86_64.sh -b -p $ANACONDA
21 | RUN rm /tmp/Anaconda3-5.0.1-Linux-x86_64.sh -rf
22 | RUN conda install -y pytorch=0.4.1 torchvision cuda91 -c pytorch
23 | RUN conda install -y -c anaconda pip
24 | RUN conda install -y -c anaconda yaml
25 | RUN pip install tensorboard tensorboardX;
26 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | ## creative commons
2 |
3 | # Attribution-NonCommercial-ShareAlike 4.0 International
4 |
5 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
6 |
7 | ### Using Creative Commons Public Licenses
8 |
9 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
10 |
11 | * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
12 |
13 | * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
14 |
15 | ## Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
16 |
17 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
18 |
19 | ### Section 1 – Definitions.
20 |
21 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
22 |
23 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
24 |
25 | c. __BY-NC-SA Compatible License__ means a license listed at [creativecommons.org/compatiblelicenses](http://creativecommons.org/compatiblelicenses), approved by Creative Commons as essentially the equivalent of this Public License.
26 |
27 | d. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
28 |
29 | e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
30 |
31 | f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
32 |
33 | g. __License Elements__ means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike.
34 |
35 | h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
36 |
37 | i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
38 |
39 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
40 |
41 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
42 |
43 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
44 |
45 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
46 |
47 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
48 |
49 | ### Section 2 – Scope.
50 |
51 | a. ___License grant.___
52 |
53 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
54 |
55 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
56 |
57 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
58 |
59 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
60 |
61 | 3. __Term.__ The term of this Public License is specified in Section 6(a).
62 |
63 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
64 |
65 | 5. __Downstream recipients.__
66 |
67 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
68 |
69 | B. __Additional offer from the Licensor – Adapted Material.__ Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter’s License You apply.
70 |
71 | C. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
72 |
73 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
74 |
75 | b. ___Other rights.___
76 |
77 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
78 |
79 | 2. Patent and trademark rights are not licensed under this Public License.
80 |
81 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
82 |
83 | ### Section 3 – License Conditions.
84 |
85 | Your exercise of the Licensed Rights is expressly made subject to the following conditions.
86 |
87 | a. ___Attribution.___
88 |
89 | 1. If You Share the Licensed Material (including in modified form), You must:
90 |
91 | A. retain the following if it is supplied by the Licensor with the Licensed Material:
92 |
93 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
94 |
95 | ii. a copyright notice;
96 |
97 | iii. a notice that refers to this Public License;
98 |
99 | iv. a notice that refers to the disclaimer of warranties;
100 |
101 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
102 |
103 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
104 |
105 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
106 |
107 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
108 |
109 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
110 |
111 | b. ___ShareAlike.___
112 |
113 | In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply.
114 |
115 | 1. The Adapter’s License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License.
116 |
117 | 2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material.
118 |
119 | 3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply.
120 |
121 | ### Section 4 – Sui Generis Database Rights.
122 |
123 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
124 |
125 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
126 |
127 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and
128 |
129 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
130 |
131 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
132 |
133 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability.
134 |
135 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
136 |
137 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
138 |
139 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
140 |
141 | ### Section 6 – Term and Termination.
142 |
143 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
144 |
145 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
146 |
147 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
148 |
149 | 2. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
150 |
151 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
152 |
153 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
154 |
155 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
156 |
157 | ### Section 7 – Other Terms and Conditions.
158 |
159 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
160 |
161 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
162 |
163 | ### Section 8 – Interpretation.
164 |
165 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
166 |
167 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
168 |
169 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
170 |
171 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
172 |
173 | ```
174 | Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
175 |
176 | Creative Commons may be contacted at [creativecommons.org](http://creativecommons.org/).
177 | ```
178 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | **The code base is no longer maintained.**
2 |
3 | **Please check here for an improved implementation of MUNIT: https://github.com/NVlabs/imaginaire/tree/master/projects/munit**
4 |
5 | [](https://raw.githubusercontent.com/NVIDIA/FastPhotoStyle/master/LICENSE.md)
6 | 
7 | 
8 | ## MUNIT: Multimodal UNsupervised Image-to-image Translation
9 |
10 | ### License
11 |
12 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
13 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
14 |
15 | For commercial use, please consult [NVIDIA Research Inquiries](https://www.nvidia.com/en-us/research/inquiries/).
16 |
17 | ### Code usage
18 |
19 | Please check out the [user manual page](USAGE.md).
20 |
21 | ### Paper
22 |
23 | [Xun Huang](http://www.cs.cornell.edu/~xhuang/), [Ming-Yu Liu](http://mingyuliu.net/), [Serge Belongie](https://vision.cornell.edu/se3/people/serge-belongie/), [Jan Kautz](http://jankautz.com/), "[Multimodal Unsupervised Image-to-Image Translation](https://arxiv.org/abs/1804.04732)", ECCV 2018
24 |
25 | ### Results Video
26 | [](https://youtu.be/ab64TWzWn40)
27 |
28 | ### Edges to Shoes/handbags Translation
29 |
30 | 
31 |
32 | ### Animal Image Translation
33 |
34 | 
35 |
36 | ### Street Scene Translation
37 |
38 | 
39 |
40 | ### Yosemite Summer to Winter Translation (HD)
41 |
42 | 
43 |
44 | ### Example-guided Image Translation
45 |
46 | 
47 |
48 | ### Other Implementations
49 |
50 | [MUNIT-Tensorflow](https://github.com/taki0112/MUNIT-Tensorflow) by [Junho Kim](https://github.com/taki0112)
51 |
52 | [MUNIT-keras](https://github.com/shaoanlu/MUNIT-keras) by [shaoanlu](https://github.com/shaoanlu)
53 |
54 | ### Citation
55 |
56 | If you find this code useful for your research, please cite our paper:
57 |
58 | ```
59 | @inproceedings{huang2018munit,
60 | title={Multimodal Unsupervised Image-to-image Translation},
61 | author={Huang, Xun and Liu, Ming-Yu and Belongie, Serge and Kautz, Jan},
62 | booktitle={ECCV},
63 | year={2018}
64 | }
65 | ```
66 |
67 |
68 |
--------------------------------------------------------------------------------
/TUTORIAL.md:
--------------------------------------------------------------------------------
1 | [](https://raw.githubusercontent.com/NVIDIA/FastPhotoStyle/master/LICENSE.md)
2 | 
3 | 
4 | ## MUNIT Tutorial
5 |
6 | In this short tutorial, we will guide you through setting up the system environment for running the MUNIT, which stands for multimodal unsupervised image-to-image translation, software and then show several usage examples.
7 |
8 | ### Background
9 |
10 | Most of the existing unsupervised/unpaired image-to-image translation algorithms assume a unimodal mapping function between two image domains. That is for a given input image in domain A, the model can only map it to one corresponding image in domain B. This is undesired since in many cases the mapping function should be multimodal or many-to-many. For example, for an input summer image, a summer-to-winter translation model should be able to synthesize various winter images that would correspond to the input summer image. These images could differ in amount of snow accumulation but they all represent valid translations of the input image. In the most idea case, given an input image, an image translation model should be able to map the input image to a distribution of output image. This is precisely the goal of MUNIT.
11 |
12 | ### Algorithm
13 |
14 |
15 |
16 | MUNIT is based on the partially-shared latent space assumption as illustrated in (a) of the above image. Basically, it assumes that latent representation of an image can be decomposed into two parts where one represents content of the image that is shared across domains, while the other represents style of the image that is not-shared across domains. To realize this assumption, MUNIT uses 3 networks for each domain, which are
17 |
18 | 1. content encoder (for extracting a domain-shared latent code, content code)
19 | 2. style encoder (for extracting a domain-specific latent code, style code)
20 | 3. decoder (for generating an image using a content code and a style code)
21 |
22 | In the test time as illustrated in (b) of the above image, when we want to translate an input image in the 1st domain (source domain) to a corresponding image in the 2nd domain (target domain). MUNIT first uses the content encoder in the source domain to extract a content codes, combines it with a randomly sampled style code from the target domain, and feed them to the decoder in the target domain to generate the translation. By sampling different style codes, MUNIT generates different translations. Since the style space is a continuous space, MUNIT essentially maps an input image in the source domain to a distribution of images in the target domain.
23 |
24 | ### Requirments
25 |
26 |
27 | - Hardware: PC with NVIDIA Titan GPU. For large resolution images, you need NVIDIA Tesla P100 or V100 GPUs, which have 16GB+ GPU memory.
28 | - Software: *Ubuntu 16.04*, *CUDA 9.1*, *Anaconda3*, *pytorch 0.4.1*
29 | - System package
30 | - `sudo apt-get install -y axel imagemagick` (Only used for demo)
31 | - Python package
32 | - `conda install pytorch=0.4.1 torchvision cuda91 -y -c pytorch`
33 | - `conda install -y -c anaconda pip`
34 | - `conda install -y -c anaconda pyyaml`
35 | - `pip install tensorboard tensorboardX`
36 |
37 | ### Docker Image
38 |
39 | We also provide a [Dockerfile](Dockerfile) for building an environment for running the MUNIT code.
40 |
41 | 1. Install docker-ce. Follow the instruction in the [Docker page](https://docs.docker.com/install/linux/docker-ce/ubuntu/#install-docker-ce-1)
42 | 2. Install nvidia-docker. Follow the instruction in the [NVIDIA-DOCKER README page](https://github.com/NVIDIA/nvidia-docker).
43 | 3. Build the docker image `docker build -t your-docker-image:v1.0 .`
44 | 4. Run an interactive session `docker run -v YOUR_PATH:YOUR_PATH --runtime=nvidia -i -t your-docker-image:v1.0 /bin/bash`
45 | 5. `cd YOUR_PATH`
46 | 6. Follow the rest of the tutorial.
47 |
48 | ### Training
49 |
50 | We provide several training scripts as usage examples. They are located under `scripts` folder.
51 | - `bash scripts/demo_train_edges2handbags.sh` to train a model for multimodal sketches of handbags to images of handbags translation.
52 | - `bash scripts/demo_train_edges2shoes.sh` to train a model for multimodal sketches of shoes to images of shoes translation.
53 | - `bash scripts/demo_train_summer2winter_yosemite256.sh` to train a model for multimodal Yosemite summer 256x256 images to Yosemite winter 256x256 image translation.
54 |
55 | If you break down the command lines in the scripts, you will find that to train a multimodal unsupervised image-to-image translation model you have to do
56 |
57 | 1. Download the dataset you want to use.
58 |
59 | 3. Setup the yaml file. Check out `configs/demo_edges2handbags_folder.yaml` for folder-based dataset organization. Change the `data_root` field to the path of your downloaded dataset. For list-based dataset organization, check out `configs/demo_edges2handbags_list.yaml`
60 |
61 | 3. Start training
62 | ```
63 | python train.py --config configs/edges2handbags_folder.yaml
64 | ```
65 |
66 | 4. Intermediate image outputs and model binary files are stored in `outputs/edges2handbags_folder`
67 |
68 | ### Testing
69 |
70 | First, download our pretrained models for the edges2shoes task and put them in `models` folder.
71 |
72 | ### Pretrained models
73 |
74 | | Dataset | Model Link |
75 | |-------------|----------------|
76 | | edges2shoes | [model](https://drive.google.com/drive/folders/10IEa7gibOWmQQuJUIUOkh-CV4cm6k8__?usp=sharing) |
77 | | edges2handbags | coming soon |
78 | | summer2winter_yosemite256 | coming soon |
79 |
80 |
81 | #### Multimodal Translation
82 |
83 | Run the following command to translate edges to shoes
84 |
85 | python test.py --config configs/edges2shoes_folder.yaml --input inputs/edges2shoes_edge.jpg --output_folder results/edges2shoes --checkpoint models/edges2shoes.pt --a2b 1
86 |
87 | The results are stored in `results/edges2shoes` folder. By default, it produces 10 random translation outputs.
88 |
89 | | Input | Translation 1 | Translation 2 | Translation 3 | Translation 4 | Translation 5 |
90 | |-------|---------------|---------------|---------------|---------------|---------------|
91 | |
|
|
|
|
|
|
92 |
93 |
94 | #### Example-guided Translation
95 |
96 | The above command outputs diverse shoes from an edge input. In addition, it is possible to control the style of output using an example shoe image.
97 |
98 | python test.py --config configs/edges2shoes_folder.yaml --input inputs/edges2shoes_edge.jpg --output_folder results --checkpoint models/edges2shoes.pt --a2b 1 --style inputs/edges2shoes_shoe.jpg
99 |
100 | | Input Photo | Style Photo | Output Photo |
101 | |-------|---------------|---------------|
102 | |
|
|
|
103 |
104 | ### Yosemite Summer2Winter HD dataset
105 |
106 | Coming soon.
107 |
108 |
109 |
--------------------------------------------------------------------------------
/USAGE.md:
--------------------------------------------------------------------------------
1 | [](https://raw.githubusercontent.com/NVIDIA/FastPhotoStyle/master/LICENSE.md)
2 | 
3 | 
4 | ## MUNIT: Multimodal UNsupervised Image-to-image Translation
5 |
6 | ### License
7 |
8 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
9 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
10 |
11 | ### Dependency
12 |
13 |
14 | pytorch, yaml, tensorboard (from https://github.com/dmlc/tensorboard), and tensorboardX (from https://github.com/lanpa/tensorboard-pytorch).
15 |
16 |
17 | The code base was developed using Anaconda with the following packages.
18 | ```
19 | conda install pytorch=0.4.1 torchvision cuda91 -c pytorch;
20 | conda install -y -c anaconda pip;
21 | conda install -y -c anaconda pyyaml;
22 | pip install tensorboard tensorboardX;
23 | ```
24 |
25 | We also provide a [Dockerfile](Dockerfile) for building an environment for running the MUNIT code.
26 |
27 | ### Example Usage
28 |
29 | #### Testing
30 |
31 | First, download the [pretrained models](https://drive.google.com/drive/folders/10IEa7gibOWmQQuJUIUOkh-CV4cm6k8__?usp=sharing) and put them in `models` folder.
32 |
33 | ###### Multimodal Translation
34 |
35 | Run the following command to translate edges to shoes
36 |
37 | python test.py --config configs/edges2shoes_folder.yaml --input inputs/edge.jpg --output_folder outputs --checkpoint models/edges2shoes.pt --a2b 1
38 |
39 | The results are stored in `outputs` folder. By default, it produces 10 random translation outputs.
40 |
41 | ###### Example-guided Translation
42 |
43 | The above command outputs diverse shoes from an edge input. In addition, it is possible to control the style of output using an example shoe image.
44 |
45 | python test.py --config configs/edges2shoes_folder.yaml --input inputs/edge.jpg --output_folder outputs --checkpoint models/edges2shoes.pt --a2b 1 --style inputs/shoe.jpg
46 |
47 |
48 | #### Training
49 | 1. Download the dataset you want to use. For example, you can use the edges2shoes dataset provided by [Zhu et al.](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)
50 |
51 | 3. Setup the yaml file. Check out `configs/edges2handbags_folder.yaml` for folder-based dataset organization. Change the `data_root` field to the path of your downloaded dataset. For list-based dataset organization, check out `configs/edges2handbags_list.yaml`
52 |
53 | 3. Start training
54 | ```
55 | python train.py --config configs/edges2handbags_folder.yaml
56 | ```
57 |
58 | 4. Intermediate image outputs and model binary files are stored in `outputs/edges2handbags_folder`
59 |
--------------------------------------------------------------------------------
/configs/demo_edges2handbags_folder.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | # logger options
5 | image_save_iter: 1000 # How often do you want to save output images during training
6 | image_display_iter: 100 # How often do you want to display output images during training
7 | display_size: 8 # How many images do you want to display each time
8 | snapshot_save_iter: 10000 # How often do you want to save trained models
9 | log_iter: 1 # How often do you want to log the training stats
10 |
11 | # optimization options
12 | max_iter: 1000000 # maximum number of training iterations
13 | batch_size: 1 # batch size
14 | weight_decay: 0.0001 # weight decay
15 | beta1: 0.5 # Adam parameter
16 | beta2: 0.999 # Adam parameter
17 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal]
18 | lr: 0.0001 # initial learning rate
19 | lr_policy: step # learning rate scheduler
20 | step_size: 100000 # how often to decay learning rate
21 | gamma: 0.5 # how much to decay learning rate
22 | gan_w: 1 # weight of adversarial loss
23 | recon_x_w: 10 # weight of image reconstruction loss
24 | recon_s_w: 1 # weight of style reconstruction loss
25 | recon_c_w: 1 # weight of content reconstruction loss
26 | recon_x_cyc_w: 0 # weight of explicit style augmented cycle consistency loss
27 | vgg_w: 0 # weight of domain-invariant perceptual loss
28 |
29 | # model options
30 | gen:
31 | dim: 64 # number of filters in the bottommost layer
32 | mlp_dim: 256 # number of filters in MLP
33 | style_dim: 8 # length of style code
34 | activ: relu # activation function [relu/lrelu/prelu/selu/tanh]
35 | n_downsample: 2 # number of downsampling layers in content encoder
36 | n_res: 4 # number of residual blocks in content encoder/decoder
37 | pad_type: reflect # padding type [zero/reflect]
38 | dis:
39 | dim: 64 # number of filters in the bottommost layer
40 | norm: none # normalization layer [none/bn/in/ln]
41 | activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh]
42 | n_layer: 4 # number of layers in D
43 | gan_type: lsgan # GAN loss [lsgan/nsgan]
44 | num_scales: 3 # number of scales
45 | pad_type: reflect # padding type [zero/reflect]
46 |
47 | # data options
48 | input_dim_a: 3 # number of image channels [1/3]
49 | input_dim_b: 3 # number of image channels [1/3]
50 | num_workers: 8 # number of data loading threads
51 | new_size: 256 # first resize the shortest image side to this size
52 | crop_image_height: 256 # random crop image of this height
53 | crop_image_width: 256 # random crop image of this width
54 | data_root: ./datasets/demo_edges2handbags/ # dataset folder location
--------------------------------------------------------------------------------
/configs/demo_edges2handbags_list.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | # logger options
5 | image_save_iter: 1000 # How often do you want to save output images during training
6 | image_display_iter: 100 # How often do you want to display output images during training
7 | display_size: 8 # How many images do you want to display each time
8 | snapshot_save_iter: 10000 # How often do you want to save trained models
9 | log_iter: 1 # How often do you want to log the training stats
10 |
11 | # optimization options
12 | max_iter: 1000000 # maximum number of training iterations
13 | batch_size: 1 # batch size
14 | weight_decay: 0.0001 # weight decay
15 | beta1: 0.5 # Adam parameter
16 | beta2: 0.999 # Adam parameter
17 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal]
18 | lr: 0.0001 # initial learning rate
19 | lr_policy: step # learning rate scheduler
20 | step_size: 100000 # how often to decay learning rate
21 | gamma: 0.5 # how much to decay learning rate
22 | gan_w: 1 # weight of adversarial loss
23 | recon_x_w: 10 # weight of image reconstruction loss
24 | recon_s_w: 1 # weight of style reconstruction loss
25 | recon_c_w: 1 # weight of content reconstruction loss
26 | recon_x_cyc_w: 0 # weight of explicit style augmented cycle consistency loss
27 | vgg_w: 0 # weight of domain-invariant perceptual loss
28 |
29 | # model options
30 | gen:
31 | dim: 64 # number of filters in the bottommost layer
32 | mlp_dim: 256 # number of filters in MLP
33 | style_dim: 8 # length of style code
34 | activ: relu # activation function [relu/lrelu/prelu/selu/tanh]
35 | n_downsample: 2 # number of downsampling layers in content encoder
36 | n_res: 4 # number of residual blocks in content encoder/decoder
37 | pad_type: reflect # padding type [zero/reflect]
38 | dis:
39 | dim: 64 # number of filters in the bottommost layer
40 | norm: none # normalization layer [none/bn/in/ln]
41 | activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh]
42 | n_layer: 4 # number of layers in D
43 | gan_type: lsgan # GAN loss [lsgan/nsgan]
44 | num_scales: 3 # number of scales
45 | pad_type: reflect # padding type [zero/reflect]
46 |
47 | # data options
48 | input_dim_a: 3 # number of image channels [1/3]
49 | input_dim_b: 3 # number of image channels [1/3]
50 | num_workers: 8 # number of data loading threads
51 | new_size: 256 # first resize the shortest image side to this size
52 | crop_image_height: 256 # random crop image of this height
53 | crop_image_width: 256 # random crop image of this width
54 |
55 | data_folder_train_a: ./datasets/demo_edges2handbags/trainA
56 | data_list_train_a: ./datasets/demo_edges2handbags/list_trainA.txt
57 | data_folder_test_a: ./datasets/demo_edges2handbags/testA
58 | data_list_test_a: ./datasets/demo_edges2handbags/list_testA.txt
59 | data_folder_train_b: ./datasets/demo_edges2handbags/trainB
60 | data_list_train_b: ./datasets/demo_edges2handbags/list_trainB.txt
61 | data_folder_test_b: ./datasets/demo_edges2handbags/testB
62 | data_list_test_b: ./datasets/demo_edges2handbags/list_testB.txt
63 |
--------------------------------------------------------------------------------
/configs/edges2handbags_folder.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | # logger options
5 | image_save_iter: 10000 # How often do you want to save output images during training
6 | image_display_iter: 500 # How often do you want to display output images during training
7 | display_size: 16 # How many images do you want to display each time
8 | snapshot_save_iter: 10000 # How often do you want to save trained models
9 | log_iter: 10 # How often do you want to log the training stats
10 |
11 | # optimization options
12 | max_iter: 1000000 # maximum number of training iterations
13 | batch_size: 1 # batch size
14 | weight_decay: 0.0001 # weight decay
15 | beta1: 0.5 # Adam parameter
16 | beta2: 0.999 # Adam parameter
17 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal]
18 | lr: 0.0001 # initial learning rate
19 | lr_policy: step # learning rate scheduler
20 | step_size: 100000 # how often to decay learning rate
21 | gamma: 0.5 # how much to decay learning rate
22 | gan_w: 1 # weight of adversarial loss
23 | recon_x_w: 10 # weight of image reconstruction loss
24 | recon_s_w: 1 # weight of style reconstruction loss
25 | recon_c_w: 1 # weight of content reconstruction loss
26 | recon_x_cyc_w: 0 # weight of explicit style augmented cycle consistency loss
27 | vgg_w: 0 # weight of domain-invariant perceptual loss
28 |
29 | # model options
30 | gen:
31 | dim: 64 # number of filters in the bottommost layer
32 | mlp_dim: 256 # number of filters in MLP
33 | style_dim: 8 # length of style code
34 | activ: relu # activation function [relu/lrelu/prelu/selu/tanh]
35 | n_downsample: 2 # number of downsampling layers in content encoder
36 | n_res: 4 # number of residual blocks in content encoder/decoder
37 | pad_type: reflect # padding type [zero/reflect]
38 | dis:
39 | dim: 64 # number of filters in the bottommost layer
40 | norm: none # normalization layer [none/bn/in/ln]
41 | activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh]
42 | n_layer: 4 # number of layers in D
43 | gan_type: lsgan # GAN loss [lsgan/nsgan]
44 | num_scales: 3 # number of scales
45 | pad_type: reflect # padding type [zero/reflect]
46 |
47 | # data options
48 | input_dim_a: 3 # number of image channels [1/3]
49 | input_dim_b: 3 # number of image channels [1/3]
50 | num_workers: 8 # number of data loading threads
51 | new_size: 256 # first resize the shortest image side to this size
52 | crop_image_height: 256 # random crop image of this height
53 | crop_image_width: 256 # random crop image of this width
54 | data_root: ./datasets/edges2handbags/ # dataset folder location
--------------------------------------------------------------------------------
/configs/edges2shoes_folder.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | # logger options
5 | image_save_iter: 10000 # How often do you want to save output images during training
6 | image_display_iter: 500 # How often do you want to display output images during training
7 | display_size: 16 # How many images do you want to display each time
8 | snapshot_save_iter: 10000 # How often do you want to save trained models
9 | log_iter: 10 # How often do you want to log the training stats
10 |
11 | # optimization options
12 | max_iter: 1000000 # maximum number of training iterations
13 | batch_size: 1 # batch size
14 | weight_decay: 0.0001 # weight decay
15 | beta1: 0.5 # Adam parameter
16 | beta2: 0.999 # Adam parameter
17 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal]
18 | lr: 0.0001 # initial learning rate
19 | lr_policy: step # learning rate scheduler
20 | step_size: 100000 # how often to decay learning rate
21 | gamma: 0.5 # how much to decay learning rate
22 | gan_w: 1 # weight of adversarial loss
23 | recon_x_w: 10 # weight of image reconstruction loss
24 | recon_s_w: 1 # weight of style reconstruction loss
25 | recon_c_w: 1 # weight of content reconstruction loss
26 | recon_x_cyc_w: 0 # weight of explicit style augmented cycle consistency loss
27 | vgg_w: 0 # weight of domain-invariant perceptual loss
28 |
29 | # model options
30 | gen:
31 | dim: 64 # number of filters in the bottommost layer
32 | mlp_dim: 256 # number of filters in MLP
33 | style_dim: 8 # length of style code
34 | activ: relu # activation function [relu/lrelu/prelu/selu/tanh]
35 | n_downsample: 2 # number of downsampling layers in content encoder
36 | n_res: 4 # number of residual blocks in content encoder/decoder
37 | pad_type: reflect # padding type [zero/reflect]
38 | dis:
39 | dim: 64 # number of filters in the bottommost layer
40 | norm: none # normalization layer [none/bn/in/ln]
41 | activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh]
42 | n_layer: 4 # number of layers in D
43 | gan_type: lsgan # GAN loss [lsgan/nsgan]
44 | num_scales: 3 # number of scales
45 | pad_type: reflect # padding type [zero/reflect]
46 |
47 | # data options
48 | input_dim_a: 3 # number of image channels [1/3]
49 | input_dim_b: 3 # number of image channels [1/3]
50 | num_workers: 8 # number of data loading threads
51 | new_size: 256 # first resize the shortest image side to this size
52 | crop_image_height: 256 # random crop image of this height
53 | crop_image_width: 256 # random crop image of this width
54 | data_root: ./datasets/edges2shoes/ # dataset folder location
--------------------------------------------------------------------------------
/configs/summer2winter_yosemite256_folder.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | # logger options
5 | image_save_iter: 10000 # How often do you want to save output images during training
6 | image_display_iter: 100 # How often do you want to display output images during training
7 | display_size: 16 # How many images do you want to display each time
8 | snapshot_save_iter: 10000 # How often do you want to save trained models
9 | log_iter: 1 # How often do you want to log the training stats
10 |
11 | # optimization options
12 | max_iter: 1000000 # maximum number of training iterations
13 | batch_size: 1 # batch size
14 | weight_decay: 0.0001 # weight decay
15 | beta1: 0.5 # Adam parameter
16 | beta2: 0.999 # Adam parameter
17 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal]
18 | lr: 0.0001 # initial learning rate
19 | lr_policy: step # learning rate scheduler
20 | step_size: 100000 # how often to decay learning rate
21 | gamma: 0.5 # how much to decay learning rate
22 | gan_w: 1 # weight of adversarial loss
23 | recon_x_w: 10 # weight of image reconstruction loss
24 | recon_s_w: 1 # weight of style reconstruction loss
25 | recon_c_w: 1 # weight of content reconstruction loss
26 | recon_x_cyc_w: 10 # weight of explicit style augmented cycle consistency loss
27 | vgg_w: 0 # weight of domain-invariant perceptual loss
28 |
29 | # model options
30 | gen:
31 | dim: 64 # number of filters in the bottommost layer
32 | mlp_dim: 256 # number of filters in MLP
33 | style_dim: 8 # length of style code
34 | activ: relu # activation function [relu/lrelu/prelu/selu/tanh]
35 | n_downsample: 2 # number of downsampling layers in content encoder
36 | n_res: 4 # number of residual blocks in content encoder/decoder
37 | pad_type: reflect # padding type [zero/reflect]
38 | dis:
39 | dim: 64 # number of filters in the bottommost layer
40 | norm: none # normalization layer [none/bn/in/ln]
41 | activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh]
42 | n_layer: 4 # number of layers in D
43 | gan_type: lsgan # GAN loss [lsgan/nsgan]
44 | num_scales: 3 # number of scales
45 | pad_type: reflect # padding type [zero/reflect]
46 |
47 | # data options
48 | input_dim_a: 3 # number of image channels [1/3]
49 | input_dim_b: 3 # number of image channels [1/3]
50 | num_workers: 8 # number of data loading threads
51 | new_size: 256 # first resize the shortest image side to this size
52 | crop_image_height: 256 # random crop image of this height
53 | crop_image_width: 256 # random crop image of this width
54 | data_root: ./datasets/summer2winter_yosemite256/summer2winter_yosemite # dataset folder location
--------------------------------------------------------------------------------
/configs/synthia2cityscape_folder.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | # logger options
5 | image_save_iter: 10000 # How often do you want to save output images during training
6 | image_display_iter: 500 # How often do you want to display output images during training
7 | display_size: 16 # How many images do you want to display each time
8 | snapshot_save_iter: 10000 # How often do you want to save trained models
9 | log_iter: 10 # How often do you want to log the training stats
10 |
11 | # optimization options
12 | max_iter: 1000000 # maximum number of training iterations
13 | batch_size: 1 # batch size
14 | weight_decay: 0.0001 # weight decay
15 | beta1: 0.5 # Adam parameter
16 | beta2: 0.999 # Adam parameter
17 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal]
18 | lr: 0.0001 # initial learning rate
19 | lr_policy: step # learning rate scheduler
20 | step_size: 100000 # how often to decay learning rate
21 | gamma: 0.5 # how much to decay learning rate
22 | gan_w: 1 # weight of adversarial loss
23 | recon_x_w: 10 # weight of image reconstruction loss
24 | recon_s_w: 1 # weight of style reconstruction loss
25 | recon_c_w: 1 # weight of content reconstruction loss
26 | recon_x_cyc_w: 10 # weight of explicit style augmented cycle consistency loss
27 | vgg_w: 1 # weight of domain-invariant perceptual loss
28 |
29 | # model options
30 | gen:
31 | dim: 64 # number of filters in the bottommost layer
32 | mlp_dim: 256 # number of filters in MLP
33 | style_dim: 8 # length of style code
34 | activ: relu # activation function [relu/lrelu/prelu/selu/tanh]
35 | n_downsample: 2 # number of downsampling layers in content encoder
36 | n_res: 4 # number of residual blocks in content encoder/decoder
37 | pad_type: reflect # padding type [zero/reflect]
38 | dis:
39 | dim: 64 # number of filters in the bottommost layer
40 | norm: none # normalization layer [none/bn/in/ln]
41 | activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh]
42 | n_layer: 4 # number of layers in D
43 | gan_type: lsgan # GAN loss [lsgan/nsgan]
44 | num_scales: 3 # number of scales
45 | pad_type: reflect # padding type [zero/reflect]
46 |
47 | # data options
48 | input_dim_a: 3 # number of image channels [1/3]
49 | input_dim_b: 3 # number of image channels [1/3]
50 | num_workers: 8 # number of data loading threads
51 | new_size: 512 # first resize the shortest image side to this size
52 | crop_image_height: 512 # random crop image of this height
53 | crop_image_width: 512 # random crop image of this width
54 | data_root: ./datasets/synthia2cityscape/ # dataset folder location
55 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 | import torch.utils.data as data
6 | import os.path
7 |
8 | def default_loader(path):
9 | return Image.open(path).convert('RGB')
10 |
11 |
12 | def default_flist_reader(flist):
13 | """
14 | flist format: impath label\nimpath label\n ...(same to caffe's filelist)
15 | """
16 | imlist = []
17 | with open(flist, 'r') as rf:
18 | for line in rf.readlines():
19 | impath = line.strip()
20 | imlist.append(impath)
21 |
22 | return imlist
23 |
24 |
25 | class ImageFilelist(data.Dataset):
26 | def __init__(self, root, flist, transform=None,
27 | flist_reader=default_flist_reader, loader=default_loader):
28 | self.root = root
29 | self.imlist = flist_reader(flist)
30 | self.transform = transform
31 | self.loader = loader
32 |
33 | def __getitem__(self, index):
34 | impath = self.imlist[index]
35 | img = self.loader(os.path.join(self.root, impath))
36 | if self.transform is not None:
37 | img = self.transform(img)
38 |
39 | return img
40 |
41 | def __len__(self):
42 | return len(self.imlist)
43 |
44 |
45 | class ImageLabelFilelist(data.Dataset):
46 | def __init__(self, root, flist, transform=None,
47 | flist_reader=default_flist_reader, loader=default_loader):
48 | self.root = root
49 | self.imlist = flist_reader(os.path.join(self.root, flist))
50 | self.transform = transform
51 | self.loader = loader
52 | self.classes = sorted(list(set([path.split('/')[0] for path in self.imlist])))
53 | self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))}
54 | self.imgs = [(impath, self.class_to_idx[impath.split('/')[0]]) for impath in self.imlist]
55 |
56 | def __getitem__(self, index):
57 | impath, label = self.imgs[index]
58 | img = self.loader(os.path.join(self.root, impath))
59 | if self.transform is not None:
60 | img = self.transform(img)
61 | return img, label
62 |
63 | def __len__(self):
64 | return len(self.imgs)
65 |
66 | ###############################################################################
67 | # Code from
68 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
69 | # Modified the original code so that it also loads images from the current
70 | # directory as well as the subdirectories
71 | ###############################################################################
72 |
73 | import torch.utils.data as data
74 |
75 | from PIL import Image
76 | import os
77 | import os.path
78 |
79 | IMG_EXTENSIONS = [
80 | '.jpg', '.JPG', '.jpeg', '.JPEG',
81 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
82 | ]
83 |
84 |
85 | def is_image_file(filename):
86 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
87 |
88 |
89 | def make_dataset(dir):
90 | images = []
91 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
92 |
93 | for root, _, fnames in sorted(os.walk(dir)):
94 | for fname in fnames:
95 | if is_image_file(fname):
96 | path = os.path.join(root, fname)
97 | images.append(path)
98 |
99 | return images
100 |
101 |
102 | class ImageFolder(data.Dataset):
103 |
104 | def __init__(self, root, transform=None, return_paths=False,
105 | loader=default_loader):
106 | imgs = sorted(make_dataset(root))
107 | if len(imgs) == 0:
108 | raise(RuntimeError("Found 0 images in: " + root + "\n"
109 | "Supported image extensions are: " +
110 | ",".join(IMG_EXTENSIONS)))
111 |
112 | self.root = root
113 | self.imgs = imgs
114 | self.transform = transform
115 | self.return_paths = return_paths
116 | self.loader = loader
117 |
118 | def __getitem__(self, index):
119 | path = self.imgs[index]
120 | img = self.loader(path)
121 | if self.transform is not None:
122 | img = self.transform(img)
123 | if self.return_paths:
124 | return img, path
125 | else:
126 | return img
127 |
128 | def __len__(self):
129 | return len(self.imgs)
130 |
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/list_testA.txt:
--------------------------------------------------------------------------------
1 | ./00002.jpg
2 | ./00006.jpg
3 | ./00004.jpg
4 | ./00000.jpg
5 | ./00007.jpg
6 | ./00001.jpg
7 | ./00003.jpg
8 | ./00005.jpg
9 | ./00008.jpg
10 |
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/list_testB.txt:
--------------------------------------------------------------------------------
1 | ./00002.jpg
2 | ./00006.jpg
3 | ./00004.jpg
4 | ./00000.jpg
5 | ./00007.jpg
6 | ./00001.jpg
7 | ./00003.jpg
8 | ./00005.jpg
9 | ./00008.jpg
10 |
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/list_trainA.txt:
--------------------------------------------------------------------------------
1 | ./00002.jpg
2 | ./00006.jpg
3 | ./00004.jpg
4 | ./00000.jpg
5 | ./00010.jpg
6 | ./00014.jpg
7 | ./00008.jpg
8 | ./00012.jpg
9 |
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/list_trainB.txt:
--------------------------------------------------------------------------------
1 | ./00009.jpg
2 | ./00007.jpg
3 | ./00001.jpg
4 | ./00003.jpg
5 | ./00013.jpg
6 | ./00005.jpg
7 | ./00011.jpg
8 | ./00015.jpg
9 |
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/testA/00000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testA/00000.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/testA/00001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testA/00001.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/testA/00002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testA/00002.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/testA/00003.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testA/00003.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/testA/00004.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testA/00004.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/testA/00005.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testA/00005.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/testA/00006.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testA/00006.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/testA/00007.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testA/00007.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/testA/00008.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testA/00008.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/testB/00000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testB/00000.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/testB/00001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testB/00001.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/testB/00002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testB/00002.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/testB/00003.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testB/00003.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/testB/00004.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testB/00004.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/testB/00005.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testB/00005.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/testB/00006.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testB/00006.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/testB/00007.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testB/00007.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/testB/00008.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/testB/00008.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/trainA/00000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainA/00000.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/trainA/00002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainA/00002.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/trainA/00004.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainA/00004.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/trainA/00006.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainA/00006.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/trainA/00008.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainA/00008.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/trainA/00010.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainA/00010.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/trainA/00012.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainA/00012.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/trainA/00014.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainA/00014.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/trainB/00001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainB/00001.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/trainB/00003.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainB/00003.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/trainB/00005.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainB/00005.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/trainB/00007.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainB/00007.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/trainB/00009.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainB/00009.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/trainB/00011.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainB/00011.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/trainB/00013.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainB/00013.jpg
--------------------------------------------------------------------------------
/datasets/demo_edges2handbags/trainB/00015.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/datasets/demo_edges2handbags/trainB/00015.jpg
--------------------------------------------------------------------------------
/docs/munit_assumption.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/docs/munit_assumption.jpg
--------------------------------------------------------------------------------
/inputs/edges2handbags_edge.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/inputs/edges2handbags_edge.jpg
--------------------------------------------------------------------------------
/inputs/edges2handbags_handbag.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/inputs/edges2handbags_handbag.jpg
--------------------------------------------------------------------------------
/inputs/edges2shoes_edge.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/inputs/edges2shoes_edge.jpg
--------------------------------------------------------------------------------
/inputs/edges2shoes_shoe.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/inputs/edges2shoes_shoe.jpg
--------------------------------------------------------------------------------
/networks.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 | from torch import nn
6 | from torch.autograd import Variable
7 | import torch
8 | import torch.nn.functional as F
9 | try:
10 | from itertools import izip as zip
11 | except ImportError: # will be 3.x series
12 | pass
13 |
14 | ##################################################################################
15 | # Discriminator
16 | ##################################################################################
17 |
18 | class MsImageDis(nn.Module):
19 | # Multi-scale discriminator architecture
20 | def __init__(self, input_dim, params):
21 | super(MsImageDis, self).__init__()
22 | self.n_layer = params['n_layer']
23 | self.gan_type = params['gan_type']
24 | self.dim = params['dim']
25 | self.norm = params['norm']
26 | self.activ = params['activ']
27 | self.num_scales = params['num_scales']
28 | self.pad_type = params['pad_type']
29 | self.input_dim = input_dim
30 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
31 | self.cnns = nn.ModuleList()
32 | for _ in range(self.num_scales):
33 | self.cnns.append(self._make_net())
34 |
35 | def _make_net(self):
36 | dim = self.dim
37 | cnn_x = []
38 | cnn_x += [Conv2dBlock(self.input_dim, dim, 4, 2, 1, norm='none', activation=self.activ, pad_type=self.pad_type)]
39 | for i in range(self.n_layer - 1):
40 | cnn_x += [Conv2dBlock(dim, dim * 2, 4, 2, 1, norm=self.norm, activation=self.activ, pad_type=self.pad_type)]
41 | dim *= 2
42 | cnn_x += [nn.Conv2d(dim, 1, 1, 1, 0)]
43 | cnn_x = nn.Sequential(*cnn_x)
44 | return cnn_x
45 |
46 | def forward(self, x):
47 | outputs = []
48 | for model in self.cnns:
49 | outputs.append(model(x))
50 | x = self.downsample(x)
51 | return outputs
52 |
53 | def calc_dis_loss(self, input_fake, input_real):
54 | # calculate the loss to train D
55 | outs0 = self.forward(input_fake)
56 | outs1 = self.forward(input_real)
57 | loss = 0
58 |
59 | for it, (out0, out1) in enumerate(zip(outs0, outs1)):
60 | if self.gan_type == 'lsgan':
61 | loss += torch.mean((out0 - 0)**2) + torch.mean((out1 - 1)**2)
62 | elif self.gan_type == 'nsgan':
63 | all0 = Variable(torch.zeros_like(out0.data).cuda(), requires_grad=False)
64 | all1 = Variable(torch.ones_like(out1.data).cuda(), requires_grad=False)
65 | loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all0) +
66 | F.binary_cross_entropy(F.sigmoid(out1), all1))
67 | else:
68 | assert 0, "Unsupported GAN type: {}".format(self.gan_type)
69 | return loss
70 |
71 | def calc_gen_loss(self, input_fake):
72 | # calculate the loss to train G
73 | outs0 = self.forward(input_fake)
74 | loss = 0
75 | for it, (out0) in enumerate(outs0):
76 | if self.gan_type == 'lsgan':
77 | loss += torch.mean((out0 - 1)**2) # LSGAN
78 | elif self.gan_type == 'nsgan':
79 | all1 = Variable(torch.ones_like(out0.data).cuda(), requires_grad=False)
80 | loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all1))
81 | else:
82 | assert 0, "Unsupported GAN type: {}".format(self.gan_type)
83 | return loss
84 |
85 | ##################################################################################
86 | # Generator
87 | ##################################################################################
88 |
89 | class AdaINGen(nn.Module):
90 | # AdaIN auto-encoder architecture
91 | def __init__(self, input_dim, params):
92 | super(AdaINGen, self).__init__()
93 | dim = params['dim']
94 | style_dim = params['style_dim']
95 | n_downsample = params['n_downsample']
96 | n_res = params['n_res']
97 | activ = params['activ']
98 | pad_type = params['pad_type']
99 | mlp_dim = params['mlp_dim']
100 |
101 | # style encoder
102 | self.enc_style = StyleEncoder(4, input_dim, dim, style_dim, norm='none', activ=activ, pad_type=pad_type)
103 |
104 | # content encoder
105 | self.enc_content = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type)
106 | self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, input_dim, res_norm='adain', activ=activ, pad_type=pad_type)
107 |
108 | # MLP to generate AdaIN parameters
109 | self.mlp = MLP(style_dim, self.get_num_adain_params(self.dec), mlp_dim, 3, norm='none', activ=activ)
110 |
111 | def forward(self, images):
112 | # reconstruct an image
113 | content, style_fake = self.encode(images)
114 | images_recon = self.decode(content, style_fake)
115 | return images_recon
116 |
117 | def encode(self, images):
118 | # encode an image to its content and style codes
119 | style_fake = self.enc_style(images)
120 | content = self.enc_content(images)
121 | return content, style_fake
122 |
123 | def decode(self, content, style):
124 | # decode content and style codes to an image
125 | adain_params = self.mlp(style)
126 | self.assign_adain_params(adain_params, self.dec)
127 | images = self.dec(content)
128 | return images
129 |
130 | def assign_adain_params(self, adain_params, model):
131 | # assign the adain_params to the AdaIN layers in model
132 | for m in model.modules():
133 | if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
134 | mean = adain_params[:, :m.num_features]
135 | std = adain_params[:, m.num_features:2*m.num_features]
136 | m.bias = mean.contiguous().view(-1)
137 | m.weight = std.contiguous().view(-1)
138 | if adain_params.size(1) > 2*m.num_features:
139 | adain_params = adain_params[:, 2*m.num_features:]
140 |
141 | def get_num_adain_params(self, model):
142 | # return the number of AdaIN parameters needed by the model
143 | num_adain_params = 0
144 | for m in model.modules():
145 | if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
146 | num_adain_params += 2*m.num_features
147 | return num_adain_params
148 |
149 |
150 | class VAEGen(nn.Module):
151 | # VAE architecture
152 | def __init__(self, input_dim, params):
153 | super(VAEGen, self).__init__()
154 | dim = params['dim']
155 | n_downsample = params['n_downsample']
156 | n_res = params['n_res']
157 | activ = params['activ']
158 | pad_type = params['pad_type']
159 |
160 | # content encoder
161 | self.enc = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type)
162 | self.dec = Decoder(n_downsample, n_res, self.enc.output_dim, input_dim, res_norm='in', activ=activ, pad_type=pad_type)
163 |
164 | def forward(self, images):
165 | # This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones.
166 | hiddens = self.encode(images)
167 | if self.training == True:
168 | noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device()))
169 | images_recon = self.decode(hiddens + noise)
170 | else:
171 | images_recon = self.decode(hiddens)
172 | return images_recon, hiddens
173 |
174 | def encode(self, images):
175 | hiddens = self.enc(images)
176 | noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device()))
177 | return hiddens, noise
178 |
179 | def decode(self, hiddens):
180 | images = self.dec(hiddens)
181 | return images
182 |
183 |
184 | ##################################################################################
185 | # Encoder and Decoders
186 | ##################################################################################
187 |
188 | class StyleEncoder(nn.Module):
189 | def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type):
190 | super(StyleEncoder, self).__init__()
191 | self.model = []
192 | self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
193 | for i in range(2):
194 | self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
195 | dim *= 2
196 | for i in range(n_downsample - 2):
197 | self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
198 | self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling
199 | self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
200 | self.model = nn.Sequential(*self.model)
201 | self.output_dim = dim
202 |
203 | def forward(self, x):
204 | return self.model(x)
205 |
206 | class ContentEncoder(nn.Module):
207 | def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type):
208 | super(ContentEncoder, self).__init__()
209 | self.model = []
210 | self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
211 | # downsampling blocks
212 | for i in range(n_downsample):
213 | self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
214 | dim *= 2
215 | # residual blocks
216 | self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
217 | self.model = nn.Sequential(*self.model)
218 | self.output_dim = dim
219 |
220 | def forward(self, x):
221 | return self.model(x)
222 |
223 | class Decoder(nn.Module):
224 | def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='relu', pad_type='zero'):
225 | super(Decoder, self).__init__()
226 |
227 | self.model = []
228 | # AdaIN residual blocks
229 | self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)]
230 | # upsampling blocks
231 | for i in range(n_upsample):
232 | self.model += [nn.Upsample(scale_factor=2),
233 | Conv2dBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)]
234 | dim //= 2
235 | # use reflection padding in the last conv layer
236 | self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
237 | self.model = nn.Sequential(*self.model)
238 |
239 | def forward(self, x):
240 | return self.model(x)
241 |
242 | ##################################################################################
243 | # Sequential Models
244 | ##################################################################################
245 | class ResBlocks(nn.Module):
246 | def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'):
247 | super(ResBlocks, self).__init__()
248 | self.model = []
249 | for i in range(num_blocks):
250 | self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)]
251 | self.model = nn.Sequential(*self.model)
252 |
253 | def forward(self, x):
254 | return self.model(x)
255 |
256 | class MLP(nn.Module):
257 | def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):
258 |
259 | super(MLP, self).__init__()
260 | self.model = []
261 | self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)]
262 | for i in range(n_blk - 2):
263 | self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)]
264 | self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations
265 | self.model = nn.Sequential(*self.model)
266 |
267 | def forward(self, x):
268 | return self.model(x.view(x.size(0), -1))
269 |
270 | ##################################################################################
271 | # Basic Blocks
272 | ##################################################################################
273 | class ResBlock(nn.Module):
274 | def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
275 | super(ResBlock, self).__init__()
276 |
277 | model = []
278 | model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
279 | model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
280 | self.model = nn.Sequential(*model)
281 |
282 | def forward(self, x):
283 | residual = x
284 | out = self.model(x)
285 | out += residual
286 | return out
287 |
288 | class Conv2dBlock(nn.Module):
289 | def __init__(self, input_dim ,output_dim, kernel_size, stride,
290 | padding=0, norm='none', activation='relu', pad_type='zero'):
291 | super(Conv2dBlock, self).__init__()
292 | self.use_bias = True
293 | # initialize padding
294 | if pad_type == 'reflect':
295 | self.pad = nn.ReflectionPad2d(padding)
296 | elif pad_type == 'replicate':
297 | self.pad = nn.ReplicationPad2d(padding)
298 | elif pad_type == 'zero':
299 | self.pad = nn.ZeroPad2d(padding)
300 | else:
301 | assert 0, "Unsupported padding type: {}".format(pad_type)
302 |
303 | # initialize normalization
304 | norm_dim = output_dim
305 | if norm == 'bn':
306 | self.norm = nn.BatchNorm2d(norm_dim)
307 | elif norm == 'in':
308 | #self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
309 | self.norm = nn.InstanceNorm2d(norm_dim)
310 | elif norm == 'ln':
311 | self.norm = LayerNorm(norm_dim)
312 | elif norm == 'adain':
313 | self.norm = AdaptiveInstanceNorm2d(norm_dim)
314 | elif norm == 'none' or norm == 'sn':
315 | self.norm = None
316 | else:
317 | assert 0, "Unsupported normalization: {}".format(norm)
318 |
319 | # initialize activation
320 | if activation == 'relu':
321 | self.activation = nn.ReLU(inplace=True)
322 | elif activation == 'lrelu':
323 | self.activation = nn.LeakyReLU(0.2, inplace=True)
324 | elif activation == 'prelu':
325 | self.activation = nn.PReLU()
326 | elif activation == 'selu':
327 | self.activation = nn.SELU(inplace=True)
328 | elif activation == 'tanh':
329 | self.activation = nn.Tanh()
330 | elif activation == 'none':
331 | self.activation = None
332 | else:
333 | assert 0, "Unsupported activation: {}".format(activation)
334 |
335 | # initialize convolution
336 | if norm == 'sn':
337 | self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias))
338 | else:
339 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
340 |
341 | def forward(self, x):
342 | x = self.conv(self.pad(x))
343 | if self.norm:
344 | x = self.norm(x)
345 | if self.activation:
346 | x = self.activation(x)
347 | return x
348 |
349 | class LinearBlock(nn.Module):
350 | def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
351 | super(LinearBlock, self).__init__()
352 | use_bias = True
353 | # initialize fully connected layer
354 | if norm == 'sn':
355 | self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias))
356 | else:
357 | self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
358 |
359 | # initialize normalization
360 | norm_dim = output_dim
361 | if norm == 'bn':
362 | self.norm = nn.BatchNorm1d(norm_dim)
363 | elif norm == 'in':
364 | self.norm = nn.InstanceNorm1d(norm_dim)
365 | elif norm == 'ln':
366 | self.norm = LayerNorm(norm_dim)
367 | elif norm == 'none' or norm == 'sn':
368 | self.norm = None
369 | else:
370 | assert 0, "Unsupported normalization: {}".format(norm)
371 |
372 | # initialize activation
373 | if activation == 'relu':
374 | self.activation = nn.ReLU(inplace=True)
375 | elif activation == 'lrelu':
376 | self.activation = nn.LeakyReLU(0.2, inplace=True)
377 | elif activation == 'prelu':
378 | self.activation = nn.PReLU()
379 | elif activation == 'selu':
380 | self.activation = nn.SELU(inplace=True)
381 | elif activation == 'tanh':
382 | self.activation = nn.Tanh()
383 | elif activation == 'none':
384 | self.activation = None
385 | else:
386 | assert 0, "Unsupported activation: {}".format(activation)
387 |
388 | def forward(self, x):
389 | out = self.fc(x)
390 | if self.norm:
391 | out = self.norm(out)
392 | if self.activation:
393 | out = self.activation(out)
394 | return out
395 |
396 | ##################################################################################
397 | # VGG network definition
398 | ##################################################################################
399 | class Vgg16(nn.Module):
400 | def __init__(self):
401 | super(Vgg16, self).__init__()
402 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
403 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
404 |
405 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
406 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
407 |
408 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
409 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
410 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
411 |
412 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
413 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
414 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
415 |
416 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
417 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
418 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
419 |
420 | def forward(self, X):
421 | h = F.relu(self.conv1_1(X), inplace=True)
422 | h = F.relu(self.conv1_2(h), inplace=True)
423 | # relu1_2 = h
424 | h = F.max_pool2d(h, kernel_size=2, stride=2)
425 |
426 | h = F.relu(self.conv2_1(h), inplace=True)
427 | h = F.relu(self.conv2_2(h), inplace=True)
428 | # relu2_2 = h
429 | h = F.max_pool2d(h, kernel_size=2, stride=2)
430 |
431 | h = F.relu(self.conv3_1(h), inplace=True)
432 | h = F.relu(self.conv3_2(h), inplace=True)
433 | h = F.relu(self.conv3_3(h), inplace=True)
434 | # relu3_3 = h
435 | h = F.max_pool2d(h, kernel_size=2, stride=2)
436 |
437 | h = F.relu(self.conv4_1(h), inplace=True)
438 | h = F.relu(self.conv4_2(h), inplace=True)
439 | h = F.relu(self.conv4_3(h), inplace=True)
440 | # relu4_3 = h
441 |
442 | h = F.relu(self.conv5_1(h), inplace=True)
443 | h = F.relu(self.conv5_2(h), inplace=True)
444 | h = F.relu(self.conv5_3(h), inplace=True)
445 | relu5_3 = h
446 |
447 | return relu5_3
448 | # return [relu1_2, relu2_2, relu3_3, relu4_3]
449 |
450 | ##################################################################################
451 | # Normalization layers
452 | ##################################################################################
453 | class AdaptiveInstanceNorm2d(nn.Module):
454 | def __init__(self, num_features, eps=1e-5, momentum=0.1):
455 | super(AdaptiveInstanceNorm2d, self).__init__()
456 | self.num_features = num_features
457 | self.eps = eps
458 | self.momentum = momentum
459 | # weight and bias are dynamically assigned
460 | self.weight = None
461 | self.bias = None
462 | # just dummy buffers, not used
463 | self.register_buffer('running_mean', torch.zeros(num_features))
464 | self.register_buffer('running_var', torch.ones(num_features))
465 |
466 | def forward(self, x):
467 | assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!"
468 | b, c = x.size(0), x.size(1)
469 | running_mean = self.running_mean.repeat(b)
470 | running_var = self.running_var.repeat(b)
471 |
472 | # Apply instance norm
473 | x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
474 |
475 | out = F.batch_norm(
476 | x_reshaped, running_mean, running_var, self.weight, self.bias,
477 | True, self.momentum, self.eps)
478 |
479 | return out.view(b, c, *x.size()[2:])
480 |
481 | def __repr__(self):
482 | return self.__class__.__name__ + '(' + str(self.num_features) + ')'
483 |
484 |
485 | class LayerNorm(nn.Module):
486 | def __init__(self, num_features, eps=1e-5, affine=True):
487 | super(LayerNorm, self).__init__()
488 | self.num_features = num_features
489 | self.affine = affine
490 | self.eps = eps
491 |
492 | if self.affine:
493 | self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
494 | self.beta = nn.Parameter(torch.zeros(num_features))
495 |
496 | def forward(self, x):
497 | shape = [-1] + [1] * (x.dim() - 1)
498 | # print(x.size())
499 | if x.size(0) == 1:
500 | # These two lines run much faster in pytorch 0.4 than the two lines listed below.
501 | mean = x.view(-1).mean().view(*shape)
502 | std = x.view(-1).std().view(*shape)
503 | else:
504 | mean = x.view(x.size(0), -1).mean(1).view(*shape)
505 | std = x.view(x.size(0), -1).std(1).view(*shape)
506 |
507 | x = (x - mean) / (std + self.eps)
508 |
509 | if self.affine:
510 | shape = [1, -1] + [1] * (x.dim() - 2)
511 | x = x * self.gamma.view(*shape) + self.beta.view(*shape)
512 | return x
513 |
514 | def l2normalize(v, eps=1e-12):
515 | return v / (v.norm() + eps)
516 |
517 |
518 | class SpectralNorm(nn.Module):
519 | """
520 | Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida
521 | and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
522 | """
523 | def __init__(self, module, name='weight', power_iterations=1):
524 | super(SpectralNorm, self).__init__()
525 | self.module = module
526 | self.name = name
527 | self.power_iterations = power_iterations
528 | if not self._made_params():
529 | self._make_params()
530 |
531 | def _update_u_v(self):
532 | u = getattr(self.module, self.name + "_u")
533 | v = getattr(self.module, self.name + "_v")
534 | w = getattr(self.module, self.name + "_bar")
535 |
536 | height = w.data.shape[0]
537 | for _ in range(self.power_iterations):
538 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
539 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
540 |
541 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
542 | sigma = u.dot(w.view(height, -1).mv(v))
543 | setattr(self.module, self.name, w / sigma.expand_as(w))
544 |
545 | def _made_params(self):
546 | try:
547 | u = getattr(self.module, self.name + "_u")
548 | v = getattr(self.module, self.name + "_v")
549 | w = getattr(self.module, self.name + "_bar")
550 | return True
551 | except AttributeError:
552 | return False
553 |
554 |
555 | def _make_params(self):
556 | w = getattr(self.module, self.name)
557 |
558 | height = w.data.shape[0]
559 | width = w.view(height, -1).data.shape[1]
560 |
561 | u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
562 | v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
563 | u.data = l2normalize(u.data)
564 | v.data = l2normalize(v.data)
565 | w_bar = nn.Parameter(w.data)
566 |
567 | del self.module._parameters[self.name]
568 |
569 | self.module.register_parameter(self.name + "_u", u)
570 | self.module.register_parameter(self.name + "_v", v)
571 | self.module.register_parameter(self.name + "_bar", w_bar)
572 |
573 |
574 | def forward(self, *args):
575 | self._update_u_v()
576 | return self.module.forward(*args)
--------------------------------------------------------------------------------
/results/animal.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/animal.jpg
--------------------------------------------------------------------------------
/results/edges2handbags/input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/input.jpg
--------------------------------------------------------------------------------
/results/edges2handbags/output000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/output000.jpg
--------------------------------------------------------------------------------
/results/edges2handbags/output001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/output001.jpg
--------------------------------------------------------------------------------
/results/edges2handbags/output002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/output002.jpg
--------------------------------------------------------------------------------
/results/edges2handbags/output003.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/output003.jpg
--------------------------------------------------------------------------------
/results/edges2handbags/output004.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/output004.jpg
--------------------------------------------------------------------------------
/results/edges2handbags/output005.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/output005.jpg
--------------------------------------------------------------------------------
/results/edges2handbags/output006.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/output006.jpg
--------------------------------------------------------------------------------
/results/edges2handbags/output007.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/output007.jpg
--------------------------------------------------------------------------------
/results/edges2handbags/output008.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/output008.jpg
--------------------------------------------------------------------------------
/results/edges2handbags/output009.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2handbags/output009.jpg
--------------------------------------------------------------------------------
/results/edges2shoes/input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/input.jpg
--------------------------------------------------------------------------------
/results/edges2shoes/output000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/output000.jpg
--------------------------------------------------------------------------------
/results/edges2shoes/output001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/output001.jpg
--------------------------------------------------------------------------------
/results/edges2shoes/output002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/output002.jpg
--------------------------------------------------------------------------------
/results/edges2shoes/output003.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/output003.jpg
--------------------------------------------------------------------------------
/results/edges2shoes/output004.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/output004.jpg
--------------------------------------------------------------------------------
/results/edges2shoes/output005.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/output005.jpg
--------------------------------------------------------------------------------
/results/edges2shoes/output006.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/output006.jpg
--------------------------------------------------------------------------------
/results/edges2shoes/output007.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/output007.jpg
--------------------------------------------------------------------------------
/results/edges2shoes/output008.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/output008.jpg
--------------------------------------------------------------------------------
/results/edges2shoes/output009.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes/output009.jpg
--------------------------------------------------------------------------------
/results/edges2shoes_handbags.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/edges2shoes_handbags.jpg
--------------------------------------------------------------------------------
/results/example_guided.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/example_guided.jpg
--------------------------------------------------------------------------------
/results/input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/input.jpg
--------------------------------------------------------------------------------
/results/output000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/output000.jpg
--------------------------------------------------------------------------------
/results/street.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/street.jpg
--------------------------------------------------------------------------------
/results/summer2winter_yosemite.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/summer2winter_yosemite.jpg
--------------------------------------------------------------------------------
/results/video.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/MUNIT/a82e222bc359892bd0f522d7a0f1573f3ec4a485/results/video.jpg
--------------------------------------------------------------------------------
/scripts/demo_train_edges2handbags.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | rm datasets/edges2handbags -rf
3 | mkdir datasets/edges2handbags -p
4 | axel -n 1 https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/edges2handbags.tar.gz --output=datasets/edges2handbags/edges2handbags.tar.gz
5 | tar -zxvf datasets/edges2handbags/edges2handbags.tar.gz -C datasets/
6 | mkdir datasets/edges2handbags/train1 -p
7 | mkdir datasets/edges2handbags/train0 -p
8 | mkdir datasets/edges2handbags/test1 -p
9 | mkdir datasets/edges2handbags/test0 -p
10 | for f in datasets/edges2handbags/train/*; do convert -quality 100 -crop 50%x100% +repage $f datasets/edges2handbags/train%d/${f##*/}; done;
11 | for f in datasets/edges2handbags/val/*; do convert -quality 100 -crop 50%x100% +repage $f datasets/edges2handbags/test%d/${f##*/}; done;
12 | mv datasets/edges2handbags/train0 datasets/edges2handbags/trainA
13 | mv datasets/edges2handbags/train1 datasets/edges2handbags/trainB
14 | mv datasets/edges2handbags/test0 datasets/edges2handbags/testA
15 | mv datasets/edges2handbags/test1 datasets/edges2handbags/testB
16 | python train.py --config configs/edges2handbags_folder.yaml
17 |
18 |
--------------------------------------------------------------------------------
/scripts/demo_train_edges2shoes.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | rm datasets/edges2shoes -rf
3 | mkdir datasets/edges2shoes -p
4 | axel -n 1 https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/edges2shoes.tar.gz --output=datasets/edges2shoes/edges2shoes.tar.gz
5 | tar -zxvf datasets/edges2shoes/edges2shoes.tar.gz -C datasets
6 | mkdir datasets/edges2shoes/train1 -p
7 | mkdir datasets/edges2shoes/train0 -p
8 | mkdir datasets/edges2shoes/test1 -p
9 | mkdir datasets/edges2shoes/test0 -p
10 | for f in datasets/edges2shoes/train/*; do convert -quality 100 -crop 50%x100% +repage $f datasets/edges2shoes/train%d/${f##*/}; done;
11 | for f in datasets/edges2shoes/val/*; do convert -quality 100 -crop 50%x100% +repage $f datasets/edges2shoes/test%d/${f##*/}; done;
12 | mv datasets/edges2shoes/train0 datasets/edges2shoes/trainA
13 | mv datasets/edges2shoes/train1 datasets/edges2shoes/trainB
14 | mv datasets/edges2shoes/test0 datasets/edges2shoes/testA
15 | mv datasets/edges2shoes/test1 datasets/edges2shoes/testB
16 | python train.py --config configs/edges2shoes_folder.yaml
17 |
--------------------------------------------------------------------------------
/scripts/demo_train_summer2winter_yosemite256.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | rm datasets/summer2winter_yosemite256 -p
3 | mkdir datasets/summer2winter_yosemite256 -p
4 | axel -n 1 https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/summer2winter_yosemite.zip --output=datasets/summer2winter_yosemite256/summer2winter_yosemite.zip
5 | unzip datasets/summer2winter_yosemite256/summer2winter_yosemite.zip -d datasets/summer2winter_yosemite256
6 | python train.py --config configs/summer2winter_yosemite256_folder.yaml
7 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 | from __future__ import print_function
6 | from utils import get_config, pytorch03_to_pytorch04
7 | from trainer import MUNIT_Trainer, UNIT_Trainer
8 | import argparse
9 | from torch.autograd import Variable
10 | import torchvision.utils as vutils
11 | import sys
12 | import torch
13 | import os
14 | from torchvision import transforms
15 | from PIL import Image
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('--config', type=str, help="net configuration")
19 | parser.add_argument('--input', type=str, help="input image path")
20 | parser.add_argument('--output_folder', type=str, help="output image path")
21 | parser.add_argument('--checkpoint', type=str, help="checkpoint of autoencoders")
22 | parser.add_argument('--style', type=str, default='', help="style image path")
23 | parser.add_argument('--a2b', type=int, default=1, help="1 for a2b and 0 for b2a")
24 | parser.add_argument('--seed', type=int, default=10, help="random seed")
25 | parser.add_argument('--num_style',type=int, default=10, help="number of styles to sample")
26 | parser.add_argument('--synchronized', action='store_true', help="whether use synchronized style code or not")
27 | parser.add_argument('--output_only', action='store_true', help="whether use synchronized style code or not")
28 | parser.add_argument('--output_path', type=str, default='.', help="path for logs, checkpoints, and VGG model weight")
29 | parser.add_argument('--trainer', type=str, default='MUNIT', help="MUNIT|UNIT")
30 | opts = parser.parse_args()
31 |
32 |
33 |
34 | torch.manual_seed(opts.seed)
35 | torch.cuda.manual_seed(opts.seed)
36 | if not os.path.exists(opts.output_folder):
37 | os.makedirs(opts.output_folder)
38 |
39 | # Load experiment setting
40 | config = get_config(opts.config)
41 | opts.num_style = 1 if opts.style != '' else opts.num_style
42 |
43 | # Setup model and data loader
44 | config['vgg_model_path'] = opts.output_path
45 | if opts.trainer == 'MUNIT':
46 | style_dim = config['gen']['style_dim']
47 | trainer = MUNIT_Trainer(config)
48 | elif opts.trainer == 'UNIT':
49 | trainer = UNIT_Trainer(config)
50 | else:
51 | sys.exit("Only support MUNIT|UNIT")
52 |
53 | try:
54 | state_dict = torch.load(opts.checkpoint)
55 | trainer.gen_a.load_state_dict(state_dict['a'])
56 | trainer.gen_b.load_state_dict(state_dict['b'])
57 | except:
58 | state_dict = pytorch03_to_pytorch04(torch.load(opts.checkpoint), opts.trainer)
59 | trainer.gen_a.load_state_dict(state_dict['a'])
60 | trainer.gen_b.load_state_dict(state_dict['b'])
61 |
62 | trainer.cuda()
63 | trainer.eval()
64 | encode = trainer.gen_a.encode if opts.a2b else trainer.gen_b.encode # encode function
65 | style_encode = trainer.gen_b.encode if opts.a2b else trainer.gen_a.encode # encode function
66 | decode = trainer.gen_b.decode if opts.a2b else trainer.gen_a.decode # decode function
67 |
68 | if 'new_size' in config:
69 | new_size = config['new_size']
70 | else:
71 | if opts.a2b==1:
72 | new_size = config['new_size_a']
73 | else:
74 | new_size = config['new_size_b']
75 |
76 | with torch.no_grad():
77 | transform = transforms.Compose([transforms.Resize(new_size),
78 | transforms.ToTensor(),
79 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
80 | image = Variable(transform(Image.open(opts.input).convert('RGB')).unsqueeze(0).cuda())
81 | style_image = Variable(transform(Image.open(opts.style).convert('RGB')).unsqueeze(0).cuda()) if opts.style != '' else None
82 |
83 | # Start testing
84 | content, _ = encode(image)
85 |
86 | if opts.trainer == 'MUNIT':
87 | style_rand = Variable(torch.randn(opts.num_style, style_dim, 1, 1).cuda())
88 | if opts.style != '':
89 | _, style = style_encode(style_image)
90 | else:
91 | style = style_rand
92 | for j in range(opts.num_style):
93 | s = style[j].unsqueeze(0)
94 | outputs = decode(content, s)
95 | outputs = (outputs + 1) / 2.
96 | path = os.path.join(opts.output_folder, 'output{:03d}.jpg'.format(j))
97 | vutils.save_image(outputs.data, path, padding=0, normalize=True)
98 | elif opts.trainer == 'UNIT':
99 | outputs = decode(content)
100 | outputs = (outputs + 1) / 2.
101 | path = os.path.join(opts.output_folder, 'output.jpg')
102 | vutils.save_image(outputs.data, path, padding=0, normalize=True)
103 | else:
104 | pass
105 |
106 | if not opts.output_only:
107 | # also save input images
108 | vutils.save_image(image.data, os.path.join(opts.output_folder, 'input.jpg'), padding=0, normalize=True)
109 |
110 |
--------------------------------------------------------------------------------
/test_batch.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 | from __future__ import print_function
6 | from utils import get_config, get_data_loader_folder, pytorch03_to_pytorch04, load_inception
7 | from trainer import MUNIT_Trainer, UNIT_Trainer
8 | from torch import nn
9 | from scipy.stats import entropy
10 | import torch.nn.functional as F
11 | import argparse
12 | from torch.autograd import Variable
13 | from data import ImageFolder
14 | import numpy as np
15 | import torchvision.utils as vutils
16 | try:
17 | from itertools import izip as zip
18 | except ImportError: # will be 3.x series
19 | pass
20 | import sys
21 | import torch
22 | import os
23 |
24 |
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument('--config', type=str, default='configs/edges2handbags_folder', help='Path to the config file.')
27 | parser.add_argument('--input_folder', type=str, help="input image folder")
28 | parser.add_argument('--output_folder', type=str, help="output image folder")
29 | parser.add_argument('--checkpoint', type=str, help="checkpoint of autoencoders")
30 | parser.add_argument('--a2b', type=int, help="1 for a2b and 0 for b2a", default=1)
31 | parser.add_argument('--seed', type=int, default=1, help="random seed")
32 | parser.add_argument('--num_style',type=int, default=10, help="number of styles to sample")
33 | parser.add_argument('--synchronized', action='store_true', help="whether use synchronized style code or not")
34 | parser.add_argument('--output_only', action='store_true', help="whether only save the output images or also save the input images")
35 | parser.add_argument('--output_path', type=str, default='.', help="path for logs, checkpoints, and VGG model weight")
36 | parser.add_argument('--trainer', type=str, default='MUNIT', help="MUNIT|UNIT")
37 | parser.add_argument('--compute_IS', action='store_true', help="whether to compute Inception Score or not")
38 | parser.add_argument('--compute_CIS', action='store_true', help="whether to compute Conditional Inception Score or not")
39 | parser.add_argument('--inception_a', type=str, default='.', help="path to the pretrained inception network for domain A")
40 | parser.add_argument('--inception_b', type=str, default='.', help="path to the pretrained inception network for domain B")
41 |
42 | opts = parser.parse_args()
43 |
44 |
45 | torch.manual_seed(opts.seed)
46 | torch.cuda.manual_seed(opts.seed)
47 |
48 | # Load experiment setting
49 | config = get_config(opts.config)
50 | input_dim = config['input_dim_a'] if opts.a2b else config['input_dim_b']
51 |
52 | # Load the inception networks if we need to compute IS or CIIS
53 | if opts.compute_IS or opts.compute_IS:
54 | inception = load_inception(opts.inception_b) if opts.a2b else load_inception(opts.inception_a)
55 | # freeze the inception models and set eval mode
56 | inception.eval()
57 | for param in inception.parameters():
58 | param.requires_grad = False
59 | inception_up = nn.Upsample(size=(299, 299), mode='bilinear')
60 |
61 | # Setup model and data loader
62 | image_names = ImageFolder(opts.input_folder, transform=None, return_paths=True)
63 | data_loader = get_data_loader_folder(opts.input_folder, 1, False, new_size=config['new_size_a'], crop=False)
64 |
65 | config['vgg_model_path'] = opts.output_path
66 | if opts.trainer == 'MUNIT':
67 | style_dim = config['gen']['style_dim']
68 | trainer = MUNIT_Trainer(config)
69 | elif opts.trainer == 'UNIT':
70 | trainer = UNIT_Trainer(config)
71 | else:
72 | sys.exit("Only support MUNIT|UNIT")
73 |
74 | try:
75 | state_dict = torch.load(opts.checkpoint)
76 | trainer.gen_a.load_state_dict(state_dict['a'])
77 | trainer.gen_b.load_state_dict(state_dict['b'])
78 | except:
79 | state_dict = pytorch03_to_pytorch04(torch.load(opts.checkpoint), opts.trainer)
80 | trainer.gen_a.load_state_dict(state_dict['a'])
81 | trainer.gen_b.load_state_dict(state_dict['b'])
82 |
83 | trainer.cuda()
84 | trainer.eval()
85 | encode = trainer.gen_a.encode if opts.a2b else trainer.gen_b.encode # encode function
86 | decode = trainer.gen_b.decode if opts.a2b else trainer.gen_a.decode # decode function
87 |
88 | if opts.compute_IS:
89 | IS = []
90 | all_preds = []
91 | if opts.compute_CIS:
92 | CIS = []
93 |
94 | if opts.trainer == 'MUNIT':
95 | # Start testing
96 | style_fixed = Variable(torch.randn(opts.num_style, style_dim, 1, 1).cuda(), volatile=True)
97 | for i, (images, names) in enumerate(zip(data_loader, image_names)):
98 | if opts.compute_CIS:
99 | cur_preds = []
100 | print(names[1])
101 | images = Variable(images.cuda(), volatile=True)
102 | content, _ = encode(images)
103 | style = style_fixed if opts.synchronized else Variable(torch.randn(opts.num_style, style_dim, 1, 1).cuda(), volatile=True)
104 | for j in range(opts.num_style):
105 | s = style[j].unsqueeze(0)
106 | outputs = decode(content, s)
107 | outputs = (outputs + 1) / 2.
108 | if opts.compute_IS or opts.compute_CIS:
109 | pred = F.softmax(inception(inception_up(outputs)), dim=1).cpu().data.numpy() # get the predicted class distribution
110 | if opts.compute_IS:
111 | all_preds.append(pred)
112 | if opts.compute_CIS:
113 | cur_preds.append(pred)
114 | # path = os.path.join(opts.output_folder, 'input{:03d}_output{:03d}.jpg'.format(i, j))
115 | basename = os.path.basename(names[1])
116 | path = os.path.join(opts.output_folder+"_%02d"%j,basename)
117 | if not os.path.exists(os.path.dirname(path)):
118 | os.makedirs(os.path.dirname(path))
119 | vutils.save_image(outputs.data, path, padding=0, normalize=True)
120 | if opts.compute_CIS:
121 | cur_preds = np.concatenate(cur_preds, 0)
122 | py = np.sum(cur_preds, axis=0) # prior is computed from outputs given a specific input
123 | for j in range(cur_preds.shape[0]):
124 | pyx = cur_preds[j, :]
125 | CIS.append(entropy(pyx, py))
126 | if not opts.output_only:
127 | # also save input images
128 | vutils.save_image(images.data, os.path.join(opts.output_folder, 'input{:03d}.jpg'.format(i)), padding=0, normalize=True)
129 | if opts.compute_IS:
130 | all_preds = np.concatenate(all_preds, 0)
131 | py = np.sum(all_preds, axis=0) # prior is computed from all outputs
132 | for j in range(all_preds.shape[0]):
133 | pyx = all_preds[j, :]
134 | IS.append(entropy(pyx, py))
135 |
136 | if opts.compute_IS:
137 | print("Inception Score: {}".format(np.exp(np.mean(IS))))
138 | if opts.compute_CIS:
139 | print("conditional Inception Score: {}".format(np.exp(np.mean(CIS))))
140 |
141 | elif opts.trainer == 'UNIT':
142 | # Start testing
143 | for i, (images, names) in enumerate(zip(data_loader, image_names)):
144 | print(names[1])
145 | images = Variable(images.cuda(), volatile=True)
146 | content, _ = encode(images)
147 |
148 | outputs = decode(content)
149 | outputs = (outputs + 1) / 2.
150 | # path = os.path.join(opts.output_folder, 'input{:03d}_output{:03d}.jpg'.format(i, j))
151 | basename = os.path.basename(names[1])
152 | path = os.path.join(opts.output_folder,basename)
153 | if not os.path.exists(os.path.dirname(path)):
154 | os.makedirs(os.path.dirname(path))
155 | vutils.save_image(outputs.data, path, padding=0, normalize=True)
156 | if not opts.output_only:
157 | # also save input images
158 | vutils.save_image(images.data, os.path.join(opts.output_folder, 'input{:03d}.jpg'.format(i)), padding=0, normalize=True)
159 | else:
160 | pass
161 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 | from utils import get_all_data_loaders, prepare_sub_folder, write_html, write_loss, get_config, write_2images, Timer
6 | import argparse
7 | from torch.autograd import Variable
8 | from trainer import MUNIT_Trainer, UNIT_Trainer
9 | import torch.backends.cudnn as cudnn
10 | import torch
11 | try:
12 | from itertools import izip as zip
13 | except ImportError: # will be 3.x series
14 | pass
15 | import os
16 | import sys
17 | import tensorboardX
18 | import shutil
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--config', type=str, default='configs/edges2handbags_folder.yaml', help='Path to the config file.')
22 | parser.add_argument('--output_path', type=str, default='.', help="outputs path")
23 | parser.add_argument("--resume", action="store_true")
24 | parser.add_argument('--trainer', type=str, default='MUNIT', help="MUNIT|UNIT")
25 | opts = parser.parse_args()
26 |
27 | cudnn.benchmark = True
28 |
29 | # Load experiment setting
30 | config = get_config(opts.config)
31 | max_iter = config['max_iter']
32 | display_size = config['display_size']
33 | config['vgg_model_path'] = opts.output_path
34 |
35 | # Setup model and data loader
36 | if opts.trainer == 'MUNIT':
37 | trainer = MUNIT_Trainer(config)
38 | elif opts.trainer == 'UNIT':
39 | trainer = UNIT_Trainer(config)
40 | else:
41 | sys.exit("Only support MUNIT|UNIT")
42 | trainer.cuda()
43 | train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(config)
44 | train_display_images_a = torch.stack([train_loader_a.dataset[i] for i in range(display_size)]).cuda()
45 | train_display_images_b = torch.stack([train_loader_b.dataset[i] for i in range(display_size)]).cuda()
46 | test_display_images_a = torch.stack([test_loader_a.dataset[i] for i in range(display_size)]).cuda()
47 | test_display_images_b = torch.stack([test_loader_b.dataset[i] for i in range(display_size)]).cuda()
48 |
49 | # Setup logger and output folders
50 | model_name = os.path.splitext(os.path.basename(opts.config))[0]
51 | train_writer = tensorboardX.SummaryWriter(os.path.join(opts.output_path + "/logs", model_name))
52 | output_directory = os.path.join(opts.output_path + "/outputs", model_name)
53 | checkpoint_directory, image_directory = prepare_sub_folder(output_directory)
54 | shutil.copy(opts.config, os.path.join(output_directory, 'config.yaml')) # copy config file to output folder
55 |
56 | # Start training
57 | iterations = trainer.resume(checkpoint_directory, hyperparameters=config) if opts.resume else 0
58 | while True:
59 | for it, (images_a, images_b) in enumerate(zip(train_loader_a, train_loader_b)):
60 | trainer.update_learning_rate()
61 | images_a, images_b = images_a.cuda().detach(), images_b.cuda().detach()
62 |
63 | with Timer("Elapsed time in update: %f"):
64 | # Main training code
65 | trainer.dis_update(images_a, images_b, config)
66 | trainer.gen_update(images_a, images_b, config)
67 | torch.cuda.synchronize()
68 |
69 | # Dump training stats in log file
70 | if (iterations + 1) % config['log_iter'] == 0:
71 | print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
72 | write_loss(iterations, trainer, train_writer)
73 |
74 | # Write images
75 | if (iterations + 1) % config['image_save_iter'] == 0:
76 | with torch.no_grad():
77 | test_image_outputs = trainer.sample(test_display_images_a, test_display_images_b)
78 | train_image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
79 | write_2images(test_image_outputs, display_size, image_directory, 'test_%08d' % (iterations + 1))
80 | write_2images(train_image_outputs, display_size, image_directory, 'train_%08d' % (iterations + 1))
81 | # HTML
82 | write_html(output_directory + "/index.html", iterations + 1, config['image_save_iter'], 'images')
83 |
84 | if (iterations + 1) % config['image_display_iter'] == 0:
85 | with torch.no_grad():
86 | image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
87 | write_2images(image_outputs, display_size, image_directory, 'train_current')
88 |
89 | # Save network weights
90 | if (iterations + 1) % config['snapshot_save_iter'] == 0:
91 | trainer.save(checkpoint_directory, iterations)
92 |
93 | iterations += 1
94 | if iterations >= max_iter:
95 | sys.exit('Finish training')
96 |
97 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 | from networks import AdaINGen, MsImageDis, VAEGen
6 | from utils import weights_init, get_model_list, vgg_preprocess, load_vgg16, get_scheduler
7 | from torch.autograd import Variable
8 | import torch
9 | import torch.nn as nn
10 | import os
11 |
12 | class MUNIT_Trainer(nn.Module):
13 | def __init__(self, hyperparameters):
14 | super(MUNIT_Trainer, self).__init__()
15 | lr = hyperparameters['lr']
16 | # Initiate the networks
17 | self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a
18 | self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b
19 | self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a
20 | self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b
21 | self.instancenorm = nn.InstanceNorm2d(512, affine=False)
22 | self.style_dim = hyperparameters['gen']['style_dim']
23 |
24 | # fix the noise used in sampling
25 | display_size = int(hyperparameters['display_size'])
26 | self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
27 | self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()
28 |
29 | # Setup the optimizers
30 | beta1 = hyperparameters['beta1']
31 | beta2 = hyperparameters['beta2']
32 | dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
33 | gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
34 | self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
35 | lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
36 | self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
37 | lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
38 | self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
39 | self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
40 |
41 | # Network weight initialization
42 | self.apply(weights_init(hyperparameters['init']))
43 | self.dis_a.apply(weights_init('gaussian'))
44 | self.dis_b.apply(weights_init('gaussian'))
45 |
46 | # Load VGG model if needed
47 | if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
48 | self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
49 | self.vgg.eval()
50 | for param in self.vgg.parameters():
51 | param.requires_grad = False
52 |
53 | def recon_criterion(self, input, target):
54 | return torch.mean(torch.abs(input - target))
55 |
56 | def forward(self, x_a, x_b):
57 | self.eval()
58 | s_a = Variable(self.s_a)
59 | s_b = Variable(self.s_b)
60 | c_a, s_a_fake = self.gen_a.encode(x_a)
61 | c_b, s_b_fake = self.gen_b.encode(x_b)
62 | x_ba = self.gen_a.decode(c_b, s_a)
63 | x_ab = self.gen_b.decode(c_a, s_b)
64 | self.train()
65 | return x_ab, x_ba
66 |
67 | def gen_update(self, x_a, x_b, hyperparameters):
68 | self.gen_opt.zero_grad()
69 | s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
70 | s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
71 | # encode
72 | c_a, s_a_prime = self.gen_a.encode(x_a)
73 | c_b, s_b_prime = self.gen_b.encode(x_b)
74 | # decode (within domain)
75 | x_a_recon = self.gen_a.decode(c_a, s_a_prime)
76 | x_b_recon = self.gen_b.decode(c_b, s_b_prime)
77 | # decode (cross domain)
78 | x_ba = self.gen_a.decode(c_b, s_a)
79 | x_ab = self.gen_b.decode(c_a, s_b)
80 | # encode again
81 | c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
82 | c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
83 | # decode again (if needed)
84 | x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
85 | x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
86 |
87 | # reconstruction loss
88 | self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
89 | self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
90 | self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
91 | self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
92 | self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
93 | self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
94 | self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
95 | self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
96 | # GAN loss
97 | self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
98 | self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
99 | # domain-invariant perceptual loss
100 | self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
101 | self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
102 | # total loss
103 | self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
104 | hyperparameters['gan_w'] * self.loss_gen_adv_b + \
105 | hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
106 | hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
107 | hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
108 | hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
109 | hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
110 | hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
111 | hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
112 | hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
113 | hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
114 | hyperparameters['vgg_w'] * self.loss_gen_vgg_b
115 | self.loss_gen_total.backward()
116 | self.gen_opt.step()
117 |
118 | def compute_vgg_loss(self, vgg, img, target):
119 | img_vgg = vgg_preprocess(img)
120 | target_vgg = vgg_preprocess(target)
121 | img_fea = vgg(img_vgg)
122 | target_fea = vgg(target_vgg)
123 | return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)
124 |
125 | def sample(self, x_a, x_b):
126 | self.eval()
127 | s_a1 = Variable(self.s_a)
128 | s_b1 = Variable(self.s_b)
129 | s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
130 | s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
131 | x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
132 | for i in range(x_a.size(0)):
133 | c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
134 | c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
135 | x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
136 | x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
137 | x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
138 | x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
139 | x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
140 | x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
141 | x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
142 | x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
143 | x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
144 | self.train()
145 | return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2
146 |
147 | def dis_update(self, x_a, x_b, hyperparameters):
148 | self.dis_opt.zero_grad()
149 | s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
150 | s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
151 | # encode
152 | c_a, _ = self.gen_a.encode(x_a)
153 | c_b, _ = self.gen_b.encode(x_b)
154 | # decode (cross domain)
155 | x_ba = self.gen_a.decode(c_b, s_a)
156 | x_ab = self.gen_b.decode(c_a, s_b)
157 | # D loss
158 | self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
159 | self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
160 | self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
161 | self.loss_dis_total.backward()
162 | self.dis_opt.step()
163 |
164 | def update_learning_rate(self):
165 | if self.dis_scheduler is not None:
166 | self.dis_scheduler.step()
167 | if self.gen_scheduler is not None:
168 | self.gen_scheduler.step()
169 |
170 | def resume(self, checkpoint_dir, hyperparameters):
171 | # Load generators
172 | last_model_name = get_model_list(checkpoint_dir, "gen")
173 | state_dict = torch.load(last_model_name)
174 | self.gen_a.load_state_dict(state_dict['a'])
175 | self.gen_b.load_state_dict(state_dict['b'])
176 | iterations = int(last_model_name[-11:-3])
177 | # Load discriminators
178 | last_model_name = get_model_list(checkpoint_dir, "dis")
179 | state_dict = torch.load(last_model_name)
180 | self.dis_a.load_state_dict(state_dict['a'])
181 | self.dis_b.load_state_dict(state_dict['b'])
182 | # Load optimizers
183 | state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
184 | self.dis_opt.load_state_dict(state_dict['dis'])
185 | self.gen_opt.load_state_dict(state_dict['gen'])
186 | # Reinitilize schedulers
187 | self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
188 | self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
189 | print('Resume from iteration %d' % iterations)
190 | return iterations
191 |
192 | def save(self, snapshot_dir, iterations):
193 | # Save generators, discriminators, and optimizers
194 | gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
195 | dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
196 | opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
197 | torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
198 | torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
199 | torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
200 |
201 |
202 | class UNIT_Trainer(nn.Module):
203 | def __init__(self, hyperparameters):
204 | super(UNIT_Trainer, self).__init__()
205 | lr = hyperparameters['lr']
206 | # Initiate the networks
207 | self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a
208 | self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b
209 | self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a
210 | self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b
211 | self.instancenorm = nn.InstanceNorm2d(512, affine=False)
212 |
213 | # Setup the optimizers
214 | beta1 = hyperparameters['beta1']
215 | beta2 = hyperparameters['beta2']
216 | dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
217 | gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
218 | self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
219 | lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
220 | self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
221 | lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
222 | self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
223 | self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
224 |
225 | # Network weight initialization
226 | self.apply(weights_init(hyperparameters['init']))
227 | self.dis_a.apply(weights_init('gaussian'))
228 | self.dis_b.apply(weights_init('gaussian'))
229 |
230 | # Load VGG model if needed
231 | if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
232 | self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
233 | self.vgg.eval()
234 | for param in self.vgg.parameters():
235 | param.requires_grad = False
236 |
237 | def recon_criterion(self, input, target):
238 | return torch.mean(torch.abs(input - target))
239 |
240 | def forward(self, x_a, x_b):
241 | self.eval()
242 | h_a, _ = self.gen_a.encode(x_a)
243 | h_b, _ = self.gen_b.encode(x_b)
244 | x_ba = self.gen_a.decode(h_b)
245 | x_ab = self.gen_b.decode(h_a)
246 | self.train()
247 | return x_ab, x_ba
248 |
249 | def __compute_kl(self, mu):
250 | # def _compute_kl(self, mu, sd):
251 | # mu_2 = torch.pow(mu, 2)
252 | # sd_2 = torch.pow(sd, 2)
253 | # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0)
254 | # return encoding_loss
255 | mu_2 = torch.pow(mu, 2)
256 | encoding_loss = torch.mean(mu_2)
257 | return encoding_loss
258 |
259 | def gen_update(self, x_a, x_b, hyperparameters):
260 | self.gen_opt.zero_grad()
261 | # encode
262 | h_a, n_a = self.gen_a.encode(x_a)
263 | h_b, n_b = self.gen_b.encode(x_b)
264 | # decode (within domain)
265 | x_a_recon = self.gen_a.decode(h_a + n_a)
266 | x_b_recon = self.gen_b.decode(h_b + n_b)
267 | # decode (cross domain)
268 | x_ba = self.gen_a.decode(h_b + n_b)
269 | x_ab = self.gen_b.decode(h_a + n_a)
270 | # encode again
271 | h_b_recon, n_b_recon = self.gen_a.encode(x_ba)
272 | h_a_recon, n_a_recon = self.gen_b.encode(x_ab)
273 | # decode again (if needed)
274 | x_aba = self.gen_a.decode(h_a_recon + n_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
275 | x_bab = self.gen_b.decode(h_b_recon + n_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
276 |
277 | # reconstruction loss
278 | self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
279 | self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
280 | self.loss_gen_recon_kl_a = self.__compute_kl(h_a)
281 | self.loss_gen_recon_kl_b = self.__compute_kl(h_b)
282 | self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a)
283 | self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b)
284 | self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon)
285 | self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon)
286 | # GAN loss
287 | self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
288 | self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
289 | # domain-invariant perceptual loss
290 | self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
291 | self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
292 | # total loss
293 | self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
294 | hyperparameters['gan_w'] * self.loss_gen_adv_b + \
295 | hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
296 | hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \
297 | hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
298 | hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \
299 | hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \
300 | hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \
301 | hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \
302 | hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \
303 | hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
304 | hyperparameters['vgg_w'] * self.loss_gen_vgg_b
305 | self.loss_gen_total.backward()
306 | self.gen_opt.step()
307 |
308 | def compute_vgg_loss(self, vgg, img, target):
309 | img_vgg = vgg_preprocess(img)
310 | target_vgg = vgg_preprocess(target)
311 | img_fea = vgg(img_vgg)
312 | target_fea = vgg(target_vgg)
313 | return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)
314 |
315 | def sample(self, x_a, x_b):
316 | self.eval()
317 | x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], []
318 | for i in range(x_a.size(0)):
319 | h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0))
320 | h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0))
321 | x_a_recon.append(self.gen_a.decode(h_a))
322 | x_b_recon.append(self.gen_b.decode(h_b))
323 | x_ba.append(self.gen_a.decode(h_b))
324 | x_ab.append(self.gen_b.decode(h_a))
325 | x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
326 | x_ba = torch.cat(x_ba)
327 | x_ab = torch.cat(x_ab)
328 | self.train()
329 | return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba
330 |
331 | def dis_update(self, x_a, x_b, hyperparameters):
332 | self.dis_opt.zero_grad()
333 | # encode
334 | h_a, n_a = self.gen_a.encode(x_a)
335 | h_b, n_b = self.gen_b.encode(x_b)
336 | # decode (cross domain)
337 | x_ba = self.gen_a.decode(h_b + n_b)
338 | x_ab = self.gen_b.decode(h_a + n_a)
339 | # D loss
340 | self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
341 | self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
342 | self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
343 | self.loss_dis_total.backward()
344 | self.dis_opt.step()
345 |
346 | def update_learning_rate(self):
347 | if self.dis_scheduler is not None:
348 | self.dis_scheduler.step()
349 | if self.gen_scheduler is not None:
350 | self.gen_scheduler.step()
351 |
352 | def resume(self, checkpoint_dir, hyperparameters):
353 | # Load generators
354 | last_model_name = get_model_list(checkpoint_dir, "gen")
355 | state_dict = torch.load(last_model_name)
356 | self.gen_a.load_state_dict(state_dict['a'])
357 | self.gen_b.load_state_dict(state_dict['b'])
358 | iterations = int(last_model_name[-11:-3])
359 | # Load discriminators
360 | last_model_name = get_model_list(checkpoint_dir, "dis")
361 | state_dict = torch.load(last_model_name)
362 | self.dis_a.load_state_dict(state_dict['a'])
363 | self.dis_b.load_state_dict(state_dict['b'])
364 | # Load optimizers
365 | state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
366 | self.dis_opt.load_state_dict(state_dict['dis'])
367 | self.gen_opt.load_state_dict(state_dict['gen'])
368 | # Reinitilize schedulers
369 | self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
370 | self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
371 | print('Resume from iteration %d' % iterations)
372 | return iterations
373 |
374 | def save(self, snapshot_dir, iterations):
375 | # Save generators, discriminators, and optimizers
376 | gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
377 | dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
378 | opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
379 | torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
380 | torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
381 | torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
382 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 | from torch.utils.serialization import load_lua
6 | from torch.utils.data import DataLoader
7 | from networks import Vgg16
8 | from torch.autograd import Variable
9 | from torch.optim import lr_scheduler
10 | from torchvision import transforms
11 | from data import ImageFilelist, ImageFolder
12 | import torch
13 | import torch.nn as nn
14 | import os
15 | import math
16 | import torchvision.utils as vutils
17 | import yaml
18 | import numpy as np
19 | import torch.nn.init as init
20 | import time
21 | # Methods
22 | # get_all_data_loaders : primary data loader interface (load trainA, testA, trainB, testB)
23 | # get_data_loader_list : list-based data loader
24 | # get_data_loader_folder : folder-based data loader
25 | # get_config : load yaml file
26 | # eformat :
27 | # write_2images : save output image
28 | # prepare_sub_folder : create checkpoints and images folders for saving outputs
29 | # write_one_row_html : write one row of the html file for output images
30 | # write_html : create the html file.
31 | # write_loss
32 | # slerp
33 | # get_slerp_interp
34 | # get_model_list
35 | # load_vgg16
36 | # load_inception
37 | # vgg_preprocess
38 | # get_scheduler
39 | # weights_init
40 |
41 | def get_all_data_loaders(conf):
42 | batch_size = conf['batch_size']
43 | num_workers = conf['num_workers']
44 | if 'new_size' in conf:
45 | new_size_a = new_size_b = conf['new_size']
46 | else:
47 | new_size_a = conf['new_size_a']
48 | new_size_b = conf['new_size_b']
49 | height = conf['crop_image_height']
50 | width = conf['crop_image_width']
51 |
52 | if 'data_root' in conf:
53 | train_loader_a = get_data_loader_folder(os.path.join(conf['data_root'], 'trainA'), batch_size, True,
54 | new_size_a, height, width, num_workers, True)
55 | test_loader_a = get_data_loader_folder(os.path.join(conf['data_root'], 'testA'), batch_size, False,
56 | new_size_a, new_size_a, new_size_a, num_workers, True)
57 | train_loader_b = get_data_loader_folder(os.path.join(conf['data_root'], 'trainB'), batch_size, True,
58 | new_size_b, height, width, num_workers, True)
59 | test_loader_b = get_data_loader_folder(os.path.join(conf['data_root'], 'testB'), batch_size, False,
60 | new_size_b, new_size_b, new_size_b, num_workers, True)
61 | else:
62 | train_loader_a = get_data_loader_list(conf['data_folder_train_a'], conf['data_list_train_a'], batch_size, True,
63 | new_size_a, height, width, num_workers, True)
64 | test_loader_a = get_data_loader_list(conf['data_folder_test_a'], conf['data_list_test_a'], batch_size, False,
65 | new_size_a, new_size_a, new_size_a, num_workers, True)
66 | train_loader_b = get_data_loader_list(conf['data_folder_train_b'], conf['data_list_train_b'], batch_size, True,
67 | new_size_b, height, width, num_workers, True)
68 | test_loader_b = get_data_loader_list(conf['data_folder_test_b'], conf['data_list_test_b'], batch_size, False,
69 | new_size_b, new_size_b, new_size_b, num_workers, True)
70 | return train_loader_a, train_loader_b, test_loader_a, test_loader_b
71 |
72 |
73 | def get_data_loader_list(root, file_list, batch_size, train, new_size=None,
74 | height=256, width=256, num_workers=4, crop=True):
75 | transform_list = [transforms.ToTensor(),
76 | transforms.Normalize((0.5, 0.5, 0.5),
77 | (0.5, 0.5, 0.5))]
78 | transform_list = [transforms.RandomCrop((height, width))] + transform_list if crop else transform_list
79 | transform_list = [transforms.Resize(new_size)] + transform_list if new_size is not None else transform_list
80 | transform_list = [transforms.RandomHorizontalFlip()] + transform_list if train else transform_list
81 | transform = transforms.Compose(transform_list)
82 | dataset = ImageFilelist(root, file_list, transform=transform)
83 | loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=train, drop_last=True, num_workers=num_workers)
84 | return loader
85 |
86 | def get_data_loader_folder(input_folder, batch_size, train, new_size=None,
87 | height=256, width=256, num_workers=4, crop=True):
88 | transform_list = [transforms.ToTensor(),
89 | transforms.Normalize((0.5, 0.5, 0.5),
90 | (0.5, 0.5, 0.5))]
91 | transform_list = [transforms.RandomCrop((height, width))] + transform_list if crop else transform_list
92 | transform_list = [transforms.Resize(new_size)] + transform_list if new_size is not None else transform_list
93 | transform_list = [transforms.RandomHorizontalFlip()] + transform_list if train else transform_list
94 | transform = transforms.Compose(transform_list)
95 | dataset = ImageFolder(input_folder, transform=transform)
96 | loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=train, drop_last=True, num_workers=num_workers)
97 | return loader
98 |
99 |
100 | def get_config(config):
101 | with open(config, 'r') as stream:
102 | return yaml.load(stream)
103 |
104 |
105 | def eformat(f, prec):
106 | s = "%.*e"%(prec, f)
107 | mantissa, exp = s.split('e')
108 | # add 1 to digits as 1 is taken by sign +/-
109 | return "%se%d"%(mantissa, int(exp))
110 |
111 |
112 | def __write_images(image_outputs, display_image_num, file_name):
113 | image_outputs = [images.expand(-1, 3, -1, -1) for images in image_outputs] # expand gray-scale images to 3 channels
114 | image_tensor = torch.cat([images[:display_image_num] for images in image_outputs], 0)
115 | image_grid = vutils.make_grid(image_tensor.data, nrow=display_image_num, padding=0, normalize=True)
116 | vutils.save_image(image_grid, file_name, nrow=1)
117 |
118 |
119 | def write_2images(image_outputs, display_image_num, image_directory, postfix):
120 | n = len(image_outputs)
121 | __write_images(image_outputs[0:n//2], display_image_num, '%s/gen_a2b_%s.jpg' % (image_directory, postfix))
122 | __write_images(image_outputs[n//2:n], display_image_num, '%s/gen_b2a_%s.jpg' % (image_directory, postfix))
123 |
124 |
125 | def prepare_sub_folder(output_directory):
126 | image_directory = os.path.join(output_directory, 'images')
127 | if not os.path.exists(image_directory):
128 | print("Creating directory: {}".format(image_directory))
129 | os.makedirs(image_directory)
130 | checkpoint_directory = os.path.join(output_directory, 'checkpoints')
131 | if not os.path.exists(checkpoint_directory):
132 | print("Creating directory: {}".format(checkpoint_directory))
133 | os.makedirs(checkpoint_directory)
134 | return checkpoint_directory, image_directory
135 |
136 |
137 | def write_one_row_html(html_file, iterations, img_filename, all_size):
138 | html_file.write("
141 |
142 |
143 |
144 | """ % (img_filename, img_filename, all_size)) 145 | return 146 | 147 | 148 | def write_html(filename, iterations, image_save_iterations, image_directory, all_size=1536): 149 | html_file = open(filename, "w") 150 | html_file.write(''' 151 | 152 | 153 |
154 |