├── .gitignore ├── Dockerfile ├── LICENSE.md ├── README.md ├── TUTORIAL.md ├── configs ├── unit_edges2handbags_folder.yaml ├── unit_edges2shoes_folder.yaml ├── unit_gta2city_folder.yaml ├── unit_gta2city_list.yaml └── unit_summer2winter_yosemite256_folder.yaml ├── data.py ├── datasets └── gta2city │ ├── list_testA.txt │ ├── list_testB.txt │ ├── list_trainA.txt │ ├── list_trainB.txt │ ├── testA │ ├── 00001.jpg │ ├── 00002.jpg │ ├── 00003.jpg │ ├── 00004.jpg │ └── 00005.jpg │ ├── testB │ ├── aachen_000000_000019_leftImg8bit.jpg │ ├── aachen_000001_000019_leftImg8bit.jpg │ ├── aachen_000002_000019_leftImg8bit.jpg │ ├── aachen_000003_000019_leftImg8bit.jpg │ └── aachen_000004_000019_leftImg8bit.jpg │ ├── trainA │ ├── 00001.jpg │ ├── 00002.jpg │ ├── 00003.jpg │ ├── 00004.jpg │ └── 00005.jpg │ └── trainB │ ├── aachen_000000_000019_leftImg8bit.jpg │ ├── aachen_000001_000019_leftImg8bit.jpg │ ├── aachen_000002_000019_leftImg8bit.jpg │ ├── aachen_000003_000019_leftImg8bit.jpg │ └── aachen_000004_000019_leftImg8bit.jpg ├── docs ├── cat_species.gif ├── cat_trans.png ├── day2night.gif ├── dog_breed.gif ├── dog_trans.png ├── faces.png ├── shared-latent-space.png ├── snowy2summery.gif ├── street_scene.png ├── two-minute-paper.png └── unit_nips_2017.pdf ├── inputs ├── city_example.jpg └── gta_example.jpg ├── networks.py ├── results ├── city2gta │ ├── input.jpg │ └── output.jpg └── gta2city │ ├── input.jpg │ └── output.jpg ├── scripts ├── unit_demo_train_edges2handbags.sh ├── unit_demo_train_edges2shoes.sh └── unit_demo_train_summer2winter_yosemite256.sh ├── test.py ├── test_batch.py ├── train.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | outputs/ 2 | src/models/ 3 | data/ 4 | logs/ 5 | .idea/ 6 | dgx/scripts/ 7 | .ipynb_checkpoints/ 8 | src/*jpg 9 | notebooks/.ipynb_checkpoints/* 10 | exps/ 11 | src/yaml_generator.py 12 | *.tar.gz 13 | *ipynb 14 | *.zip 15 | *.pkl 16 | *.pyc 17 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.1-cudnn7-runtime-ubuntu16.04 2 | # Set anaconda path 3 | ENV ANACONDA /opt/anaconda 4 | ENV PATH $ANACONDA/bin:$PATH 5 | RUN apt-get update && apt-get install -y --no-install-recommends \ 6 | wget \ 7 | libopencv-dev \ 8 | python-opencv \ 9 | build-essential \ 10 | cmake \ 11 | git \ 12 | curl \ 13 | ca-certificates \ 14 | libjpeg-dev \ 15 | libpng-dev \ 16 | axel \ 17 | zip \ 18 | unzip 19 | RUN wget https://repo.continuum.io/archive/Anaconda3-5.0.1-Linux-x86_64.sh -P /tmp 20 | RUN bash /tmp/Anaconda3-5.0.1-Linux-x86_64.sh -b -p $ANACONDA 21 | RUN rm /tmp/Anaconda3-5.0.1-Linux-x86_64.sh -rf 22 | RUN conda install -y pytorch=0.4.1 torchvision cuda91 -c pytorch 23 | RUN conda install -y -c anaconda pip 24 | RUN conda install -y -c anaconda yaml 25 | RUN pip install tensorboard tensorboardX; 26 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | ## creative commons 2 | 3 | # Attribution-NonCommercial-ShareAlike 4.0 International 4 | 5 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible. 6 | 7 | ### Using Creative Commons Public Licenses 8 | 9 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. 10 | 11 | * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors). 12 | 13 | * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees). 14 | 15 | ## Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License 16 | 17 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. 18 | 19 | ### Section 1 – Definitions. 20 | 21 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. 22 | 23 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. 24 | 25 | c. __BY-NC-SA Compatible License__ means a license listed at [creativecommons.org/compatiblelicenses](http://creativecommons.org/compatiblelicenses), approved by Creative Commons as essentially the equivalent of this Public License. 26 | 27 | d. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. 28 | 29 | e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. 30 | 31 | f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. 32 | 33 | g. __License Elements__ means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike. 34 | 35 | h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License. 36 | 37 | i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. 38 | 39 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License. 40 | 41 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. 42 | 43 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. 44 | 45 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. 46 | 47 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. 48 | 49 | ### Section 2 – Scope. 50 | 51 | a. ___License grant.___ 52 | 53 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: 54 | 55 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and 56 | 57 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. 58 | 59 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 60 | 61 | 3. __Term.__ The term of this Public License is specified in Section 6(a). 62 | 63 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. 64 | 65 | 5. __Downstream recipients.__ 66 | 67 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. 68 | 69 | B. __Additional offer from the Licensor – Adapted Material.__ Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter’s License You apply. 70 | 71 | C. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 72 | 73 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). 74 | 75 | b. ___Other rights.___ 76 | 77 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 78 | 79 | 2. Patent and trademark rights are not licensed under this Public License. 80 | 81 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. 82 | 83 | ### Section 3 – License Conditions. 84 | 85 | Your exercise of the Licensed Rights is expressly made subject to the following conditions. 86 | 87 | a. ___Attribution.___ 88 | 89 | 1. If You Share the Licensed Material (including in modified form), You must: 90 | 91 | A. retain the following if it is supplied by the Licensor with the Licensed Material: 92 | 93 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); 94 | 95 | ii. a copyright notice; 96 | 97 | iii. a notice that refers to this Public License; 98 | 99 | iv. a notice that refers to the disclaimer of warranties; 100 | 101 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; 102 | 103 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and 104 | 105 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 106 | 107 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 108 | 109 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. 110 | 111 | b. ___ShareAlike.___ 112 | 113 | In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply. 114 | 115 | 1. The Adapter’s License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License. 116 | 117 | 2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material. 118 | 119 | 3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply. 120 | 121 | ### Section 4 – Sui Generis Database Rights. 122 | 123 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: 124 | 125 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; 126 | 127 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and 128 | 129 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. 130 | 131 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. 132 | 133 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability. 134 | 135 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__ 136 | 137 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__ 138 | 139 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. 140 | 141 | ### Section 6 – Term and Termination. 142 | 143 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. 144 | 145 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 146 | 147 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 148 | 149 | 2. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 150 | 151 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. 152 | 153 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. 154 | 155 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. 156 | 157 | ### Section 7 – Other Terms and Conditions. 158 | 159 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. 160 | 161 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. 162 | 163 | ### Section 8 – Interpretation. 164 | 165 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. 166 | 167 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. 168 | 169 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. 170 | 171 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. 172 | 173 | ``` 174 | Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. 175 | 176 | Creative Commons may be contacted at [creativecommons.org](http://creativecommons.org/). 177 | ``` 178 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![License CC BY-NC-SA 4.0](https://img.shields.io/badge/license-CC4.0-blue.svg)](https://raw.githubusercontent.com/NVIDIA/FastPhotoStyle/master/LICENSE.md) 2 | ![Python 2.7](https://img.shields.io/badge/python-2.7-green.svg) 3 | ## UNIT: UNsupervised Image-to-image Translation Networks 4 | 5 | ## New implementation available at imaginaire repository 6 | 7 | We have a reimplementation of the UNIT method that is more performant. It is avaiable at [Imaginaire](https://github.com/NVlabs/imaginaire) 8 | 9 | ### License 10 | 11 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 12 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 13 | 14 | ### Code usage 15 | 16 | - Please check out our [tutorial](TUTORIAL.md). 17 | 18 | - For multimodal (or many-to-many) image translation, please check out our new work on [MUNIT](https://github.com/NVlabs/MUNIT). 19 | 20 | ### What's new. 21 | 22 | - 05-02-2018: We now adapt [MUNIT](https://github.com/NVlabs/MUNIT) code structure. For reproducing experiment results in the NIPS paper, please check out [version_02 branch](https://github.com/mingyuliutw/UNIT/tree/version_02). 23 | 24 | - 12-21-2017: Release pre-trained synthia-to-cityscape image translation model. See [USAGE.md](TUTORIAL.md) for usage examples. 25 | 26 | - 12-14-2017: Added multi-scale discriminators described in the [pix2pixHD](https://arxiv.org/pdf/1711.11585.pdf) paper. To use it simply make the name of the discriminator COCOMsDis. 27 | 28 | ### Paper 29 | 30 | [Ming-Yu Liu, Thomas Breuel, Jan Kautz, "Unsupervised Image-to-Image Translation Networks" NIPS 2017 Spotlight, arXiv:1703.00848 2017](https://arxiv.org/abs/1703.00848) 31 | 32 | #### Two Minute Paper Summary 33 | [![](./docs/two-minute-paper.png)](https://youtu.be/dqxqbvyOnMY) (We thank the Two Minute Papers channel for summarizing our work.) 34 | 35 | #### The Shared Latent Space Assumption 36 | [![](./docs/shared-latent-space.png)](https://www.youtube.com/watch?v=nlyXoX2aIek) 37 | 38 | #### Result Videos 39 | 40 | More image results are available in the [Google Photo Album](https://photos.app.goo.gl/5x7oIifLh2BVJemb2). 41 | 42 | *Left: input.* **Right: neural network generated.** Resolution: 640x480 43 | 44 | ![](./docs/snowy2summery.gif) 45 | 46 | *Left: input.* **Right: neural network generated.** Resolution: 640x480 47 | 48 | ![](./docs/day2night.gif) 49 | ![](./docs/dog_breed.gif) 50 | ![](./docs/cat_species.gif) 51 | 52 | - [Snowy2Summery-01](https://youtu.be/9VC0c3pndbI) 53 | - [Snowy2Summery-02](https://youtu.be/eUBiiBS1mj0) 54 | - [Day2Night-01](https://youtu.be/Z_Rxf0TfBJE) 55 | - [Day2Night-02](https://youtu.be/mmj3iRIQw1k) 56 | - [Translation Between 5 dog breeds](https://youtu.be/3a6Jc7PabB4) 57 | - [Translation Between 6 cat species](https://youtu.be/Bwq7BmQ1Vbc) 58 | 59 | #### Street Scene Image Translation 60 | From the first row to the fourth row, we show example results on day to night, sunny to rainy, summery to snowy, and real to synthetic image translation (two directions). 61 | 62 | For each image pair, *left is the input image*; **right is the machine generated image.** 63 | 64 | ![](./docs/street_scene.png) 65 | 66 | #### Dog Breed Image Translation 67 | 68 | ![](./docs/dog_trans.png) 69 | 70 | #### Cat Species Image Translation 71 | 72 | ![](./docs/cat_trans.png) 73 | 74 | #### Attribute-based Face Image Translation 75 | 76 | ![](./docs/faces.png) 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /TUTORIAL.md: -------------------------------------------------------------------------------- 1 | [![License CC BY-NC-SA 4.0](https://img.shields.io/badge/license-CC4.0-blue.svg)](https://raw.githubusercontent.com/NVIDIA/FastPhotoStyle/master/LICENSE.md) 2 | ![Python 2.7](https://img.shields.io/badge/python-2.7-green.svg) 3 | ![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg) 4 | ## UNIT Tutorial 5 | 6 | In this short tutorial, we will guide you through setting up the system environment for running the UNIT, which stands for unsupervised image-to-image translation, software and then show several usage examples. 7 | 8 | ### Background 9 | 10 | Unsupervised image-to-image translation concerns learning an image translation model that can map an input image in the source domain to a corresponding image in the target domain without paired supervision on the mapping function. Typically, the training data consists of two datasets of images. One from each domain and paired of corresponding images between domains are unavailable. For example, to learning a summer-to-winter translation mapping function, the model only has access to a dataset of summer images and a dataset of winter images during training. 11 | 12 | ### Algorithm 13 | 14 | 15 | 16 | The unsupervised image-to-image translation problem is an ill-posed problem. It basically aims at discovering the joint distribution from samples of marginal distributions. From the coupling theory in probability, we know there exists infinitely many possible joint distributions that can arrive to two given marginal distributions. To find the target solution, one would have to incorporate the right inductive bias. One has to use additional assumptions. UNIT is based on the shared-latent space assumption as illustrated in the figure above. Basically, it assumes that latent representations of a pair of corresponding images in two different domains share the same latent code. Although we do not have any pairs of corresponding images during training, we assume their existences and utilize network capacity constraint to encourage discovering the true joint distribution. 17 | 18 | As shown in the figure above, UNIT consists of 4 networks, 2 from each domain 19 | 20 | 1. source domain encoder (for extracting a domain-shared latent code for image in the source domain) 21 | 2. source domain decoder (for generating an image in the source domain using a latent code, either from the source or target domains) 22 | 3. target domain encoder (for extracting a domain-shared latent code for image in the target domain) 23 | 4. target domain decoder (for generating an image in the target domain using a latent code, either from the source or target domains) 24 | 25 | In the test time, for translating images from the source domain to the target domain, it utilizes the source domain encoder to encoder the source domain image to a shared-latent code. It then utilizes the target domain decoder to generate an image in the target domain. 26 | 27 | ### Requirments 28 | 29 | - Hardware: PC with NVIDIA Titan GPU. For large resolution images, you need NVIDIA Tesla P100 or V100 GPUs, which have 16GB+ GPU memory. 30 | - Software: *Ubuntu 16.04*, *CUDA 9.1*, *Anaconda3*, *pytorch 0.4.1* 31 | - System package 32 | - `sudo apt-get install -y axel imagemagick` (Only used for demo) 33 | - Python package 34 | - `conda install pytorch=0.4.1 torchvision cuda91 -y -c pytorch` 35 | - `conda install -y -c anaconda pip` 36 | - `conda install -y -c anaconda pyyaml` 37 | - `pip install tensorboard tensorboardX` 38 | 39 | ### Docker Image 40 | 41 | We also provide a [Dockerfile](Dockerfile) for building an environment for running the MUNIT code. 42 | 43 | 1. Install docker-ce. Follow the instruction in the [Docker page](https://docs.docker.com/install/linux/docker-ce/ubuntu/#install-docker-ce-1) 44 | 2. Install nvidia-docker. Follow the instruction in the [NVIDIA-DOCKER README page](https://github.com/NVIDIA/nvidia-docker). 45 | 3. Build the docker image `docker build -t your-docker-image:v1.0 .` 46 | 4. Run an interactive session `docker run -v YOUR_PATH:YOUR_PATH --runtime=nvidia -i -t your-docker-image:v1.0 /bin/bash` 47 | 5. `cd YOUR_PATH` 48 | 6. Follow the rest of the tutorial. 49 | 50 | 51 | ### Training 52 | 53 | We provide several training scripts as usage examples. They are located under `scripts` folder. 54 | - `bash scripts/unit_demo_train_edges2handbags.sh` to train a model for sketches of handbags to images of handbags translation. 55 | - `bash scripts/unit_demo_train_edges2shoes.sh` to train a model for sketches of shoes to images of shoes translation. 56 | - `bash scripts/unit_demo_train_summer2winter_yosemite256.sh` to train a model for Yosemite summer 256x256 images to Yosemite winter 256x256 image translation. 57 | 58 | 1. Download the dataset you want to use. For example, you can use the GTA5 dataset provided by [Richter et al.](https://download.visinf.tu-darmstadt.de/data/from_games/) and Cityscape dataset provided by [Cordts et al.](https://www.cityscapes-dataset.com/). 59 | 60 | 3. Setup the yaml file. Check out `configs/unit_gta2city_folder.yaml` for folder-based dataset organization. Change the `data_root` field to the path of your downloaded dataset. For list-based dataset organization, check out `configs/unit_gta2city_list.yaml` 61 | 62 | 3. Start training 63 | ``` 64 | python train.py --trainer UNIT --config configs/unit_gta2city_folder.yaml 65 | ``` 66 | 67 | 4. Intermediate image outputs and model binary files are stored in `outputs/unit_gta2city_folder` 68 | 69 | 70 | ### Testing 71 | 72 | First, download our pretrained models for the gta2cityscape task and put them in `models` folder. 73 | 74 | #### Pretrained models 75 | 76 | | Dataset | Model Link | 77 | |-------------|----------------| 78 | | gta2cityscape | [model](https://drive.google.com/open?id=1R9MH_p8tDmUsIAjKCu-jgoilWgANfObx) | 79 | 80 | #### Translation 81 | 82 | First, download the [pretrained models](https://drive.google.com/open?id=1R9MH_p8tDmUsIAjKCu-jgoilWgANfObx) and put them in `models` folder. 83 | 84 | Run the following command to translate GTA5 images to Cityscape images 85 | 86 | python test.py --trainer UNIT --config configs/unit_gta2city_list.yaml --input inputs/gta_example.jpg --output_folder results/gta2city --checkpoint models/unit_gta2city.pt --a2b 1 87 | 88 | The results are stored in `results/gta2city` folder. You should see images like the following. 89 | 90 | | Input Photo | Output Photo | 91 | |-------------|--------------| 92 | | | | 93 | 94 | Run the following command to translate Cityscape images to GTA5 images 95 | 96 | python test.py --trainer UNIT --config configs/unit_gta2city_list.yaml --input inputs/city_example.jpg --output_folder results/city2gta --checkpoint models/unit_gta2city.pt --a2b 0 97 | 98 | The results are stored in `results/city2gta` folder. You should see images like the following. 99 | 100 | | Input Photo | Output Photo | 101 | |-------------|--------------| 102 | | | | 103 | 104 | -------------------------------------------------------------------------------- /configs/unit_edges2handbags_folder.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_iter: 10000 # How often do you want to save output images during training 6 | image_display_iter: 500 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_iter: 10000 # How often do you want to save trained models 9 | log_iter: 10 # How often do you want to log the training stats 10 | 11 | # optimization options 12 | max_iter: 1000000 # maximum number of training iterations 13 | batch_size: 1 # batch size 14 | weight_decay: 0.0001 # weight decay 15 | beta1: 0.5 # Adam parameter 16 | beta2: 0.999 # Adam parameter 17 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] 18 | lr: 0.0001 # initial learning rate 19 | lr_policy: step # learning rate scheduler 20 | step_size: 100000 # how often to decay learning rate 21 | gamma: 0.5 # how much to decay learning rate 22 | gan_w: 1 # weight of adversarial loss 23 | recon_x_w: 10 # weight of image reconstruction loss 24 | recon_h_w: 0 # weight of hidden reconstruction loss 25 | recon_kl_w: 0.01 # weight of KL loss for reconstruction 26 | recon_x_cyc_w: 10 # weight of cycle consistency loss 27 | recon_kl_cyc_w: 0.01 # weight of KL loss for cycle consistency 28 | vgg_w: 0 # weight of domain-invariant perceptual loss 29 | 30 | # model options 31 | gen: 32 | dim: 64 # number of filters in the bottommost layer 33 | activ: relu # activation function [relu/lrelu/prelu/selu/tanh] 34 | n_downsample: 2 # number of downsampling layers in content encoder 35 | n_res: 4 # number of residual blocks in content encoder/decoder 36 | pad_type: reflect # padding type [zero/reflect] 37 | dis: 38 | dim: 64 # number of filters in the bottommost layer 39 | norm: none # normalization layer [none/bn/in/ln] 40 | activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh] 41 | n_layer: 4 # number of layers in D 42 | gan_type: lsgan # GAN loss [lsgan/nsgan] 43 | num_scales: 3 # number of scales 44 | pad_type: reflect # padding type [zero/reflect] 45 | 46 | # data options 47 | input_dim_a: 3 # number of image channels [1/3] 48 | input_dim_b: 3 # number of image channels [1/3] 49 | num_workers: 8 # number of data loading threads 50 | new_size: 256 # first resize the shortest image side to this size 51 | crop_image_height: 256 # random crop image of this height 52 | crop_image_width: 256 # random crop image of this width 53 | data_root: ./datasets/edges2handbags/ # dataset folder location -------------------------------------------------------------------------------- /configs/unit_edges2shoes_folder.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_iter: 10000 # How often do you want to save output images during training 6 | image_display_iter: 500 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_iter: 10000 # How often do you want to save trained models 9 | log_iter: 10 # How often do you want to log the training stats 10 | 11 | # optimization options 12 | max_iter: 1000000 # maximum number of training iterations 13 | batch_size: 1 # batch size 14 | weight_decay: 0.0001 # weight decay 15 | beta1: 0.5 # Adam parameter 16 | beta2: 0.999 # Adam parameter 17 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] 18 | lr: 0.0001 # initial learning rate 19 | lr_policy: step # learning rate scheduler 20 | step_size: 100000 # how often to decay learning rate 21 | gamma: 0.5 # how much to decay learning rate 22 | gan_w: 1 # weight of adversarial loss 23 | recon_x_w: 10 # weight of image reconstruction loss 24 | recon_h_w: 0 # weight of hidden reconstruction loss 25 | recon_kl_w: 0.01 # weight of KL loss for reconstruction 26 | recon_x_cyc_w: 10 # weight of cycle consistency loss 27 | recon_kl_cyc_w: 0.01 # weight of KL loss for cycle consistency 28 | vgg_w: 0 # weight of domain-invariant perceptual loss 29 | 30 | # model options 31 | gen: 32 | dim: 64 # number of filters in the bottommost layer 33 | activ: relu # activation function [relu/lrelu/prelu/selu/tanh] 34 | n_downsample: 2 # number of downsampling layers in content encoder 35 | n_res: 4 # number of residual blocks in content encoder/decoder 36 | pad_type: reflect # padding type [zero/reflect] 37 | dis: 38 | dim: 64 # number of filters in the bottommost layer 39 | norm: none # normalization layer [none/bn/in/ln] 40 | activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh] 41 | n_layer: 4 # number of layers in D 42 | gan_type: lsgan # GAN loss [lsgan/nsgan] 43 | num_scales: 3 # number of scales 44 | pad_type: reflect # padding type [zero/reflect] 45 | 46 | # data options 47 | input_dim_a: 3 # number of image channels [1/3] 48 | input_dim_b: 3 # number of image channels [1/3] 49 | num_workers: 8 # number of data loading threads 50 | new_size: 256 # first resize the shortest image side to this size 51 | crop_image_height: 256 # random crop image of this height 52 | crop_image_width: 256 # random crop image of this width 53 | data_root: ./datasets/edges2shoes/ # dataset folder location -------------------------------------------------------------------------------- /configs/unit_gta2city_folder.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_iter: 10000 # How often do you want to save output images during training 6 | image_display_iter: 100 # How often do you want to display output images during training 7 | display_size: 4 # How many images do you want to display each time 8 | snapshot_save_iter: 10000 # How often do you want to save trained models 9 | log_iter: 1 # How often do you want to log the training stats 10 | 11 | # optimization options 12 | max_iter: 1000000 # maximum number of training iterations 13 | batch_size: 1 # batch size 14 | weight_decay: 0.0001 # weight decay 15 | beta1: 0.5 # Adam parameter 16 | beta2: 0.999 # Adam parameter 17 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] 18 | lr: 0.0001 # initial learning rate 19 | lr_policy: step # learning rate scheduler 20 | step_size: 100000 # how often to decay learning rate 21 | gamma: 0.5 # how much to decay learning rate 22 | gan_w: 1 # weight of adversarial loss 23 | recon_x_w: 10 # weight of image reconstruction loss 24 | recon_h_w: 0 # weight of hidden reconstruction loss 25 | recon_kl_w: 0.01 # weight of KL loss for reconstruction 26 | recon_x_cyc_w: 10 # weight of cycle consistency loss 27 | recon_kl_cyc_w: 0.01 # weight of KL loss for cycle consistency 28 | vgg_w: 1 # weight of domain-invariant perceptual loss 29 | 30 | # model options 31 | gen: 32 | dim: 64 # number of filters in the bottommost layer 33 | activ: relu # activation function [relu/lrelu/prelu/selu/tanh] 34 | n_downsample: 2 # number of downsampling layers in content encoder 35 | n_res: 4 # number of residual blocks in content encoder/decoder 36 | pad_type: reflect # padding type [zero/reflect] 37 | dis: 38 | dim: 64 # number of filters in the bottommost layer 39 | norm: none # normalization layer [none/bn/in/ln] 40 | activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh] 41 | n_layer: 4 # number of layers in D 42 | gan_type: lsgan # GAN loss [lsgan/nsgan] 43 | num_scales: 3 # number of scales 44 | pad_type: reflect # padding type [zero/reflect] 45 | 46 | # data options 47 | input_dim_a: 3 # number of image channels [1/3] 48 | input_dim_b: 3 # number of image channels [1/3] 49 | num_workers: 8 # number of data loading threads 50 | new_size: 256 # first resize the shortest image side to this size 51 | crop_image_height: 256 # random crop image of this height 52 | crop_image_width: 256 # random crop image of this width 53 | 54 | data_root: ./datasets/gta2city/ # dataset folder location -------------------------------------------------------------------------------- /configs/unit_gta2city_list.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_iter: 10000 # How often do you want to save output images during training 6 | image_display_iter: 100 # How often do you want to display output images during training 7 | display_size: 4 # How many images do you want to display each time 8 | snapshot_save_iter: 10000 # How often do you want to save trained models 9 | log_iter: 1 # How often do you want to log the training stats 10 | 11 | # optimization options 12 | max_iter: 1000000 # maximum number of training iterations 13 | batch_size: 1 # batch size 14 | weight_decay: 0.0001 # weight decay 15 | beta1: 0.5 # Adam parameter 16 | beta2: 0.999 # Adam parameter 17 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] 18 | lr: 0.0001 # initial learning rate 19 | lr_policy: step # learning rate scheduler 20 | step_size: 100000 # how often to decay learning rate 21 | gamma: 0.5 # how much to decay learning rate 22 | gan_w: 1 # weight of adversarial loss 23 | recon_x_w: 10 # weight of image reconstruction loss 24 | recon_h_w: 0 # weight of hidden reconstruction loss 25 | recon_kl_w: 0.01 # weight of KL loss for reconstruction 26 | recon_x_cyc_w: 10 # weight of cycle consistency loss 27 | recon_kl_cyc_w: 0.01 # weight of KL loss for cycle consistency 28 | vgg_w: 1 # weight of domain-invariant perceptual loss 29 | 30 | # model options 31 | gen: 32 | dim: 64 # number of filters in the bottommost layer 33 | activ: relu # activation function [relu/lrelu/prelu/selu/tanh] 34 | n_downsample: 2 # number of downsampling layers in content encoder 35 | n_res: 4 # number of residual blocks in content encoder/decoder 36 | pad_type: reflect # padding type [zero/reflect] 37 | dis: 38 | dim: 64 # number of filters in the bottommost layer 39 | norm: none # normalization layer [none/bn/in/ln] 40 | activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh] 41 | n_layer: 4 # number of layers in D 42 | gan_type: lsgan # GAN loss [lsgan/nsgan] 43 | num_scales: 3 # number of scales 44 | pad_type: reflect # padding type [zero/reflect] 45 | 46 | # data options 47 | input_dim_a: 3 # number of image channels [1/3] 48 | input_dim_b: 3 # number of image channels [1/3] 49 | num_workers: 8 # number of data loading threads 50 | new_size_a: 512 # first resize the shortest image side to this size 51 | new_size_b: 512 # first resize the shortest image side to this size 52 | crop_image_height: 512 # random crop image of this height 53 | crop_image_width: 512 # random crop image of this width 54 | 55 | data_folder_train_a: /mnt/scratch.nvdrivenet/adundar/GTA_Intel_Labs 56 | data_list_train_a: /mnt/scratch.nvdrivenet/adundar/GTA_Intel_Labs/train_images.txt 57 | data_folder_test_a: /mnt/scratch.nvdrivenet/adundar/GTA_Intel_Labs 58 | data_list_test_a: /mnt/scratch.nvdrivenet/adundar/GTA_Intel_Labs/train_images.txt 59 | data_folder_train_b: /mnt/scratch.nvdrivenet/adundar/cityscapes_raw 60 | data_list_train_b: /mnt/scratch.nvdrivenet/adundar/cityscapes_raw/train_images.txt 61 | data_folder_test_b: /mnt/scratch.nvdrivenet/adundar/cityscapes_raw 62 | data_list_test_b: /mnt/scratch.nvdrivenet/adundar/cityscapes_raw/train_images.txt -------------------------------------------------------------------------------- /configs/unit_summer2winter_yosemite256_folder.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | # logger options 5 | image_save_iter: 10000 # How often do you want to save output images during training 6 | image_display_iter: 100 # How often do you want to display output images during training 7 | display_size: 16 # How many images do you want to display each time 8 | snapshot_save_iter: 10000 # How often do you want to save trained models 9 | log_iter: 10 # How often do you want to log the training stats 10 | 11 | # optimization options 12 | max_iter: 1000000 # maximum number of training iterations 13 | batch_size: 1 # batch size 14 | weight_decay: 0.0001 # weight decay 15 | beta1: 0.5 # Adam parameter 16 | beta2: 0.999 # Adam parameter 17 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] 18 | lr: 0.0001 # initial learning rate 19 | lr_policy: step # learning rate scheduler 20 | step_size: 100000 # how often to decay learning rate 21 | gamma: 0.5 # how much to decay learning rate 22 | gan_w: 1 # weight of adversarial loss 23 | recon_x_w: 10 # weight of image reconstruction loss 24 | recon_h_w: 0 # weight of hidden reconstruction loss 25 | recon_kl_w: 0.01 # weight of KL loss for reconstruction 26 | recon_x_cyc_w: 10 # weight of cycle consistency loss 27 | recon_kl_cyc_w: 0.01 # weight of KL loss for cycle consistency 28 | vgg_w: 0 # weight of domain-invariant perceptual loss 29 | 30 | # model options 31 | gen: 32 | dim: 64 # number of filters in the bottommost layer 33 | mlp_dim: 256 # number of filters in MLP 34 | style_dim: 8 # length of style code 35 | activ: relu # activation function [relu/lrelu/prelu/selu/tanh] 36 | n_downsample: 2 # number of downsampling layers in content encoder 37 | n_res: 4 # number of residual blocks in content encoder/decoder 38 | pad_type: reflect # padding type [zero/reflect] 39 | dis: 40 | dim: 64 # number of filters in the bottommost layer 41 | norm: none # normalization layer [none/bn/in/ln] 42 | activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh] 43 | n_layer: 4 # number of layers in D 44 | gan_type: lsgan # GAN loss [lsgan/nsgan] 45 | num_scales: 3 # number of scales 46 | pad_type: reflect # padding type [zero/reflect] 47 | 48 | # data options 49 | input_dim_a: 3 # number of image channels [1/3] 50 | input_dim_b: 3 # number of image channels [1/3] 51 | num_workers: 8 # number of data loading threads 52 | new_size: 256 # first resize the shortest image side to this size 53 | crop_image_height: 256 # random crop image of this height 54 | crop_image_width: 256 # random crop image of this width 55 | data_root: ./datasets/summer2winter_yosemite256/summer2winter_yosemite # dataset folder location -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | import torch.utils.data as data 6 | import os.path 7 | 8 | def default_loader(path): 9 | return Image.open(path).convert('RGB') 10 | 11 | 12 | def default_flist_reader(flist): 13 | """ 14 | flist format: impath label\nimpath label\n ...(same to caffe's filelist) 15 | """ 16 | imlist = [] 17 | with open(flist, 'r') as rf: 18 | for line in rf.readlines(): 19 | impath = line.strip() 20 | imlist.append(impath) 21 | 22 | return imlist 23 | 24 | 25 | class ImageFilelist(data.Dataset): 26 | def __init__(self, root, flist, transform=None, 27 | flist_reader=default_flist_reader, loader=default_loader): 28 | self.root = root 29 | self.imlist = flist_reader(flist) 30 | self.transform = transform 31 | self.loader = loader 32 | 33 | def __getitem__(self, index): 34 | impath = self.imlist[index] 35 | img = self.loader(os.path.join(self.root, impath)) 36 | if self.transform is not None: 37 | img = self.transform(img) 38 | 39 | return img 40 | 41 | def __len__(self): 42 | return len(self.imlist) 43 | 44 | 45 | class ImageLabelFilelist(data.Dataset): 46 | def __init__(self, root, flist, transform=None, 47 | flist_reader=default_flist_reader, loader=default_loader): 48 | self.root = root 49 | self.imlist = flist_reader(os.path.join(self.root, flist)) 50 | self.transform = transform 51 | self.loader = loader 52 | self.classes = sorted(list(set([path.split('/')[0] for path in self.imlist]))) 53 | self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))} 54 | self.imgs = [(impath, self.class_to_idx[impath.split('/')[0]]) for impath in self.imlist] 55 | 56 | def __getitem__(self, index): 57 | impath, label = self.imgs[index] 58 | img = self.loader(os.path.join(self.root, impath)) 59 | if self.transform is not None: 60 | img = self.transform(img) 61 | return img, label 62 | 63 | def __len__(self): 64 | return len(self.imgs) 65 | 66 | ############################################################################### 67 | # Code from 68 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 69 | # Modified the original code so that it also loads images from the current 70 | # directory as well as the subdirectories 71 | ############################################################################### 72 | 73 | import torch.utils.data as data 74 | 75 | from PIL import Image 76 | import os 77 | import os.path 78 | 79 | IMG_EXTENSIONS = [ 80 | '.jpg', '.JPG', '.jpeg', '.JPEG', 81 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 82 | ] 83 | 84 | 85 | def is_image_file(filename): 86 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 87 | 88 | 89 | def make_dataset(dir): 90 | images = [] 91 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 92 | 93 | for root, _, fnames in sorted(os.walk(dir)): 94 | for fname in fnames: 95 | if is_image_file(fname): 96 | path = os.path.join(root, fname) 97 | images.append(path) 98 | 99 | return images 100 | 101 | 102 | class ImageFolder(data.Dataset): 103 | 104 | def __init__(self, root, transform=None, return_paths=False, 105 | loader=default_loader): 106 | imgs = sorted(make_dataset(root)) 107 | if len(imgs) == 0: 108 | raise(RuntimeError("Found 0 images in: " + root + "\n" 109 | "Supported image extensions are: " + 110 | ",".join(IMG_EXTENSIONS))) 111 | 112 | self.root = root 113 | self.imgs = imgs 114 | self.transform = transform 115 | self.return_paths = return_paths 116 | self.loader = loader 117 | 118 | def __getitem__(self, index): 119 | path = self.imgs[index] 120 | img = self.loader(path) 121 | if self.transform is not None: 122 | img = self.transform(img) 123 | if self.return_paths: 124 | return img, path 125 | else: 126 | return img 127 | 128 | def __len__(self): 129 | return len(self.imgs) 130 | -------------------------------------------------------------------------------- /datasets/gta2city/list_testA.txt: -------------------------------------------------------------------------------- 1 | ./00002.jpg 2 | ./00004.jpg 3 | ./00001.jpg 4 | ./00003.jpg 5 | ./00005.jpg 6 | -------------------------------------------------------------------------------- /datasets/gta2city/list_testB.txt: -------------------------------------------------------------------------------- 1 | ./aachen_000000_000019_leftImg8bit.jpg 2 | ./aachen_000002_000019_leftImg8bit.jpg 3 | ./aachen_000001_000019_leftImg8bit.jpg 4 | ./aachen_000003_000019_leftImg8bit.jpg 5 | ./aachen_000004_000019_leftImg8bit.jpg 6 | -------------------------------------------------------------------------------- /datasets/gta2city/list_trainA.txt: -------------------------------------------------------------------------------- 1 | ./00002.jpg 2 | ./00004.jpg 3 | ./00001.jpg 4 | ./00003.jpg 5 | ./00005.jpg 6 | -------------------------------------------------------------------------------- /datasets/gta2city/list_trainB.txt: -------------------------------------------------------------------------------- 1 | ./aachen_000000_000019_leftImg8bit.jpg 2 | ./aachen_000002_000019_leftImg8bit.jpg 3 | ./aachen_000001_000019_leftImg8bit.jpg 4 | ./aachen_000003_000019_leftImg8bit.jpg 5 | ./aachen_000004_000019_leftImg8bit.jpg 6 | -------------------------------------------------------------------------------- /datasets/gta2city/testA/00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/datasets/gta2city/testA/00001.jpg -------------------------------------------------------------------------------- /datasets/gta2city/testA/00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/datasets/gta2city/testA/00002.jpg -------------------------------------------------------------------------------- /datasets/gta2city/testA/00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/datasets/gta2city/testA/00003.jpg -------------------------------------------------------------------------------- /datasets/gta2city/testA/00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/datasets/gta2city/testA/00004.jpg -------------------------------------------------------------------------------- /datasets/gta2city/testA/00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/datasets/gta2city/testA/00005.jpg -------------------------------------------------------------------------------- /datasets/gta2city/testB/aachen_000000_000019_leftImg8bit.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/datasets/gta2city/testB/aachen_000000_000019_leftImg8bit.jpg -------------------------------------------------------------------------------- /datasets/gta2city/testB/aachen_000001_000019_leftImg8bit.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/datasets/gta2city/testB/aachen_000001_000019_leftImg8bit.jpg -------------------------------------------------------------------------------- /datasets/gta2city/testB/aachen_000002_000019_leftImg8bit.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/datasets/gta2city/testB/aachen_000002_000019_leftImg8bit.jpg -------------------------------------------------------------------------------- /datasets/gta2city/testB/aachen_000003_000019_leftImg8bit.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/datasets/gta2city/testB/aachen_000003_000019_leftImg8bit.jpg -------------------------------------------------------------------------------- /datasets/gta2city/testB/aachen_000004_000019_leftImg8bit.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/datasets/gta2city/testB/aachen_000004_000019_leftImg8bit.jpg -------------------------------------------------------------------------------- /datasets/gta2city/trainA/00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/datasets/gta2city/trainA/00001.jpg -------------------------------------------------------------------------------- /datasets/gta2city/trainA/00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/datasets/gta2city/trainA/00002.jpg -------------------------------------------------------------------------------- /datasets/gta2city/trainA/00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/datasets/gta2city/trainA/00003.jpg -------------------------------------------------------------------------------- /datasets/gta2city/trainA/00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/datasets/gta2city/trainA/00004.jpg -------------------------------------------------------------------------------- /datasets/gta2city/trainA/00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/datasets/gta2city/trainA/00005.jpg -------------------------------------------------------------------------------- /datasets/gta2city/trainB/aachen_000000_000019_leftImg8bit.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/datasets/gta2city/trainB/aachen_000000_000019_leftImg8bit.jpg -------------------------------------------------------------------------------- /datasets/gta2city/trainB/aachen_000001_000019_leftImg8bit.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/datasets/gta2city/trainB/aachen_000001_000019_leftImg8bit.jpg -------------------------------------------------------------------------------- /datasets/gta2city/trainB/aachen_000002_000019_leftImg8bit.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/datasets/gta2city/trainB/aachen_000002_000019_leftImg8bit.jpg -------------------------------------------------------------------------------- /datasets/gta2city/trainB/aachen_000003_000019_leftImg8bit.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/datasets/gta2city/trainB/aachen_000003_000019_leftImg8bit.jpg -------------------------------------------------------------------------------- /datasets/gta2city/trainB/aachen_000004_000019_leftImg8bit.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/datasets/gta2city/trainB/aachen_000004_000019_leftImg8bit.jpg -------------------------------------------------------------------------------- /docs/cat_species.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/docs/cat_species.gif -------------------------------------------------------------------------------- /docs/cat_trans.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/docs/cat_trans.png -------------------------------------------------------------------------------- /docs/day2night.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/docs/day2night.gif -------------------------------------------------------------------------------- /docs/dog_breed.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/docs/dog_breed.gif -------------------------------------------------------------------------------- /docs/dog_trans.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/docs/dog_trans.png -------------------------------------------------------------------------------- /docs/faces.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/docs/faces.png -------------------------------------------------------------------------------- /docs/shared-latent-space.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/docs/shared-latent-space.png -------------------------------------------------------------------------------- /docs/snowy2summery.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/docs/snowy2summery.gif -------------------------------------------------------------------------------- /docs/street_scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/docs/street_scene.png -------------------------------------------------------------------------------- /docs/two-minute-paper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/docs/two-minute-paper.png -------------------------------------------------------------------------------- /docs/unit_nips_2017.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/docs/unit_nips_2017.pdf -------------------------------------------------------------------------------- /inputs/city_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/inputs/city_example.jpg -------------------------------------------------------------------------------- /inputs/gta_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/inputs/gta_example.jpg -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | from torch import nn 6 | from torch.autograd import Variable 7 | import torch 8 | import torch.nn.functional as F 9 | try: 10 | from itertools import izip as zip 11 | except ImportError: # will be 3.x series 12 | pass 13 | 14 | ################################################################################## 15 | # Discriminator 16 | ################################################################################## 17 | 18 | class MsImageDis(nn.Module): 19 | # Multi-scale discriminator architecture 20 | def __init__(self, input_dim, params): 21 | super(MsImageDis, self).__init__() 22 | self.n_layer = params['n_layer'] 23 | self.gan_type = params['gan_type'] 24 | self.dim = params['dim'] 25 | self.norm = params['norm'] 26 | self.activ = params['activ'] 27 | self.num_scales = params['num_scales'] 28 | self.pad_type = params['pad_type'] 29 | self.input_dim = input_dim 30 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) 31 | self.cnns = nn.ModuleList() 32 | for _ in range(self.num_scales): 33 | self.cnns.append(self._make_net()) 34 | 35 | def _make_net(self): 36 | dim = self.dim 37 | cnn_x = [] 38 | cnn_x += [Conv2dBlock(self.input_dim, dim, 4, 2, 1, norm='none', activation=self.activ, pad_type=self.pad_type)] 39 | for i in range(self.n_layer - 1): 40 | cnn_x += [Conv2dBlock(dim, dim * 2, 4, 2, 1, norm=self.norm, activation=self.activ, pad_type=self.pad_type)] 41 | dim *= 2 42 | cnn_x += [nn.Conv2d(dim, 1, 1, 1, 0)] 43 | cnn_x = nn.Sequential(*cnn_x) 44 | return cnn_x 45 | 46 | def forward(self, x): 47 | outputs = [] 48 | for model in self.cnns: 49 | outputs.append(model(x)) 50 | x = self.downsample(x) 51 | return outputs 52 | 53 | def calc_dis_loss(self, input_fake, input_real): 54 | # calculate the loss to train D 55 | outs0 = self.forward(input_fake) 56 | outs1 = self.forward(input_real) 57 | loss = 0 58 | 59 | for it, (out0, out1) in enumerate(zip(outs0, outs1)): 60 | if self.gan_type == 'lsgan': 61 | loss += torch.mean((out0 - 0)**2) + torch.mean((out1 - 1)**2) 62 | elif self.gan_type == 'nsgan': 63 | all0 = Variable(torch.zeros_like(out0.data).cuda(), requires_grad=False) 64 | all1 = Variable(torch.ones_like(out1.data).cuda(), requires_grad=False) 65 | loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all0) + 66 | F.binary_cross_entropy(F.sigmoid(out1), all1)) 67 | else: 68 | assert 0, "Unsupported GAN type: {}".format(self.gan_type) 69 | return loss 70 | 71 | def calc_gen_loss(self, input_fake): 72 | # calculate the loss to train G 73 | outs0 = self.forward(input_fake) 74 | loss = 0 75 | for it, (out0) in enumerate(outs0): 76 | if self.gan_type == 'lsgan': 77 | loss += torch.mean((out0 - 1)**2) # LSGAN 78 | elif self.gan_type == 'nsgan': 79 | all1 = Variable(torch.ones_like(out0.data).cuda(), requires_grad=False) 80 | loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all1)) 81 | else: 82 | assert 0, "Unsupported GAN type: {}".format(self.gan_type) 83 | return loss 84 | 85 | ################################################################################## 86 | # Generator 87 | ################################################################################## 88 | 89 | class AdaINGen(nn.Module): 90 | # AdaIN auto-encoder architecture 91 | def __init__(self, input_dim, params): 92 | super(AdaINGen, self).__init__() 93 | dim = params['dim'] 94 | style_dim = params['style_dim'] 95 | n_downsample = params['n_downsample'] 96 | n_res = params['n_res'] 97 | activ = params['activ'] 98 | pad_type = params['pad_type'] 99 | mlp_dim = params['mlp_dim'] 100 | 101 | # style encoder 102 | self.enc_style = StyleEncoder(4, input_dim, dim, style_dim, norm='none', activ=activ, pad_type=pad_type) 103 | 104 | # content encoder 105 | self.enc_content = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type) 106 | self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, input_dim, res_norm='adain', activ=activ, pad_type=pad_type) 107 | 108 | # MLP to generate AdaIN parameters 109 | self.mlp = MLP(style_dim, self.get_num_adain_params(self.dec), mlp_dim, 3, norm='none', activ=activ) 110 | 111 | def forward(self, images): 112 | # reconstruct an image 113 | content, style_fake = self.encode(images) 114 | images_recon = self.decode(content, style_fake) 115 | return images_recon 116 | 117 | def encode(self, images): 118 | # encode an image to its content and style codes 119 | style_fake = self.enc_style(images) 120 | content = self.enc_content(images) 121 | return content, style_fake 122 | 123 | def decode(self, content, style): 124 | # decode content and style codes to an image 125 | adain_params = self.mlp(style) 126 | self.assign_adain_params(adain_params, self.dec) 127 | images = self.dec(content) 128 | return images 129 | 130 | def assign_adain_params(self, adain_params, model): 131 | # assign the adain_params to the AdaIN layers in model 132 | for m in model.modules(): 133 | if m.__class__.__name__ == "AdaptiveInstanceNorm2d": 134 | mean = adain_params[:, :m.num_features] 135 | std = adain_params[:, m.num_features:2*m.num_features] 136 | m.bias = mean.contiguous().view(-1) 137 | m.weight = std.contiguous().view(-1) 138 | if adain_params.size(1) > 2*m.num_features: 139 | adain_params = adain_params[:, 2*m.num_features:] 140 | 141 | def get_num_adain_params(self, model): 142 | # return the number of AdaIN parameters needed by the model 143 | num_adain_params = 0 144 | for m in model.modules(): 145 | if m.__class__.__name__ == "AdaptiveInstanceNorm2d": 146 | num_adain_params += 2*m.num_features 147 | return num_adain_params 148 | 149 | 150 | class VAEGen(nn.Module): 151 | # VAE architecture 152 | def __init__(self, input_dim, params): 153 | super(VAEGen, self).__init__() 154 | dim = params['dim'] 155 | n_downsample = params['n_downsample'] 156 | n_res = params['n_res'] 157 | activ = params['activ'] 158 | pad_type = params['pad_type'] 159 | 160 | # content encoder 161 | self.enc = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type) 162 | self.dec = Decoder(n_downsample, n_res, self.enc.output_dim, input_dim, res_norm='in', activ=activ, pad_type=pad_type) 163 | 164 | def forward(self, images): 165 | # This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones. 166 | hiddens = self.encode(images) 167 | if self.training == True: 168 | noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device())) 169 | images_recon = self.decode(hiddens + noise) 170 | else: 171 | images_recon = self.decode(hiddens) 172 | return images_recon, hiddens 173 | 174 | def encode(self, images): 175 | hiddens = self.enc(images) 176 | noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device())) 177 | return hiddens, noise 178 | 179 | def decode(self, hiddens): 180 | images = self.dec(hiddens) 181 | return images 182 | 183 | 184 | ################################################################################## 185 | # Encoder and Decoders 186 | ################################################################################## 187 | 188 | class StyleEncoder(nn.Module): 189 | def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type): 190 | super(StyleEncoder, self).__init__() 191 | self.model = [] 192 | self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] 193 | for i in range(2): 194 | self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] 195 | dim *= 2 196 | for i in range(n_downsample - 2): 197 | self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] 198 | self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling 199 | self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)] 200 | self.model = nn.Sequential(*self.model) 201 | self.output_dim = dim 202 | 203 | def forward(self, x): 204 | return self.model(x) 205 | 206 | class ContentEncoder(nn.Module): 207 | def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): 208 | super(ContentEncoder, self).__init__() 209 | self.model = [] 210 | self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] 211 | # downsampling blocks 212 | for i in range(n_downsample): 213 | self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] 214 | dim *= 2 215 | # residual blocks 216 | self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)] 217 | self.model = nn.Sequential(*self.model) 218 | self.output_dim = dim 219 | 220 | def forward(self, x): 221 | return self.model(x) 222 | 223 | class Decoder(nn.Module): 224 | def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='relu', pad_type='zero'): 225 | super(Decoder, self).__init__() 226 | 227 | self.model = [] 228 | # AdaIN residual blocks 229 | self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)] 230 | # upsampling blocks 231 | for i in range(n_upsample): 232 | self.model += [nn.Upsample(scale_factor=2), 233 | Conv2dBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)] 234 | dim //= 2 235 | # use reflection padding in the last conv layer 236 | self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)] 237 | self.model = nn.Sequential(*self.model) 238 | 239 | def forward(self, x): 240 | return self.model(x) 241 | 242 | ################################################################################## 243 | # Sequential Models 244 | ################################################################################## 245 | class ResBlocks(nn.Module): 246 | def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'): 247 | super(ResBlocks, self).__init__() 248 | self.model = [] 249 | for i in range(num_blocks): 250 | self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)] 251 | self.model = nn.Sequential(*self.model) 252 | 253 | def forward(self, x): 254 | return self.model(x) 255 | 256 | class MLP(nn.Module): 257 | def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'): 258 | 259 | super(MLP, self).__init__() 260 | self.model = [] 261 | self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)] 262 | for i in range(n_blk - 2): 263 | self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)] 264 | self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations 265 | self.model = nn.Sequential(*self.model) 266 | 267 | def forward(self, x): 268 | return self.model(x.view(x.size(0), -1)) 269 | 270 | ################################################################################## 271 | # Basic Blocks 272 | ################################################################################## 273 | class ResBlock(nn.Module): 274 | def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): 275 | super(ResBlock, self).__init__() 276 | 277 | model = [] 278 | model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] 279 | model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] 280 | self.model = nn.Sequential(*model) 281 | 282 | def forward(self, x): 283 | residual = x 284 | out = self.model(x) 285 | out += residual 286 | return out 287 | 288 | class Conv2dBlock(nn.Module): 289 | def __init__(self, input_dim ,output_dim, kernel_size, stride, 290 | padding=0, norm='none', activation='relu', pad_type='zero'): 291 | super(Conv2dBlock, self).__init__() 292 | self.use_bias = True 293 | # initialize padding 294 | if pad_type == 'reflect': 295 | self.pad = nn.ReflectionPad2d(padding) 296 | elif pad_type == 'replicate': 297 | self.pad = nn.ReplicationPad2d(padding) 298 | elif pad_type == 'zero': 299 | self.pad = nn.ZeroPad2d(padding) 300 | else: 301 | assert 0, "Unsupported padding type: {}".format(pad_type) 302 | 303 | # initialize normalization 304 | norm_dim = output_dim 305 | if norm == 'bn': 306 | self.norm = nn.BatchNorm2d(norm_dim) 307 | elif norm == 'in': 308 | #self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True) 309 | self.norm = nn.InstanceNorm2d(norm_dim) 310 | elif norm == 'ln': 311 | self.norm = LayerNorm(norm_dim) 312 | elif norm == 'adain': 313 | self.norm = AdaptiveInstanceNorm2d(norm_dim) 314 | elif norm == 'none': 315 | self.norm = None 316 | else: 317 | assert 0, "Unsupported normalization: {}".format(norm) 318 | 319 | # initialize activation 320 | if activation == 'relu': 321 | self.activation = nn.ReLU(inplace=True) 322 | elif activation == 'lrelu': 323 | self.activation = nn.LeakyReLU(0.2, inplace=True) 324 | elif activation == 'prelu': 325 | self.activation = nn.PReLU() 326 | elif activation == 'selu': 327 | self.activation = nn.SELU(inplace=True) 328 | elif activation == 'tanh': 329 | self.activation = nn.Tanh() 330 | elif activation == 'none': 331 | self.activation = None 332 | else: 333 | assert 0, "Unsupported activation: {}".format(activation) 334 | 335 | # initialize convolution 336 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) 337 | 338 | def forward(self, x): 339 | x = self.conv(self.pad(x)) 340 | if self.norm: 341 | x = self.norm(x) 342 | if self.activation: 343 | x = self.activation(x) 344 | return x 345 | 346 | class LinearBlock(nn.Module): 347 | def __init__(self, input_dim, output_dim, norm='none', activation='relu'): 348 | super(LinearBlock, self).__init__() 349 | use_bias = True 350 | # initialize fully connected layer 351 | self.fc = nn.Linear(input_dim, output_dim, bias=use_bias) 352 | 353 | # initialize normalization 354 | norm_dim = output_dim 355 | if norm == 'bn': 356 | self.norm = nn.BatchNorm1d(norm_dim) 357 | elif norm == 'in': 358 | self.norm = nn.InstanceNorm1d(norm_dim) 359 | elif norm == 'ln': 360 | self.norm = LayerNorm(norm_dim) 361 | elif norm == 'none': 362 | self.norm = None 363 | else: 364 | assert 0, "Unsupported normalization: {}".format(norm) 365 | 366 | # initialize activation 367 | if activation == 'relu': 368 | self.activation = nn.ReLU(inplace=True) 369 | elif activation == 'lrelu': 370 | self.activation = nn.LeakyReLU(0.2, inplace=True) 371 | elif activation == 'prelu': 372 | self.activation = nn.PReLU() 373 | elif activation == 'selu': 374 | self.activation = nn.SELU(inplace=True) 375 | elif activation == 'tanh': 376 | self.activation = nn.Tanh() 377 | elif activation == 'none': 378 | self.activation = None 379 | else: 380 | assert 0, "Unsupported activation: {}".format(activation) 381 | 382 | def forward(self, x): 383 | out = self.fc(x) 384 | if self.norm: 385 | out = self.norm(out) 386 | if self.activation: 387 | out = self.activation(out) 388 | return out 389 | 390 | ################################################################################## 391 | # VGG network definition 392 | ################################################################################## 393 | class Vgg16(nn.Module): 394 | def __init__(self): 395 | super(Vgg16, self).__init__() 396 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 397 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 398 | 399 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) 400 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 401 | 402 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) 403 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 404 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 405 | 406 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) 407 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 408 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 409 | 410 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 411 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 412 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 413 | 414 | def forward(self, X): 415 | h = F.relu(self.conv1_1(X), inplace=True) 416 | h = F.relu(self.conv1_2(h), inplace=True) 417 | # relu1_2 = h 418 | h = F.max_pool2d(h, kernel_size=2, stride=2) 419 | 420 | h = F.relu(self.conv2_1(h), inplace=True) 421 | h = F.relu(self.conv2_2(h), inplace=True) 422 | # relu2_2 = h 423 | h = F.max_pool2d(h, kernel_size=2, stride=2) 424 | 425 | h = F.relu(self.conv3_1(h), inplace=True) 426 | h = F.relu(self.conv3_2(h), inplace=True) 427 | h = F.relu(self.conv3_3(h), inplace=True) 428 | # relu3_3 = h 429 | h = F.max_pool2d(h, kernel_size=2, stride=2) 430 | 431 | h = F.relu(self.conv4_1(h), inplace=True) 432 | h = F.relu(self.conv4_2(h), inplace=True) 433 | h = F.relu(self.conv4_3(h), inplace=True) 434 | # relu4_3 = h 435 | 436 | h = F.relu(self.conv5_1(h), inplace=True) 437 | h = F.relu(self.conv5_2(h), inplace=True) 438 | h = F.relu(self.conv5_3(h), inplace=True) 439 | relu5_3 = h 440 | 441 | return relu5_3 442 | # return [relu1_2, relu2_2, relu3_3, relu4_3] 443 | 444 | ################################################################################## 445 | # Normalization layers 446 | ################################################################################## 447 | class AdaptiveInstanceNorm2d(nn.Module): 448 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 449 | super(AdaptiveInstanceNorm2d, self).__init__() 450 | self.num_features = num_features 451 | self.eps = eps 452 | self.momentum = momentum 453 | # weight and bias are dynamically assigned 454 | self.weight = None 455 | self.bias = None 456 | # just dummy buffers, not used 457 | self.register_buffer('running_mean', torch.zeros(num_features)) 458 | self.register_buffer('running_var', torch.ones(num_features)) 459 | 460 | def forward(self, x): 461 | assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!" 462 | b, c = x.size(0), x.size(1) 463 | running_mean = self.running_mean.repeat(b) 464 | running_var = self.running_var.repeat(b) 465 | 466 | # Apply instance norm 467 | x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) 468 | 469 | out = F.batch_norm( 470 | x_reshaped, running_mean, running_var, self.weight, self.bias, 471 | True, self.momentum, self.eps) 472 | 473 | return out.view(b, c, *x.size()[2:]) 474 | 475 | def __repr__(self): 476 | return self.__class__.__name__ + '(' + str(self.num_features) + ')' 477 | 478 | 479 | class LayerNorm(nn.Module): 480 | def __init__(self, num_features, eps=1e-5, affine=True): 481 | super(LayerNorm, self).__init__() 482 | self.num_features = num_features 483 | self.affine = affine 484 | self.eps = eps 485 | 486 | if self.affine: 487 | self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) 488 | self.beta = nn.Parameter(torch.zeros(num_features)) 489 | 490 | def forward(self, x): 491 | shape = [-1] + [1] * (x.dim() - 1) 492 | # print(x.size()) 493 | if x.size(0) == 1: 494 | # These two lines run much faster in pytorch 0.4 than the two lines listed below. 495 | mean = x.view(-1).mean().view(*shape) 496 | std = x.view(-1).std().view(*shape) 497 | else: 498 | mean = x.view(x.size(0), -1).mean(1).view(*shape) 499 | std = x.view(x.size(0), -1).std(1).view(*shape) 500 | 501 | x = (x - mean) / (std + self.eps) 502 | 503 | if self.affine: 504 | shape = [1, -1] + [1] * (x.dim() - 2) 505 | x = x * self.gamma.view(*shape) + self.beta.view(*shape) 506 | return x 507 | 508 | -------------------------------------------------------------------------------- /results/city2gta/input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/results/city2gta/input.jpg -------------------------------------------------------------------------------- /results/city2gta/output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/results/city2gta/output.jpg -------------------------------------------------------------------------------- /results/gta2city/input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/results/gta2city/input.jpg -------------------------------------------------------------------------------- /results/gta2city/output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuliutw/UNIT/28d5982934e9abf4dbdaaa72650bdd3a128b5519/results/gta2city/output.jpg -------------------------------------------------------------------------------- /scripts/unit_demo_train_edges2handbags.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | rm datasets/edges2handbags -rf 3 | mkdir datasets/edges2handbags -p 4 | axel -n 1 https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/edges2handbags.tar.gz --output=datasets/edges2handbags/edges2handbags.tar.gz 5 | tar -zxvf datasets/edges2handbags/edges2handbags.tar.gz -C datasets/ 6 | mkdir datasets/edges2handbags/train1 -p 7 | mkdir datasets/edges2handbags/train0 -p 8 | mkdir datasets/edges2handbags/test1 -p 9 | mkdir datasets/edges2handbags/test0 -p 10 | for f in datasets/edges2handbags/train/*; do convert -quality 100 -crop 50%x100% +repage $f datasets/edges2handbags/train%d/${f##*/}; done; 11 | for f in datasets/edges2handbags/val/*; do convert -quality 100 -crop 50%x100% +repage $f datasets/edges2handbags/test%d/${f##*/}; done; 12 | mv datasets/edges2handbags/train0 datasets/edges2handbags/trainA 13 | mv datasets/edges2handbags/train1 datasets/edges2handbags/trainB 14 | mv datasets/edges2handbags/test0 datasets/edges2handbags/testA 15 | mv datasets/edges2handbags/test1 datasets/edges2handbags/testB 16 | python train.py --config configs/unit_edges2handbags_folder.yaml --trainer UNIT 17 | 18 | -------------------------------------------------------------------------------- /scripts/unit_demo_train_edges2shoes.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | rm datasets/edges2shoes -rf 3 | mkdir datasets/edges2shoes -p 4 | axel -n 1 https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/edges2shoes.tar.gz --output=datasets/edges2shoes/edges2shoes.tar.gz 5 | tar -zxvf datasets/edges2shoes/edges2shoes.tar.gz -C datasets 6 | mkdir datasets/edges2shoes/train1 -p 7 | mkdir datasets/edges2shoes/train0 -p 8 | mkdir datasets/edges2shoes/test1 -p 9 | mkdir datasets/edges2shoes/test0 -p 10 | for f in datasets/edges2shoes/train/*; do convert -quality 100 -crop 50%x100% +repage $f datasets/edges2shoes/train%d/${f##*/}; done; 11 | for f in datasets/edges2shoes/val/*; do convert -quality 100 -crop 50%x100% +repage $f datasets/edges2shoes/test%d/${f##*/}; done; 12 | mv datasets/edges2shoes/train0 datasets/edges2shoes/trainA 13 | mv datasets/edges2shoes/train1 datasets/edges2shoes/trainB 14 | mv datasets/edges2shoes/test0 datasets/edges2shoes/testA 15 | mv datasets/edges2shoes/test1 datasets/edges2shoes/testB 16 | python train.py --config configs/unit_edges2shoes_folder.yaml --trainer UNIT 17 | -------------------------------------------------------------------------------- /scripts/unit_demo_train_summer2winter_yosemite256.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | rm datasets/summer2winter_yosemite256 -p 3 | mkdir datasets/summer2winter_yosemite256 -p 4 | axel -n 1 https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/summer2winter_yosemite.zip --output=datasets/summer2winter_yosemite256/summer2winter_yosemite.zip 5 | unzip datasets/summer2winter_yosemite256/summer2winter_yosemite.zip -d datasets/summer2winter_yosemite256 6 | python train.py --config configs/unit_summer2winter_yosemite256_folder.yaml --trainer UNIT 7 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | from __future__ import print_function 6 | from utils import get_config, pytorch03_to_pytorch04 7 | from trainer import MUNIT_Trainer, UNIT_Trainer 8 | import argparse 9 | from torch.autograd import Variable 10 | import torchvision.utils as vutils 11 | import sys 12 | import torch 13 | import os 14 | from torchvision import transforms 15 | from PIL import Image 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--config', type=str, help="net configuration") 19 | parser.add_argument('--input', type=str, help="input image path") 20 | parser.add_argument('--output_folder', type=str, help="output image path") 21 | parser.add_argument('--checkpoint', type=str, help="checkpoint of autoencoders") 22 | parser.add_argument('--style', type=str, default='', help="style image path") 23 | parser.add_argument('--a2b', type=int, default=1, help="1 for a2b and others for b2a") 24 | parser.add_argument('--seed', type=int, default=10, help="random seed") 25 | parser.add_argument('--num_style',type=int, default=10, help="number of styles to sample") 26 | parser.add_argument('--synchronized', action='store_true', help="whether use synchronized style code or not") 27 | parser.add_argument('--output_only', action='store_true', help="whether use synchronized style code or not") 28 | parser.add_argument('--output_path', type=str, default='.', help="path for logs, checkpoints, and VGG model weight") 29 | parser.add_argument('--trainer', type=str, default='MUNIT', help="MUNIT|UNIT") 30 | opts = parser.parse_args() 31 | 32 | 33 | 34 | torch.manual_seed(opts.seed) 35 | torch.cuda.manual_seed(opts.seed) 36 | if not os.path.exists(opts.output_folder): 37 | os.makedirs(opts.output_folder) 38 | 39 | # Load experiment setting 40 | config = get_config(opts.config) 41 | opts.num_style = 1 if opts.style != '' else opts.num_style 42 | 43 | # Setup model and data loader 44 | config['vgg_model_path'] = opts.output_path 45 | if opts.trainer == 'MUNIT': 46 | style_dim = config['gen']['style_dim'] 47 | trainer = MUNIT_Trainer(config) 48 | elif opts.trainer == 'UNIT': 49 | trainer = UNIT_Trainer(config) 50 | else: 51 | sys.exit("Only support MUNIT|UNIT") 52 | 53 | try: 54 | state_dict = torch.load(opts.checkpoint) 55 | trainer.gen_a.load_state_dict(state_dict['a']) 56 | trainer.gen_b.load_state_dict(state_dict['b']) 57 | except: 58 | state_dict = pytorch03_to_pytorch04(torch.load(opts.checkpoint)) 59 | trainer.gen_a.load_state_dict(state_dict['a']) 60 | trainer.gen_b.load_state_dict(state_dict['b']) 61 | 62 | trainer.cuda() 63 | trainer.eval() 64 | encode = trainer.gen_a.encode if opts.a2b else trainer.gen_b.encode # encode function 65 | style_encode = trainer.gen_b.encode if opts.a2b else trainer.gen_a.encode # encode function 66 | decode = trainer.gen_b.decode if opts.a2b else trainer.gen_a.decode # decode function 67 | 68 | if 'new_size' in config: 69 | new_size = config['new_size'] 70 | else: 71 | if opts.a2b==1: 72 | new_size = config['new_size_a'] 73 | else: 74 | new_size = config['new_size_b'] 75 | 76 | with torch.no_grad(): 77 | transform = transforms.Compose([transforms.Resize(new_size), 78 | transforms.ToTensor(), 79 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 80 | image = Variable(transform(Image.open(opts.input).convert('RGB')).unsqueeze(0).cuda()) 81 | style_image = Variable(transform(Image.open(opts.style).convert('RGB')).unsqueeze(0).cuda()) if opts.style != '' else None 82 | 83 | # Start testing 84 | content, _ = encode(image) 85 | 86 | if opts.trainer == 'MUNIT': 87 | style_rand = Variable(torch.randn(opts.num_style, style_dim, 1, 1).cuda()) 88 | if opts.style != '': 89 | _, style = style_encode(style_image) 90 | else: 91 | style = style_rand 92 | for j in range(opts.num_style): 93 | s = style[j].unsqueeze(0) 94 | outputs = decode(content, s) 95 | outputs = (outputs + 1) / 2. 96 | path = os.path.join(opts.output_folder, 'output{:03d}.jpg'.format(j)) 97 | vutils.save_image(outputs.data, path, padding=0, normalize=True) 98 | elif opts.trainer == 'UNIT': 99 | outputs = decode(content) 100 | outputs = (outputs + 1) / 2. 101 | path = os.path.join(opts.output_folder, 'output.jpg') 102 | vutils.save_image(outputs.data, path, padding=0, normalize=True) 103 | else: 104 | pass 105 | 106 | if not opts.output_only: 107 | # also save input images 108 | vutils.save_image(image.data, os.path.join(opts.output_folder, 'input.jpg'), padding=0, normalize=True) 109 | 110 | -------------------------------------------------------------------------------- /test_batch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | from __future__ import print_function 6 | from utils import get_config, get_data_loader_folder, pytorch03_to_pytorch04 7 | from trainer import MUNIT_Trainer, UNIT_Trainer 8 | import argparse 9 | from torch.autograd import Variable 10 | from data import ImageFolder 11 | import torchvision.utils as vutils 12 | try: 13 | from itertools import izip as zip 14 | except ImportError: # will be 3.x series 15 | pass 16 | import sys 17 | import torch 18 | import os 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--config', type=str, default='configs/edges2handbags_folder', help='Path to the config file.') 23 | parser.add_argument('--input_folder', type=str, help="input image folder") 24 | parser.add_argument('--output_folder', type=str, help="output image folder") 25 | parser.add_argument('--checkpoint', type=str, help="checkpoint of autoencoders") 26 | parser.add_argument('--a2b', type=int, help="1 for a2b and others for b2a", default=1) 27 | parser.add_argument('--seed', type=int, default=1, help="random seed") 28 | parser.add_argument('--num_style',type=int, default=10, help="number of styles to sample") 29 | parser.add_argument('--synchronized', action='store_true', help="whether use synchronized style code or not") 30 | parser.add_argument('--output_only', action='store_true', help="whether use synchronized style code or not") 31 | parser.add_argument('--output_path', type=str, default='.', help="path for logs, checkpoints, and VGG model weight") 32 | parser.add_argument('--trainer', type=str, default='MUNIT', help="MUNIT|UNIT") 33 | 34 | opts = parser.parse_args() 35 | 36 | 37 | torch.manual_seed(opts.seed) 38 | torch.cuda.manual_seed(opts.seed) 39 | 40 | # Load experiment setting 41 | config = get_config(opts.config) 42 | input_dim = config['input_dim_a'] if opts.a2b else config['input_dim_b'] 43 | 44 | # Setup model and data loader 45 | image_names = ImageFolder(opts.input_folder, transform=None, return_paths=True) 46 | data_loader = get_data_loader_folder(opts.input_folder, 1, False, new_size=config['new_size_a'], crop=False) 47 | 48 | config['vgg_model_path'] = opts.output_path 49 | if opts.trainer == 'MUNIT': 50 | style_dim = config['gen']['style_dim'] 51 | trainer = MUNIT_Trainer(config) 52 | elif opts.trainer == 'UNIT': 53 | trainer = UNIT_Trainer(config) 54 | else: 55 | sys.exit("Only support MUNIT|UNIT") 56 | 57 | try: 58 | state_dict = torch.load(opts.checkpoint) 59 | trainer.gen_a.load_state_dict(state_dict['a']) 60 | trainer.gen_b.load_state_dict(state_dict['b']) 61 | except: 62 | state_dict = pytorch03_to_pytorch04(torch.load(opts.checkpoint)) 63 | trainer.gen_a.load_state_dict(state_dict['a']) 64 | trainer.gen_b.load_state_dict(state_dict['b']) 65 | 66 | trainer.cuda() 67 | trainer.eval() 68 | encode = trainer.gen_a.encode if opts.a2b else trainer.gen_b.encode # encode function 69 | decode = trainer.gen_b.decode if opts.a2b else trainer.gen_a.decode # decode function 70 | 71 | if opts.trainer == 'MUNIT': 72 | # Start testing 73 | style_fixed = Variable(torch.randn(opts.num_style, style_dim, 1, 1).cuda(), volatile=True) 74 | for i, (images, names) in enumerate(zip(data_loader, image_names)): 75 | print(names[1]) 76 | images = Variable(images.cuda(), volatile=True) 77 | content, _ = encode(images) 78 | style = style_fixed if opts.synchronized else Variable(torch.randn(opts.num_style, style_dim, 1, 1).cuda(), volatile=True) 79 | for j in range(opts.num_style): 80 | s = style[j].unsqueeze(0) 81 | outputs = decode(content, s) 82 | outputs = (outputs + 1) / 2. 83 | # path = os.path.join(opts.output_folder, 'input{:03d}_output{:03d}.jpg'.format(i, j)) 84 | basename = os.path.basename(names[1]) 85 | path = os.path.join(opts.output_folder+"_%02d"%j,basename) 86 | if not os.path.exists(os.path.dirname(path)): 87 | os.makedirs(os.path.dirname(path)) 88 | vutils.save_image(outputs.data, path, padding=0, normalize=True) 89 | if not opts.output_only: 90 | # also save input images 91 | vutils.save_image(images.data, os.path.join(opts.output_folder, 'input{:03d}.jpg'.format(i)), padding=0, normalize=True) 92 | elif opts.trainer == 'UNIT': 93 | # Start testing 94 | for i, (images, names) in enumerate(zip(data_loader, image_names)): 95 | print(names[1]) 96 | images = Variable(images.cuda(), volatile=True) 97 | content, _ = encode(images) 98 | 99 | outputs = decode(content) 100 | outputs = (outputs + 1) / 2. 101 | # path = os.path.join(opts.output_folder, 'input{:03d}_output{:03d}.jpg'.format(i, j)) 102 | basename = os.path.basename(names[1]) 103 | path = os.path.join(opts.output_folder,basename) 104 | if not os.path.exists(os.path.dirname(path)): 105 | os.makedirs(os.path.dirname(path)) 106 | vutils.save_image(outputs.data, path, padding=0, normalize=True) 107 | if not opts.output_only: 108 | # also save input images 109 | vutils.save_image(images.data, os.path.join(opts.output_folder, 'input{:03d}.jpg'.format(i)), padding=0, normalize=True) 110 | else: 111 | pass 112 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | from utils import get_all_data_loaders, prepare_sub_folder, write_html, write_loss, get_config, write_2images, Timer 6 | import argparse 7 | from torch.autograd import Variable 8 | from trainer import MUNIT_Trainer, UNIT_Trainer 9 | import torch.backends.cudnn as cudnn 10 | import torch 11 | try: 12 | from itertools import izip as zip 13 | except ImportError: # will be 3.x series 14 | pass 15 | import os 16 | import sys 17 | import tensorboardX 18 | import shutil 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--config', type=str, default='configs/edges2handbags_folder.yaml', help='Path to the config file.') 22 | parser.add_argument('--output_path', type=str, default='.', help="outputs path") 23 | parser.add_argument("--resume", action="store_true") 24 | parser.add_argument('--trainer', type=str, default='MUNIT', help="MUNIT|UNIT") 25 | opts = parser.parse_args() 26 | 27 | cudnn.benchmark = True 28 | 29 | # Load experiment setting 30 | config = get_config(opts.config) 31 | max_iter = config['max_iter'] 32 | display_size = config['display_size'] 33 | config['vgg_model_path'] = opts.output_path 34 | 35 | # Setup model and data loader 36 | if opts.trainer == 'MUNIT': 37 | trainer = MUNIT_Trainer(config) 38 | elif opts.trainer == 'UNIT': 39 | trainer = UNIT_Trainer(config) 40 | else: 41 | sys.exit("Only support MUNIT|UNIT") 42 | trainer.cuda() 43 | train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(config) 44 | train_display_images_a = torch.stack([train_loader_a.dataset[i] for i in range(display_size)]).cuda() 45 | train_display_images_b = torch.stack([train_loader_b.dataset[i] for i in range(display_size)]).cuda() 46 | test_display_images_a = torch.stack([test_loader_a.dataset[i] for i in range(display_size)]).cuda() 47 | test_display_images_b = torch.stack([test_loader_b.dataset[i] for i in range(display_size)]).cuda() 48 | 49 | # Setup logger and output folders 50 | model_name = os.path.splitext(os.path.basename(opts.config))[0] 51 | train_writer = tensorboardX.SummaryWriter(os.path.join(opts.output_path + "/logs", model_name)) 52 | output_directory = os.path.join(opts.output_path + "/outputs", model_name) 53 | checkpoint_directory, image_directory = prepare_sub_folder(output_directory) 54 | shutil.copy(opts.config, os.path.join(output_directory, 'config.yaml')) # copy config file to output folder 55 | 56 | # Start training 57 | iterations = trainer.resume(checkpoint_directory, hyperparameters=config) if opts.resume else 0 58 | while True: 59 | for it, (images_a, images_b) in enumerate(zip(train_loader_a, train_loader_b)): 60 | trainer.update_learning_rate() 61 | images_a, images_b = images_a.cuda().detach(), images_b.cuda().detach() 62 | 63 | with Timer("Elapsed time in update: %f"): 64 | # Main training code 65 | trainer.dis_update(images_a, images_b, config) 66 | trainer.gen_update(images_a, images_b, config) 67 | torch.cuda.synchronize() 68 | 69 | # Dump training stats in log file 70 | if (iterations + 1) % config['log_iter'] == 0: 71 | print("Iteration: %08d/%08d" % (iterations + 1, max_iter)) 72 | write_loss(iterations, trainer, train_writer) 73 | 74 | # Write images 75 | if (iterations + 1) % config['image_save_iter'] == 0: 76 | with torch.no_grad(): 77 | test_image_outputs = trainer.sample(test_display_images_a, test_display_images_b) 78 | train_image_outputs = trainer.sample(train_display_images_a, train_display_images_b) 79 | write_2images(test_image_outputs, display_size, image_directory, 'test_%08d' % (iterations + 1)) 80 | write_2images(train_image_outputs, display_size, image_directory, 'train_%08d' % (iterations + 1)) 81 | # HTML 82 | write_html(output_directory + "/index.html", iterations + 1, config['image_save_iter'], 'images') 83 | 84 | if (iterations + 1) % config['image_display_iter'] == 0: 85 | with torch.no_grad(): 86 | image_outputs = trainer.sample(train_display_images_a, train_display_images_b) 87 | write_2images(image_outputs, display_size, image_directory, 'train_current') 88 | 89 | # Save network weights 90 | if (iterations + 1) % config['snapshot_save_iter'] == 0: 91 | trainer.save(checkpoint_directory, iterations) 92 | 93 | iterations += 1 94 | if iterations >= max_iter: 95 | sys.exit('Finish training') 96 | 97 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | from networks import AdaINGen, MsImageDis, VAEGen 6 | from utils import weights_init, get_model_list, vgg_preprocess, load_vgg16, get_scheduler 7 | from torch.autograd import Variable 8 | import torch 9 | import torch.nn as nn 10 | import os 11 | 12 | class MUNIT_Trainer(nn.Module): 13 | def __init__(self, hyperparameters): 14 | super(MUNIT_Trainer, self).__init__() 15 | lr = hyperparameters['lr'] 16 | # Initiate the networks 17 | self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a 18 | self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b 19 | self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a 20 | self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b 21 | self.instancenorm = nn.InstanceNorm2d(512, affine=False) 22 | self.style_dim = hyperparameters['gen']['style_dim'] 23 | 24 | # fix the noise used in sampling 25 | display_size = int(hyperparameters['display_size']) 26 | self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda() 27 | self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda() 28 | 29 | # Setup the optimizers 30 | beta1 = hyperparameters['beta1'] 31 | beta2 = hyperparameters['beta2'] 32 | dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) 33 | gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) 34 | self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], 35 | lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) 36 | self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], 37 | lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) 38 | self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) 39 | self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) 40 | 41 | # Network weight initialization 42 | self.apply(weights_init(hyperparameters['init'])) 43 | self.dis_a.apply(weights_init('gaussian')) 44 | self.dis_b.apply(weights_init('gaussian')) 45 | 46 | # Load VGG model if needed 47 | if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: 48 | self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') 49 | self.vgg.eval() 50 | for param in self.vgg.parameters(): 51 | param.requires_grad = False 52 | 53 | def recon_criterion(self, input, target): 54 | return torch.mean(torch.abs(input - target)) 55 | 56 | def forward(self, x_a, x_b): 57 | self.eval() 58 | s_a = Variable(self.s_a) 59 | s_b = Variable(self.s_b) 60 | c_a, s_a_fake = self.gen_a.encode(x_a) 61 | c_b, s_b_fake = self.gen_b.encode(x_b) 62 | x_ba = self.gen_a.decode(c_b, s_a) 63 | x_ab = self.gen_b.decode(c_a, s_b) 64 | self.train() 65 | return x_ab, x_ba 66 | 67 | def gen_update(self, x_a, x_b, hyperparameters): 68 | self.gen_opt.zero_grad() 69 | s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) 70 | s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) 71 | # encode 72 | c_a, s_a_prime = self.gen_a.encode(x_a) 73 | c_b, s_b_prime = self.gen_b.encode(x_b) 74 | # decode (within domain) 75 | x_a_recon = self.gen_a.decode(c_a, s_a_prime) 76 | x_b_recon = self.gen_b.decode(c_b, s_b_prime) 77 | # decode (cross domain) 78 | x_ba = self.gen_a.decode(c_b, s_a) 79 | x_ab = self.gen_b.decode(c_a, s_b) 80 | # encode again 81 | c_b_recon, s_a_recon = self.gen_a.encode(x_ba) 82 | c_a_recon, s_b_recon = self.gen_b.encode(x_ab) 83 | # decode again (if needed) 84 | x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None 85 | x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None 86 | 87 | # reconstruction loss 88 | self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) 89 | self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) 90 | self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) 91 | self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) 92 | self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) 93 | self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) 94 | self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 95 | self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 96 | # GAN loss 97 | self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) 98 | self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) 99 | # domain-invariant perceptual loss 100 | self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 101 | self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 102 | # total loss 103 | self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ 104 | hyperparameters['gan_w'] * self.loss_gen_adv_b + \ 105 | hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ 106 | hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ 107 | hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ 108 | hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ 109 | hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ 110 | hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ 111 | hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ 112 | hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ 113 | hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ 114 | hyperparameters['vgg_w'] * self.loss_gen_vgg_b 115 | self.loss_gen_total.backward() 116 | self.gen_opt.step() 117 | 118 | def compute_vgg_loss(self, vgg, img, target): 119 | img_vgg = vgg_preprocess(img) 120 | target_vgg = vgg_preprocess(target) 121 | img_fea = vgg(img_vgg) 122 | target_fea = vgg(target_vgg) 123 | return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) 124 | 125 | def sample(self, x_a, x_b): 126 | self.eval() 127 | s_a1 = Variable(self.s_a) 128 | s_b1 = Variable(self.s_b) 129 | s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) 130 | s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) 131 | x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] 132 | for i in range(x_a.size(0)): 133 | c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) 134 | c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) 135 | x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) 136 | x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) 137 | x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) 138 | x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) 139 | x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) 140 | x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) 141 | x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) 142 | x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) 143 | x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) 144 | self.train() 145 | return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 146 | 147 | def dis_update(self, x_a, x_b, hyperparameters): 148 | self.dis_opt.zero_grad() 149 | s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) 150 | s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) 151 | # encode 152 | c_a, _ = self.gen_a.encode(x_a) 153 | c_b, _ = self.gen_b.encode(x_b) 154 | # decode (cross domain) 155 | x_ba = self.gen_a.decode(c_b, s_a) 156 | x_ab = self.gen_b.decode(c_a, s_b) 157 | # D loss 158 | self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) 159 | self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) 160 | self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b 161 | self.loss_dis_total.backward() 162 | self.dis_opt.step() 163 | 164 | def update_learning_rate(self): 165 | if self.dis_scheduler is not None: 166 | self.dis_scheduler.step() 167 | if self.gen_scheduler is not None: 168 | self.gen_scheduler.step() 169 | 170 | def resume(self, checkpoint_dir, hyperparameters): 171 | # Load generators 172 | last_model_name = get_model_list(checkpoint_dir, "gen") 173 | state_dict = torch.load(last_model_name) 174 | self.gen_a.load_state_dict(state_dict['a']) 175 | self.gen_b.load_state_dict(state_dict['b']) 176 | iterations = int(last_model_name[-11:-3]) 177 | # Load discriminators 178 | last_model_name = get_model_list(checkpoint_dir, "dis") 179 | state_dict = torch.load(last_model_name) 180 | self.dis_a.load_state_dict(state_dict['a']) 181 | self.dis_b.load_state_dict(state_dict['b']) 182 | # Load optimizers 183 | state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) 184 | self.dis_opt.load_state_dict(state_dict['dis']) 185 | self.gen_opt.load_state_dict(state_dict['gen']) 186 | # Reinitilize schedulers 187 | self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) 188 | self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) 189 | print('Resume from iteration %d' % iterations) 190 | return iterations 191 | 192 | def save(self, snapshot_dir, iterations): 193 | # Save generators, discriminators, and optimizers 194 | gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) 195 | dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) 196 | opt_name = os.path.join(snapshot_dir, 'optimizer.pt') 197 | torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) 198 | torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) 199 | torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name) 200 | 201 | 202 | class UNIT_Trainer(nn.Module): 203 | def __init__(self, hyperparameters): 204 | super(UNIT_Trainer, self).__init__() 205 | lr = hyperparameters['lr'] 206 | # Initiate the networks 207 | self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a 208 | self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b 209 | self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a 210 | self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b 211 | self.instancenorm = nn.InstanceNorm2d(512, affine=False) 212 | 213 | # Setup the optimizers 214 | beta1 = hyperparameters['beta1'] 215 | beta2 = hyperparameters['beta2'] 216 | dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) 217 | gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) 218 | self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], 219 | lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) 220 | self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], 221 | lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) 222 | self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) 223 | self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) 224 | 225 | # Network weight initialization 226 | self.apply(weights_init(hyperparameters['init'])) 227 | self.dis_a.apply(weights_init('gaussian')) 228 | self.dis_b.apply(weights_init('gaussian')) 229 | 230 | # Load VGG model if needed 231 | if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: 232 | self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') 233 | self.vgg.eval() 234 | for param in self.vgg.parameters(): 235 | param.requires_grad = False 236 | 237 | def recon_criterion(self, input, target): 238 | return torch.mean(torch.abs(input - target)) 239 | 240 | def forward(self, x_a, x_b): 241 | self.eval() 242 | h_a, _ = self.gen_a.encode(x_a) 243 | h_b, _ = self.gen_b.encode(x_b) 244 | x_ba = self.gen_a.decode(h_b) 245 | x_ab = self.gen_b.decode(h_a) 246 | self.train() 247 | return x_ab, x_ba 248 | 249 | def __compute_kl(self, mu): 250 | # def _compute_kl(self, mu, sd): 251 | # mu_2 = torch.pow(mu, 2) 252 | # sd_2 = torch.pow(sd, 2) 253 | # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0) 254 | # return encoding_loss 255 | mu_2 = torch.pow(mu, 2) 256 | encoding_loss = torch.mean(mu_2) 257 | return encoding_loss 258 | 259 | def gen_update(self, x_a, x_b, hyperparameters): 260 | self.gen_opt.zero_grad() 261 | # encode 262 | h_a, n_a = self.gen_a.encode(x_a) 263 | h_b, n_b = self.gen_b.encode(x_b) 264 | # decode (within domain) 265 | x_a_recon = self.gen_a.decode(h_a + n_a) 266 | x_b_recon = self.gen_b.decode(h_b + n_b) 267 | # decode (cross domain) 268 | x_ba = self.gen_a.decode(h_b + n_b) 269 | x_ab = self.gen_b.decode(h_a + n_a) 270 | # encode again 271 | h_b_recon, n_b_recon = self.gen_a.encode(x_ba) 272 | h_a_recon, n_a_recon = self.gen_b.encode(x_ab) 273 | # decode again (if needed) 274 | x_aba = self.gen_a.decode(h_a_recon + n_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None 275 | x_bab = self.gen_b.decode(h_b_recon + n_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None 276 | 277 | # reconstruction loss 278 | self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) 279 | self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) 280 | self.loss_gen_recon_kl_a = self.__compute_kl(h_a) 281 | self.loss_gen_recon_kl_b = self.__compute_kl(h_b) 282 | self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a) 283 | self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b) 284 | self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon) 285 | self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon) 286 | # GAN loss 287 | self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) 288 | self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) 289 | # domain-invariant perceptual loss 290 | self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 291 | self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 292 | # total loss 293 | self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ 294 | hyperparameters['gan_w'] * self.loss_gen_adv_b + \ 295 | hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ 296 | hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ 297 | hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ 298 | hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ 299 | hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ 300 | hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ 301 | hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ 302 | hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \ 303 | hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ 304 | hyperparameters['vgg_w'] * self.loss_gen_vgg_b 305 | self.loss_gen_total.backward() 306 | self.gen_opt.step() 307 | 308 | def compute_vgg_loss(self, vgg, img, target): 309 | img_vgg = vgg_preprocess(img) 310 | target_vgg = vgg_preprocess(target) 311 | img_fea = vgg(img_vgg) 312 | target_fea = vgg(target_vgg) 313 | return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) 314 | 315 | def sample(self, x_a, x_b): 316 | self.eval() 317 | x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], [] 318 | for i in range(x_a.size(0)): 319 | h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0)) 320 | h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0)) 321 | x_a_recon.append(self.gen_a.decode(h_a)) 322 | x_b_recon.append(self.gen_b.decode(h_b)) 323 | x_ba.append(self.gen_a.decode(h_b)) 324 | x_ab.append(self.gen_b.decode(h_a)) 325 | x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) 326 | x_ba = torch.cat(x_ba) 327 | x_ab = torch.cat(x_ab) 328 | self.train() 329 | return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba 330 | 331 | def dis_update(self, x_a, x_b, hyperparameters): 332 | self.dis_opt.zero_grad() 333 | # encode 334 | h_a, n_a = self.gen_a.encode(x_a) 335 | h_b, n_b = self.gen_b.encode(x_b) 336 | # decode (cross domain) 337 | x_ba = self.gen_a.decode(h_b + n_b) 338 | x_ab = self.gen_b.decode(h_a + n_a) 339 | # D loss 340 | self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) 341 | self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) 342 | self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b 343 | self.loss_dis_total.backward() 344 | self.dis_opt.step() 345 | 346 | def update_learning_rate(self): 347 | if self.dis_scheduler is not None: 348 | self.dis_scheduler.step() 349 | if self.gen_scheduler is not None: 350 | self.gen_scheduler.step() 351 | 352 | def resume(self, checkpoint_dir, hyperparameters): 353 | # Load generators 354 | last_model_name = get_model_list(checkpoint_dir, "gen") 355 | state_dict = torch.load(last_model_name) 356 | self.gen_a.load_state_dict(state_dict['a']) 357 | self.gen_b.load_state_dict(state_dict['b']) 358 | iterations = int(last_model_name[-11:-3]) 359 | # Load discriminators 360 | last_model_name = get_model_list(checkpoint_dir, "dis") 361 | state_dict = torch.load(last_model_name) 362 | self.dis_a.load_state_dict(state_dict['a']) 363 | self.dis_b.load_state_dict(state_dict['b']) 364 | # Load optimizers 365 | state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) 366 | self.dis_opt.load_state_dict(state_dict['dis']) 367 | self.gen_opt.load_state_dict(state_dict['gen']) 368 | # Reinitilize schedulers 369 | self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) 370 | self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) 371 | print('Resume from iteration %d' % iterations) 372 | return iterations 373 | 374 | def save(self, snapshot_dir, iterations): 375 | # Save generators, discriminators, and optimizers 376 | gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) 377 | dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) 378 | opt_name = os.path.join(snapshot_dir, 'optimizer.pt') 379 | torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) 380 | torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) 381 | torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name) 382 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | from torch.utils.serialization import load_lua 6 | from torch.utils.data import DataLoader 7 | from networks import Vgg16 8 | from torch.autograd import Variable 9 | from torch.optim import lr_scheduler 10 | from torchvision import transforms 11 | from data import ImageFilelist, ImageFolder 12 | import torch 13 | import os 14 | import math 15 | import torchvision.utils as vutils 16 | import yaml 17 | import numpy as np 18 | import torch.nn.init as init 19 | import time 20 | # Methods 21 | # get_all_data_loaders : primary data loader interface (load trainA, testA, trainB, testB) 22 | # get_data_loader_list : list-based data loader 23 | # get_data_loader_folder : folder-based data loader 24 | # get_config : load yaml file 25 | # eformat : 26 | # write_2images : save output image 27 | # prepare_sub_folder : create checkpoints and images folders for saving outputs 28 | # write_one_row_html : write one row of the html file for output images 29 | # write_html : create the html file. 30 | # write_loss 31 | # slerp 32 | # get_slerp_interp 33 | # get_model_list 34 | # load_vgg16 35 | # vgg_preprocess 36 | # get_scheduler 37 | # weights_init 38 | 39 | def get_all_data_loaders(conf): 40 | batch_size = conf['batch_size'] 41 | num_workers = conf['num_workers'] 42 | if 'new_size' in conf: 43 | new_size_a = new_size_b = conf['new_size'] 44 | else: 45 | new_size_a = conf['new_size_a'] 46 | new_size_b = conf['new_size_b'] 47 | height = conf['crop_image_height'] 48 | width = conf['crop_image_width'] 49 | 50 | if 'data_root' in conf: 51 | train_loader_a = get_data_loader_folder(os.path.join(conf['data_root'], 'trainA'), batch_size, True, 52 | new_size_a, height, width, num_workers, True) 53 | test_loader_a = get_data_loader_folder(os.path.join(conf['data_root'], 'testA'), batch_size, False, 54 | new_size_a, new_size_a, new_size_a, num_workers, True) 55 | train_loader_b = get_data_loader_folder(os.path.join(conf['data_root'], 'trainB'), batch_size, True, 56 | new_size_b, height, width, num_workers, True) 57 | test_loader_b = get_data_loader_folder(os.path.join(conf['data_root'], 'testB'), batch_size, False, 58 | new_size_b, new_size_b, new_size_b, num_workers, True) 59 | else: 60 | train_loader_a = get_data_loader_list(conf['data_folder_train_a'], conf['data_list_train_a'], batch_size, True, 61 | new_size_a, height, width, num_workers, True) 62 | test_loader_a = get_data_loader_list(conf['data_folder_test_a'], conf['data_list_test_a'], batch_size, False, 63 | new_size_a, new_size_a, new_size_a, num_workers, True) 64 | train_loader_b = get_data_loader_list(conf['data_folder_train_b'], conf['data_list_train_b'], batch_size, True, 65 | new_size_b, height, width, num_workers, True) 66 | test_loader_b = get_data_loader_list(conf['data_folder_test_b'], conf['data_list_test_b'], batch_size, False, 67 | new_size_b, new_size_b, new_size_b, num_workers, True) 68 | return train_loader_a, train_loader_b, test_loader_a, test_loader_b 69 | 70 | 71 | def get_data_loader_list(root, file_list, batch_size, train, new_size=None, 72 | height=256, width=256, num_workers=4, crop=True): 73 | transform_list = [transforms.ToTensor(), 74 | transforms.Normalize((0.5, 0.5, 0.5), 75 | (0.5, 0.5, 0.5))] 76 | transform_list = [transforms.RandomCrop((height, width))] + transform_list if crop else transform_list 77 | transform_list = [transforms.Resize(new_size)] + transform_list if new_size is not None else transform_list 78 | transform_list = [transforms.RandomHorizontalFlip()] + transform_list if train else transform_list 79 | transform = transforms.Compose(transform_list) 80 | dataset = ImageFilelist(root, file_list, transform=transform) 81 | loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=train, drop_last=True, num_workers=num_workers) 82 | return loader 83 | 84 | def get_data_loader_folder(input_folder, batch_size, train, new_size=None, 85 | height=256, width=256, num_workers=4, crop=True): 86 | transform_list = [transforms.ToTensor(), 87 | transforms.Normalize((0.5, 0.5, 0.5), 88 | (0.5, 0.5, 0.5))] 89 | transform_list = [transforms.RandomCrop((height, width))] + transform_list if crop else transform_list 90 | transform_list = [transforms.Resize(new_size)] + transform_list if new_size is not None else transform_list 91 | transform_list = [transforms.RandomHorizontalFlip()] + transform_list if train else transform_list 92 | transform = transforms.Compose(transform_list) 93 | dataset = ImageFolder(input_folder, transform=transform) 94 | loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=train, drop_last=True, num_workers=num_workers) 95 | return loader 96 | 97 | 98 | def get_config(config): 99 | with open(config, 'r') as stream: 100 | return yaml.load(stream) 101 | 102 | 103 | def eformat(f, prec): 104 | s = "%.*e"%(prec, f) 105 | mantissa, exp = s.split('e') 106 | # add 1 to digits as 1 is taken by sign +/- 107 | return "%se%d"%(mantissa, int(exp)) 108 | 109 | 110 | def __write_images(image_outputs, display_image_num, file_name): 111 | image_outputs = [images.expand(-1, 3, -1, -1) for images in image_outputs] # expand gray-scale images to 3 channels 112 | image_tensor = torch.cat([images[:display_image_num] for images in image_outputs], 0) 113 | image_grid = vutils.make_grid(image_tensor.data, nrow=display_image_num, padding=0, normalize=True) 114 | vutils.save_image(image_grid, file_name, nrow=1) 115 | 116 | 117 | def write_2images(image_outputs, display_image_num, image_directory, postfix): 118 | n = len(image_outputs) 119 | __write_images(image_outputs[0:n//2], display_image_num, '%s/gen_a2b_%s.jpg' % (image_directory, postfix)) 120 | __write_images(image_outputs[n//2:n], display_image_num, '%s/gen_b2a_%s.jpg' % (image_directory, postfix)) 121 | 122 | 123 | def prepare_sub_folder(output_directory): 124 | image_directory = os.path.join(output_directory, 'images') 125 | if not os.path.exists(image_directory): 126 | print("Creating directory: {}".format(image_directory)) 127 | os.makedirs(image_directory) 128 | checkpoint_directory = os.path.join(output_directory, 'checkpoints') 129 | if not os.path.exists(checkpoint_directory): 130 | print("Creating directory: {}".format(checkpoint_directory)) 131 | os.makedirs(checkpoint_directory) 132 | return checkpoint_directory, image_directory 133 | 134 | 135 | def write_one_row_html(html_file, iterations, img_filename, all_size): 136 | html_file.write("

iteration [%d] (%s)

" % (iterations,img_filename.split('/')[-1])) 137 | html_file.write(""" 138 |

139 | 140 |
141 |

142 | """ % (img_filename, img_filename, all_size)) 143 | return 144 | 145 | 146 | def write_html(filename, iterations, image_save_iterations, image_directory, all_size=1536): 147 | html_file = open(filename, "w") 148 | html_file.write(''' 149 | 150 | 151 | 152 | Experiment name = %s 153 | 154 | 155 | 156 | ''' % os.path.basename(filename)) 157 | html_file.write("

current

") 158 | write_one_row_html(html_file, iterations, '%s/gen_a2b_train_current.jpg' % (image_directory), all_size) 159 | write_one_row_html(html_file, iterations, '%s/gen_b2a_train_current.jpg' % (image_directory), all_size) 160 | for j in range(iterations, image_save_iterations-1, -1): 161 | if j % image_save_iterations == 0: 162 | write_one_row_html(html_file, j, '%s/gen_a2b_test_%08d.jpg' % (image_directory, j), all_size) 163 | write_one_row_html(html_file, j, '%s/gen_b2a_test_%08d.jpg' % (image_directory, j), all_size) 164 | write_one_row_html(html_file, j, '%s/gen_a2b_train_%08d.jpg' % (image_directory, j), all_size) 165 | write_one_row_html(html_file, j, '%s/gen_b2a_train_%08d.jpg' % (image_directory, j), all_size) 166 | html_file.write("") 167 | html_file.close() 168 | 169 | 170 | def write_loss(iterations, trainer, train_writer): 171 | members = [attr for attr in dir(trainer) \ 172 | if not callable(getattr(trainer, attr)) and not attr.startswith("__") and ('loss' in attr or 'grad' in attr or 'nwd' in attr)] 173 | for m in members: 174 | train_writer.add_scalar(m, getattr(trainer, m), iterations + 1) 175 | 176 | 177 | def slerp(val, low, high): 178 | """ 179 | original: Animating Rotation with Quaternion Curves, Ken Shoemake 180 | https://arxiv.org/abs/1609.04468 181 | Code: https://github.com/soumith/dcgan.torch/issues/14, Tom White 182 | """ 183 | omega = np.arccos(np.dot(low / np.linalg.norm(low), high / np.linalg.norm(high))) 184 | so = np.sin(omega) 185 | return np.sin((1.0 - val) * omega) / so * low + np.sin(val * omega) / so * high 186 | 187 | 188 | def get_slerp_interp(nb_latents, nb_interp, z_dim): 189 | """ 190 | modified from: PyTorch inference for "Progressive Growing of GANs" with CelebA snapshot 191 | https://github.com/ptrblck/prog_gans_pytorch_inference 192 | """ 193 | 194 | latent_interps = np.empty(shape=(0, z_dim), dtype=np.float32) 195 | for _ in range(nb_latents): 196 | low = np.random.randn(z_dim) 197 | high = np.random.randn(z_dim) # low + np.random.randn(512) * 0.7 198 | interp_vals = np.linspace(0, 1, num=nb_interp) 199 | latent_interp = np.array([slerp(v, low, high) for v in interp_vals], 200 | dtype=np.float32) 201 | latent_interps = np.vstack((latent_interps, latent_interp)) 202 | 203 | return latent_interps[:, :, np.newaxis, np.newaxis] 204 | 205 | 206 | # Get model list for resume 207 | def get_model_list(dirname, key): 208 | if os.path.exists(dirname) is False: 209 | return None 210 | gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if 211 | os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f] 212 | if gen_models is None: 213 | return None 214 | gen_models.sort() 215 | last_model_name = gen_models[-1] 216 | return last_model_name 217 | 218 | 219 | def load_vgg16(model_dir): 220 | """ Use the model from https://github.com/abhiskk/fast-neural-style/blob/master/neural_style/utils.py """ 221 | if not os.path.exists(model_dir): 222 | os.mkdir(model_dir) 223 | if not os.path.exists(os.path.join(model_dir, 'vgg16.weight')): 224 | if not os.path.exists(os.path.join(model_dir, 'vgg16.t7')): 225 | os.system('wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O ' + os.path.join(model_dir, 'vgg16.t7')) 226 | vgglua = load_lua(os.path.join(model_dir, 'vgg16.t7')) 227 | vgg = Vgg16() 228 | for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()): 229 | dst.data[:] = src 230 | torch.save(vgg.state_dict(), os.path.join(model_dir, 'vgg16.weight')) 231 | vgg = Vgg16() 232 | vgg.load_state_dict(torch.load(os.path.join(model_dir, 'vgg16.weight'))) 233 | return vgg 234 | 235 | 236 | def vgg_preprocess(batch): 237 | tensortype = type(batch.data) 238 | (r, g, b) = torch.chunk(batch, 3, dim = 1) 239 | batch = torch.cat((b, g, r), dim = 1) # convert RGB to BGR 240 | batch = (batch + 1) * 255 * 0.5 # [-1, 1] -> [0, 255] 241 | mean = tensortype(batch.data.size()).cuda() 242 | mean[:, 0, :, :] = 103.939 243 | mean[:, 1, :, :] = 116.779 244 | mean[:, 2, :, :] = 123.680 245 | batch = batch.sub(Variable(mean)) # subtract mean 246 | return batch 247 | 248 | 249 | def get_scheduler(optimizer, hyperparameters, iterations=-1): 250 | if 'lr_policy' not in hyperparameters or hyperparameters['lr_policy'] == 'constant': 251 | scheduler = None # constant scheduler 252 | elif hyperparameters['lr_policy'] == 'step': 253 | scheduler = lr_scheduler.StepLR(optimizer, step_size=hyperparameters['step_size'], 254 | gamma=hyperparameters['gamma'], last_epoch=iterations) 255 | else: 256 | return NotImplementedError('learning rate policy [%s] is not implemented', hyperparameters['lr_policy']) 257 | return scheduler 258 | 259 | 260 | def weights_init(init_type='gaussian'): 261 | def init_fun(m): 262 | classname = m.__class__.__name__ 263 | if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'): 264 | # print m.__class__.__name__ 265 | if init_type == 'gaussian': 266 | init.normal_(m.weight.data, 0.0, 0.02) 267 | elif init_type == 'xavier': 268 | init.xavier_normal_(m.weight.data, gain=math.sqrt(2)) 269 | elif init_type == 'kaiming': 270 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 271 | elif init_type == 'orthogonal': 272 | init.orthogonal_(m.weight.data, gain=math.sqrt(2)) 273 | elif init_type == 'default': 274 | pass 275 | else: 276 | assert 0, "Unsupported initialization: {}".format(init_type) 277 | if hasattr(m, 'bias') and m.bias is not None: 278 | init.constant_(m.bias.data, 0.0) 279 | 280 | return init_fun 281 | 282 | 283 | class Timer: 284 | def __init__(self, msg): 285 | self.msg = msg 286 | self.start_time = None 287 | 288 | def __enter__(self): 289 | self.start_time = time.time() 290 | 291 | def __exit__(self, exc_type, exc_value, exc_tb): 292 | print(self.msg % (time.time() - self.start_time)) 293 | 294 | 295 | def pytorch03_to_pytorch04(state_dict_base): 296 | def __conversion_core(state_dict_base): 297 | state_dict = state_dict_base.copy() 298 | for key, value in state_dict_base.items(): 299 | if key.endswith(('enc.model.0.norm.running_mean', 300 | 'enc.model.0.norm.running_var', 301 | 'enc.model.1.norm.running_mean', 302 | 'enc.model.1.norm.running_var', 303 | 'enc.model.2.norm.running_mean', 304 | 'enc.model.2.norm.running_var', 305 | 'enc.model.3.model.0.model.1.norm.running_mean', 306 | 'enc.model.3.model.0.model.1.norm.running_var', 307 | 'enc.model.3.model.0.model.0.norm.running_mean', 308 | 'enc.model.3.model.0.model.0.norm.running_var', 309 | 'enc.model.3.model.1.model.1.norm.running_mean', 310 | 'enc.model.3.model.1.model.1.norm.running_var', 311 | 'enc.model.3.model.1.model.0.norm.running_mean', 312 | 'enc.model.3.model.1.model.0.norm.running_var', 313 | 'enc.model.3.model.2.model.1.norm.running_mean', 314 | 'enc.model.3.model.2.model.1.norm.running_var', 315 | 'enc.model.3.model.2.model.0.norm.running_mean', 316 | 'enc.model.3.model.2.model.0.norm.running_var', 317 | 'enc.model.3.model.3.model.1.norm.running_mean', 318 | 'enc.model.3.model.3.model.1.norm.running_var', 319 | 'enc.model.3.model.3.model.0.norm.running_mean', 320 | 'enc.model.3.model.3.model.0.norm.running_var', 321 | 'dec.model.0.model.0.model.1.norm.running_mean', 322 | 'dec.model.0.model.0.model.1.norm.running_var', 323 | 'dec.model.0.model.0.model.0.norm.running_mean', 324 | 'dec.model.0.model.0.model.0.norm.running_var', 325 | 'dec.model.0.model.1.model.1.norm.running_mean', 326 | 'dec.model.0.model.1.model.1.norm.running_var', 327 | 'dec.model.0.model.1.model.0.norm.running_mean', 328 | 'dec.model.0.model.1.model.0.norm.running_var', 329 | 'dec.model.0.model.2.model.1.norm.running_mean', 330 | 'dec.model.0.model.2.model.1.norm.running_var', 331 | 'dec.model.0.model.2.model.0.norm.running_mean', 332 | 'dec.model.0.model.2.model.0.norm.running_var', 333 | 'dec.model.0.model.3.model.1.norm.running_mean', 334 | 'dec.model.0.model.3.model.1.norm.running_var', 335 | 'dec.model.0.model.3.model.0.norm.running_mean', 336 | 'dec.model.0.model.3.model.0.norm.running_var', 337 | )): 338 | del state_dict[key] 339 | return state_dict 340 | state_dict = dict() 341 | state_dict['a'] = __conversion_core(state_dict_base['a']) 342 | state_dict['b'] = __conversion_core(state_dict_base['b']) 343 | return state_dict --------------------------------------------------------------------------------