├── .gitignore ├── LICENSE ├── README.md ├── data ├── aligned_data_loader.lua ├── base_data_loader.lua ├── data.lua ├── data_util.lua ├── dataset.lua ├── donkey_folder.lua └── unaligned_data_loader.lua ├── datasets ├── bibtex │ ├── cityscapes.tex │ └── facades.tex ├── download_dataset.sh └── prepare_cityscapes_dataset.py ├── examples ├── test_vangogh_style_on_ae_photos.sh └── train_maps.sh ├── imgs ├── failure_putin.jpg ├── horse2zebra.gif ├── objects.jpg ├── painting2photo.jpg ├── paper_thumbnail.jpg ├── photo2painting.jpg ├── photo_enhancement.jpg ├── season.jpg └── teaser.jpg ├── models ├── architectures.lua ├── base_model.lua ├── bigan_model.lua ├── content_gan_model.lua ├── cycle_gan_model.lua ├── one_direction_test_model.lua └── pix2pix_model.lua ├── options.lua ├── pretrained_models ├── download_model.sh ├── download_vgg.sh └── places_vgg.prototxt ├── test.lua ├── train.lua └── util ├── InstanceNormalization.lua ├── VGG_preprocess.lua ├── content_loss.lua ├── cudnn_convert_custom.lua ├── image_pool.lua ├── plot_util.lua ├── util.lua └── visualizer.lua /.gitignore: -------------------------------------------------------------------------------- 1 | datasets/ 2 | checkpoints/ 3 | results/ 4 | build/ 5 | dist/ 6 | *.png 7 | torch.egg-info/ 8 | */**/__pycache__ 9 | torch/version.py 10 | torch/csrc/generic/TensorMethods.cpp 11 | torch/lib/*.so* 12 | torch/lib/*.dylib* 13 | torch/lib/*.h 14 | torch/lib/build 15 | torch/lib/tmp_install 16 | torch/lib/include 17 | torch/lib/torch_shm_manager 18 | torch/csrc/cudnn/cuDNN.cpp 19 | torch/csrc/nn/THNN.cwrap 20 | torch/csrc/nn/THNN.cpp 21 | torch/csrc/nn/THCUNN.cwrap 22 | torch/csrc/nn/THCUNN.cpp 23 | torch/csrc/nn/THNN_generic.cwrap 24 | torch/csrc/nn/THNN_generic.cpp 25 | torch/csrc/nn/THNN_generic.h 26 | docs/src/**/* 27 | test/data/legacy_modules.t7 28 | test/data/gpu_tensors.pt 29 | test/htmlcov 30 | test/.coverage 31 | */*.pyc 32 | */**/*.pyc 33 | */**/**/*.pyc 34 | */**/**/**/*.pyc 35 | */**/**/**/**/*.pyc 36 | */*.so* 37 | */**/*.so* 38 | */**/*.dylib* 39 | test/data/legacy_serialized.pt 40 | *~ 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | 26 | --------------------------- LICENSE FOR pix2pix -------------------------------- 27 | BSD License 28 | 29 | For pix2pix software 30 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu 31 | All rights reserved. 32 | 33 | Redistribution and use in source and binary forms, with or without 34 | modification, are permitted provided that the following conditions are met: 35 | 36 | * Redistributions of source code must retain the above copyright notice, this 37 | list of conditions and the following disclaimer. 38 | 39 | * Redistributions in binary form must reproduce the above copyright notice, 40 | this list of conditions and the following disclaimer in the documentation 41 | and/or other materials provided with the distribution. 42 | 43 | ----------------------------- LICENSE FOR DCGAN -------------------------------- 44 | BSD License 45 | 46 | For dcgan.torch software 47 | 48 | Copyright (c) 2015, Facebook, Inc. All rights reserved. 49 | 50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 51 | 52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 53 | 54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 55 | 56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 57 | 58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |


4 | 5 | # CycleGAN 6 | ### [PyTorch](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) | [project page](https://junyanz.github.io/CycleGAN/) | [paper](https://arxiv.org/pdf/1703.10593.pdf) 7 | 8 | Torch implementation for learning an image-to-image translation (i.e. [pix2pix](https://github.com/phillipi/pix2pix)) **without** input-output pairs, for example: 9 | 10 | **New**: Please check out [contrastive-unpaired-translation](https://github.com/taesungp/contrastive-unpaired-translation) (CUT), our new unpaired image-to-image translation model that enables fast and memory-efficient training. 11 | 12 | 13 | 14 | [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://junyanz.github.io/CycleGAN/) 15 | [Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz/)\*, [Taesung Park](https://taesung.me/)\*, [Phillip Isola](http://web.mit.edu/phillipi/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/) 16 | Berkeley AI Research Lab, UC Berkeley 17 | In ICCV 2017. (* equal contributions) 18 | 19 | This package includes CycleGAN, [pix2pix](https://github.com/phillipi/pix2pix), as well as other methods like [BiGAN](https://arxiv.org/abs/1605.09782)/[ALI](https://ishmaelbelghazi.github.io/ALI/) and Apple's paper [S+U learning](https://arxiv.org/pdf/1612.07828.pdf). 20 | The code was written by [Jun-Yan Zhu](https://github.com/junyanz) and [Taesung Park](https://github.com/taesung). 21 | **Update**: Please check out [PyTorch](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) implementation for CycleGAN and pix2pix. 22 | The PyTorch version is under active development and can produce results comparable or better than this Torch version. 23 | 24 | ## Other implementations: 25 |

[Tensorflow] (by Harry Yang), 26 | [Tensorflow] (by Archit Rathore), 27 | [Tensorflow] (by Van Huy), 28 | [Tensorflow] (by Xiaowei Hu), 29 | [Tensorflow-simple] (by Zhenliang He), 30 | [TensorLayer] (by luoxier), 31 | [Chainer] (by Yanghua Jin), 32 | [Minimal PyTorch] (by yunjey), 33 | [Mxnet] (by Ldpe2G), 34 | [lasagne/Keras] (by tjwei), 35 | [Keras] (by Simon Karlsson)

36 | 37 | 38 | ## Applications 39 | ### Monet Paintings to Photos 40 | 41 | 42 | ### Collection Style Transfer 43 | 44 | 45 | ### Object Transfiguration 46 | 47 | 48 | ### Season Transfer 49 | 50 | 51 | ### Photo Enhancement: Narrow depth of field 52 | 53 | 54 | 55 | 56 | ## Prerequisites 57 | - Linux or OSX 58 | - NVIDIA GPU + CUDA CuDNN (CPU mode and CUDA without CuDNN may work with minimal modification, but untested) 59 | - For MAC users, you need the Linux/GNU commands `gfind` and `gwc`, which can be installed with `brew install findutils coreutils`. 60 | 61 | ## Getting Started 62 | ### Installation 63 | - Install torch and dependencies from https://github.com/torch/distro 64 | - Install torch packages `nngraph`, `class`, `display` 65 | ```bash 66 | luarocks install nngraph 67 | luarocks install class 68 | luarocks install https://raw.githubusercontent.com/szym/display/master/display-scm-0.rockspec 69 | ``` 70 | - Clone this repo: 71 | ```bash 72 | git clone https://github.com/junyanz/CycleGAN 73 | cd CycleGAN 74 | ``` 75 | 76 | ### Apply a Pre-trained Model 77 | - Download the test photos (taken by [Alexei Efros](https://www.flickr.com/photos/aaefros)): 78 | ``` 79 | bash ./datasets/download_dataset.sh ae_photos 80 | ``` 81 | - Download the pre-trained model `style_cezanne` (For CPU model, use `style_cezanne_cpu`): 82 | ``` 83 | bash ./pretrained_models/download_model.sh style_cezanne 84 | ``` 85 | - Now, let's generate Paul Cézanne style images: 86 | ``` 87 | DATA_ROOT=./datasets/ae_photos name=style_cezanne_pretrained model=one_direction_test phase=test loadSize=256 fineSize=256 resize_or_crop="scale_width" th test.lua 88 | ``` 89 | The test results will be saved to `./results/style_cezanne_pretrained/latest_test/index.html`. 90 | Please refer to [Model Zoo](#model-zoo) for more pre-trained models. 91 | `./examples/test_vangogh_style_on_ae_photos.sh` is an example script that downloads the pretrained Van Gogh style network and runs it on Efros's photos. 92 | 93 | ### Train 94 | - Download a dataset (e.g. zebra and horse images from ImageNet): 95 | ```bash 96 | bash ./datasets/download_dataset.sh horse2zebra 97 | ``` 98 | - Train a model: 99 | ```bash 100 | DATA_ROOT=./datasets/horse2zebra name=horse2zebra_model th train.lua 101 | ``` 102 | - (CPU only) The same training command without using a GPU or CUDNN. Setting the environment variables ```gpu=0 cudnn=0``` forces CPU only 103 | ```bash 104 | DATA_ROOT=./datasets/horse2zebra name=horse2zebra_model gpu=0 cudnn=0 th train.lua 105 | ``` 106 | - (Optionally) start the display server to view results as the model trains. (See [Display UI](#display-ui) for more details): 107 | ```bash 108 | th -ldisplay.start 8000 0.0.0.0 109 | ``` 110 | 111 | ### Test 112 | - Finally, test the model: 113 | ```bash 114 | DATA_ROOT=./datasets/horse2zebra name=horse2zebra_model phase=test th test.lua 115 | ``` 116 | The test results will be saved to an HTML file here: `./results/horse2zebra_model/latest_test/index.html`. 117 | 118 | 119 | ## Model Zoo 120 | Download the pre-trained models with the following script. The model will be saved to `./checkpoints/model_name/latest_net_G.t7`. 121 | ```bash 122 | bash ./pretrained_models/download_model.sh model_name 123 | ``` 124 | - `orange2apple` (orange -> apple) and `apple2orange`: trained on ImageNet categories `apple` and `orange`. 125 | - `horse2zebra` (horse -> zebra) and `zebra2horse` (zebra -> horse): trained on ImageNet categories `horse` and `zebra`. 126 | - `style_monet` (landscape photo -> Monet painting style), `style_vangogh` (landscape photo -> Van Gogh painting style), `style_ukiyoe` (landscape photo -> Ukiyo-e painting style), `style_cezanne` (landscape photo -> Cezanne painting style): trained on paintings and Flickr landscape photos. 127 | - `monet2photo` (Monet paintings -> real landscape): trained on paintings and Flickr landscape photographs. 128 | - `cityscapes_photo2label` (street scene -> label) and `cityscapes_label2photo` (label -> street scene): trained on the Cityscapes dataset. 129 | - `map2sat` (map -> aerial photo) and `sat2map` (aerial photo -> map): trained on Google maps. 130 | - `iphone2dslr_flower` (iPhone photos of flowers -> DSLR photos of flowers): trained on Flickr photos. 131 | 132 | CPU models can be downloaded using: 133 | ```bash 134 | bash pretrained_models/download_model.sh _cpu 135 | ``` 136 | , where `` can be `horse2zebra`, `style_monet`, etc. You just need to append `_cpu` to the target model. 137 | 138 | ## Training and Test Details 139 | To train a model, 140 | ```bash 141 | DATA_ROOT=/path/to/data/ name=expt_name th train.lua 142 | ``` 143 | Models are saved to `./checkpoints/expt_name` (can be changed by passing `checkpoint_dir=your_dir` in train.lua). 144 | See `opt_train` in `options.lua` for additional training options. 145 | 146 | To test the model, 147 | ```bash 148 | DATA_ROOT=/path/to/data/ name=expt_name phase=test th test.lua 149 | ``` 150 | This will run the model named `expt_name` in both directions on all images in `/path/to/data/testA` and `/path/to/data/testB`. 151 | A webpage with result images will be saved to `./results/expt_name` (can be changed by passing `results_dir=your_dir` in test.lua). 152 | See `opt_test` in `options.lua` for additional test options. Please use `model=one_direction_test` if you only would like to generate outputs of the trained network in only one direction, and specify `which_direction=AtoB` or `which_direction=BtoA` to set the direction. 153 | 154 | There are other options that can be used. For example, you can specify `resize_or_crop=crop` option to avoid resizing the image to squares. This is indeed how we trained GTA2Cityscapes model in the projet [webpage](https://junyanz.github.io/CycleGAN/) and [Cycada](https://arxiv.org/pdf/1711.03213.pdf) model. We prepared the images at 1024px resolution, and used `resize_or_crop=crop fineSize=360` to work with the cropped images of size 360x360. We also used `lambda_identity=1.0`. 155 | 156 | ## Datasets 157 | Download the datasets using the following script. Many of the datasets were collected by other researchers. Please cite their papers if you use the data. 158 | ```bash 159 | bash ./datasets/download_dataset.sh dataset_name 160 | ``` 161 | - `facades`: 400 images from the [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade/). [[Citation](datasets/bibtex/facades.tex)] 162 | - `cityscapes`: 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com/). [[Citation](datasets/bibtex/cityscapes.tex)]. Note: Due to license issue, we do not host the dataset on our repo. Please download the dataset directly from the Cityscapes webpage. Please refer to `./datasets/prepare_cityscapes_dataset.py` for more detail. 163 | - `maps`: 1096 training images scraped from Google Maps. 164 | - `horse2zebra`: 939 horse images and 1177 zebra images downloaded from [ImageNet](http://www.image-net.org/) using the keywords `wild horse` and `zebra` 165 | - `apple2orange`: 996 apple images and 1020 orange images downloaded from [ImageNet](http://www.image-net.org/) using the keywords `apple` and `navel orange`. 166 | - `summer2winter_yosemite`: 1273 summer Yosemite images and 854 winter Yosemite images were downloaded using Flickr API. See more details in our paper. 167 | - `monet2photo`, `vangogh2photo`, `ukiyoe2photo`, `cezanne2photo`: The art images were downloaded from [Wikiart](https://www.wikiart.org/). The real photos are downloaded from Flickr using the combination of the tags *landscape* and *landscapephotography*. The training set size of each class is Monet:1074, Cezanne:584, Van Gogh:401, Ukiyo-e:1433, Photographs:6853. 168 | - `iphone2dslr_flower`: both classes of images were downloaded from Flickr. The training set size of each class is iPhone:1813, DSLR:3316. See more details in our paper. 169 | 170 | 171 | ## Display UI 172 | Optionally, for displaying images during training and test, use the [display package](https://github.com/szym/display). 173 | 174 | - Install it with: `luarocks install https://raw.githubusercontent.com/szym/display/master/display-scm-0.rockspec` 175 | - Then start the server with: `th -ldisplay.start` 176 | - Open this URL in your browser: [http://localhost:8000](http://localhost:8000) 177 | 178 | By default, the server listens on localhost. Pass `0.0.0.0` to allow external connections on any interface: 179 | ```bash 180 | th -ldisplay.start 8000 0.0.0.0 181 | ``` 182 | Then open `http://(hostname):(port)/` in your browser to load the remote desktop. 183 | 184 | ## Setup Training and Test data 185 | To train CycleGAN model on your own datasets, you need to create a data folder with two subdirectories `trainA` and `trainB` that contain images from domain A and B. You can test your model on your training set by setting ``phase='train'`` in `test.lua`. You can also create subdirectories `testA` and `testB` if you have test data. 186 | 187 | You should **not** expect our method to work on just any random combination of input and output datasets (e.g. `cats<->keyboards`). From our experiments, we find it works better if two datasets share similar visual content. For example, `landscape painting<->landscape photographs` works much better than `portrait painting <-> landscape photographs`. `zebras<->horses` achieves compelling results while `cats<->dogs` completely fails. See the following section for more discussion. 188 | 189 | 190 | ## Failure cases 191 | 192 | 193 | Our model does not work well when the test image is rather different from the images on which the model is trained, as is the case in the figure to the left (we trained on horses and zebras without riders, but test here one a horse with a rider). See additional typical failure cases [here](https://junyanz.github.io/CycleGAN/images/failures.jpg). On translation tasks that involve color and texture changes, like many of those reported above, the method often succeeds. We have also explored tasks that require geometric changes, with little success. For example, on the task of `dog<->cat` transfiguration, the learned translation degenerates into making minimal changes to the input. We also observe a lingering gap between the results achievable with paired training data and those achieved by our unpaired method. In some cases, this gap may be very hard -- or even impossible,-- to close: for example, our method sometimes permutes the labels for tree and building in the output of the cityscapes photos->labels task. 194 | 195 | 196 | 197 | ## Citation 198 | If you use this code for your research, please cite our [paper](https://junyanz.github.io/CycleGAN/): 199 | 200 | ``` 201 | @inproceedings{CycleGAN2017, 202 | title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networkss}, 203 | author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A}, 204 | booktitle={Computer Vision (ICCV), 2017 IEEE International Conference on}, 205 | year={2017} 206 | } 207 | 208 | ``` 209 | 210 | 211 | ## Related Projects: 212 | **[contrastive-unpaired-translation](https://github.com/taesungp/contrastive-unpaired-translation) (CUT)**
213 | **[pix2pix-Torch](https://github.com/phillipi/pix2pix) | [pix2pixHD](https://github.com/NVIDIA/pix2pixHD) | 214 | [BicycleGAN](https://github.com/junyanz/BicycleGAN) | [vid2vid](https://tcwang0509.github.io/vid2vid/) | [SPADE/GauGAN](https://github.com/NVlabs/SPADE)**
215 | **[iGAN](https://github.com/junyanz/iGAN) | [GAN Dissection](https://github.com/CSAILVision/GANDissect) | [GAN Paint](http://ganpaint.io/)** 216 | 217 | ## Cat Paper Collection 218 | If you love cats, and love reading cool graphics, vision, and ML papers, please check out the Cat Paper [Collection](https://github.com/junyanz/CatPapers). 219 | 220 | 221 | ## Acknowledgments 222 | Code borrows from [pix2pix](https://github.com/phillipi/pix2pix) and [DCGAN](https://github.com/soumith/dcgan.torch). The data loader is modified from [DCGAN](https://github.com/soumith/dcgan.torch) and [Context-Encoder](https://github.com/pathak22/context-encoder). The generative network is adopted from [neural-style](https://github.com/jcjohnson/neural-style) with [Instance Normalization](https://github.com/DmitryUlyanov/texture_nets/blob/master/InstanceNormalization.lua). 223 | -------------------------------------------------------------------------------- /data/aligned_data_loader.lua: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- 2 | -- Subclass of BaseDataLoader that provides data from two datasets. 3 | -- The samples from the datasets are aligned 4 | -- The datasets are of the same size 5 | -------------------------------------------------------------------------------- 6 | require 'data.base_data_loader' 7 | 8 | local class = require 'class' 9 | data_util = paths.dofile('data_util.lua') 10 | 11 | AlignedDataLoader = class('AlignedDataLoader', 'BaseDataLoader') 12 | 13 | function AlignedDataLoader:__init(conf) 14 | BaseDataLoader.__init(self, conf) 15 | conf = conf or {} 16 | end 17 | 18 | function AlignedDataLoader:name() 19 | return 'AlignedDataLoader' 20 | end 21 | 22 | function AlignedDataLoader:Initialize(opt) 23 | opt.align_data = 1 24 | self.idx_A = {1, opt.input_nc} 25 | self.idx_B = {opt.input_nc+1, opt.input_nc+opt.output_nc} 26 | local nc = 3--opt.input_nc + opt.output_nc 27 | self.data = data_util.load_dataset('', opt, nc) 28 | end 29 | 30 | -- actually fetches the data 31 | -- |return|: a table of two tables, each corresponding to 32 | -- the batch for dataset A and dataset B 33 | function AlignedDataLoader:LoadBatchForAllDatasets() 34 | local batch_data, path = self.data:getBatch() 35 | local batchA = batch_data[{ {}, self.idx_A, {}, {} }] 36 | local batchB = batch_data[{ {}, self.idx_B, {}, {} }] 37 | 38 | return batchA, batchB, path, path 39 | end 40 | 41 | -- returns the size of each dataset 42 | function AlignedDataLoader:size(dataset) 43 | return self.data:size() 44 | end 45 | -------------------------------------------------------------------------------- /data/base_data_loader.lua: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- 2 | -- Base Class for Providing Data 3 | -------------------------------------------------------------------------------- 4 | 5 | local class = require 'class' 6 | require 'torch' 7 | 8 | BaseDataLoader = class('BaseDataLoader') 9 | 10 | function BaseDataLoader:__init(conf) 11 | conf = conf or {} 12 | self.data_tm = torch.Timer() 13 | end 14 | 15 | function BaseDataLoader:name() 16 | return 'BaseDataLoader' 17 | end 18 | 19 | function BaseDataLoader:Initialize(opt) 20 | end 21 | 22 | -- actually fetches the data 23 | -- |return|: a table of two tables, each corresponding to 24 | -- the batch for dataset A and dataset B 25 | function BaseDataLoader:LoadBatchForAllDatasets() 26 | return {},{},{},{} 27 | end 28 | 29 | -- returns the next batch 30 | -- a wrapper of getBatch(), which is meant to be overriden by subclasses 31 | -- |return|: a table of two tables, each corresponding to 32 | -- the batch for dataset A and dataset B 33 | function BaseDataLoader:GetNextBatch() 34 | self.data_tm:reset() 35 | self.data_tm:resume() 36 | local dataA, dataB, pathA, pathB = self:LoadBatchForAllDatasets() 37 | self.data_tm:stop() 38 | return dataA, dataB, pathA, pathB 39 | end 40 | 41 | function BaseDataLoader:time_elapsed_to_fetch_data() 42 | return self.data_tm:time().real 43 | end 44 | 45 | -- returns the size of each dataset 46 | function BaseDataLoader:size(dataset) 47 | return 0 48 | end 49 | 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /data/data.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | This data loader is a modified version of the one from dcgan.torch 3 | (see https://github.com/soumith/dcgan.torch/blob/master/data/data.lua). 4 | 5 | Copyright (c) 2016, Deepak Pathak [See LICENSE file for details] 6 | ]]-- 7 | 8 | local Threads = require 'threads' 9 | Threads.serialization('threads.sharedserialize') 10 | 11 | local data = {} 12 | 13 | local result = {} 14 | local unpack = unpack and unpack or table.unpack 15 | 16 | function data.new(n, opt_) 17 | opt_ = opt_ or {} 18 | local self = {} 19 | for k,v in pairs(data) do 20 | self[k] = v 21 | end 22 | 23 | local donkey_file = 'donkey_folder.lua' 24 | -- print('n..' .. n) 25 | if n > 0 then 26 | local options = opt_ 27 | self.threads = Threads(n, 28 | function() require 'torch' end, 29 | function(idx) 30 | opt = options 31 | tid = idx 32 | local seed = (opt.manualSeed and opt.manualSeed or 0) + idx 33 | torch.manualSeed(seed) 34 | torch.setnumthreads(1) 35 | print(string.format('Starting donkey with id: %d seed: %d', tid, seed)) 36 | assert(options, 'options not found') 37 | assert(opt, 'opt not given') 38 | print(opt) 39 | paths.dofile(donkey_file) 40 | end 41 | 42 | ) 43 | else 44 | if donkey_file then paths.dofile(donkey_file) end 45 | -- print('empty threads') 46 | self.threads = {} 47 | function self.threads:addjob(f1, f2) f2(f1()) end 48 | function self.threads:dojob() end 49 | function self.threads:synchronize() end 50 | end 51 | 52 | local nSamples = 0 53 | self.threads:addjob(function() return trainLoader:size() end, 54 | function(c) nSamples = c end) 55 | self.threads:synchronize() 56 | self._size = nSamples 57 | 58 | for i = 1, n do 59 | self.threads:addjob(self._getFromThreads, 60 | self._pushResult) 61 | end 62 | -- print(self.threads) 63 | return self 64 | end 65 | 66 | function data._getFromThreads() 67 | assert(opt.batchSize, 'opt.batchSize not found') 68 | return trainLoader:sample(opt.batchSize) 69 | end 70 | 71 | function data._pushResult(...) 72 | local res = {...} 73 | if res == nil then 74 | self.threads:synchronize() 75 | end 76 | result[1] = res 77 | end 78 | 79 | 80 | 81 | function data:getBatch() 82 | -- queue another job 83 | self.threads:addjob(self._getFromThreads, self._pushResult) 84 | self.threads:dojob() 85 | local res = result[1] 86 | 87 | img_data = res[1] 88 | img_paths = res[3] 89 | 90 | result[1] = nil 91 | if torch.type(img_data) == 'table' then 92 | img_data = unpack(img_data) 93 | end 94 | 95 | 96 | return img_data, img_paths 97 | end 98 | 99 | function data:size() 100 | return self._size 101 | end 102 | 103 | return data 104 | -------------------------------------------------------------------------------- /data/data_util.lua: -------------------------------------------------------------------------------- 1 | local data_util = {} 2 | 3 | require 'torch' 4 | -- options = require '../options.lua' 5 | -- load dataset from the file system 6 | -- |name|: name of the dataset. It's currently either 'A' or 'B' 7 | function data_util.load_dataset(name, opt, nc) 8 | local tensortype = torch.getdefaulttensortype() 9 | torch.setdefaulttensortype('torch.FloatTensor') 10 | 11 | local new_opt = options.clone(opt) 12 | new_opt.manualSeed = torch.random(1, 10000) -- fix seed 13 | new_opt.nc = nc 14 | torch.manualSeed(new_opt.manualSeed) 15 | local data_loader = paths.dofile('../data/data.lua') 16 | new_opt.phase = new_opt.phase .. name 17 | local data = data_loader.new(new_opt.nThreads, new_opt) 18 | print("Dataset Size " .. name .. ": ", data:size()) 19 | 20 | torch.setdefaulttensortype(tensortype) 21 | return data 22 | end 23 | 24 | return data_util 25 | -------------------------------------------------------------------------------- /data/dataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2015-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | require 'torch' 11 | torch.setdefaulttensortype('torch.FloatTensor') 12 | local ffi = require 'ffi' 13 | local class = require('pl.class') 14 | local dir = require 'pl.dir' 15 | local tablex = require 'pl.tablex' 16 | local argcheck = require 'argcheck' 17 | require 'sys' 18 | require 'xlua' 19 | require 'image' 20 | 21 | local dataset = torch.class('dataLoader') 22 | 23 | local initcheck = argcheck{ 24 | pack=true, 25 | help=[[ 26 | A dataset class for images in a flat folder structure (folder-name is class-name). 27 | Optimized for extremely large datasets (upwards of 14 million images). 28 | Tested only on Linux (as it uses command-line linux utilities to scale up) 29 | ]], 30 | {check=function(paths) 31 | local out = true; 32 | for k,v in ipairs(paths) do 33 | if type(v) ~= 'string' then 34 | print('paths can only be of string input'); 35 | out = false 36 | end 37 | end 38 | return out 39 | end, 40 | name="paths", 41 | type="table", 42 | help="Multiple paths of directories with images"}, 43 | 44 | {name="sampleSize", 45 | type="table", 46 | help="a consistent sample size to resize the images"}, 47 | 48 | {name="split", 49 | type="number", 50 | help="Percentage of split to go to Training" 51 | }, 52 | {name="serial_batches", 53 | type="number", 54 | help="if randomly sample training images"}, 55 | 56 | {name="samplingMode", 57 | type="string", 58 | help="Sampling mode: random | balanced ", 59 | default = "balanced"}, 60 | 61 | {name="verbose", 62 | type="boolean", 63 | help="Verbose mode during initialization", 64 | default = false}, 65 | 66 | {name="loadSize", 67 | type="table", 68 | help="a size to load the images to, initially", 69 | opt = true}, 70 | 71 | {name="forceClasses", 72 | type="table", 73 | help="If you want this loader to map certain classes to certain indices, " 74 | .. "pass a classes table that has {classname : classindex} pairs." 75 | .. " For example: {3 : 'dog', 5 : 'cat'}" 76 | .. "This function is very useful when you want two loaders to have the same " 77 | .. "class indices (trainLoader/testLoader for example)", 78 | opt = true}, 79 | 80 | {name="sampleHookTrain", 81 | type="function", 82 | help="applied to sample during training(ex: for lighting jitter). " 83 | .. "It takes the image path as input", 84 | opt = true}, 85 | 86 | {name="sampleHookTest", 87 | type="function", 88 | help="applied to sample during testing", 89 | opt = true}, 90 | } 91 | 92 | function dataset:__init(...) 93 | 94 | -- argcheck 95 | local args = initcheck(...) 96 | print(args) 97 | for k,v in pairs(args) do self[k] = v end 98 | 99 | if not self.loadSize then self.loadSize = self.sampleSize; end 100 | 101 | if not self.sampleHookTrain then self.sampleHookTrain = self.defaultSampleHook end 102 | if not self.sampleHookTest then self.sampleHookTest = self.defaultSampleHook end 103 | self.image_count = 1 104 | -- print('image_count_init', self.image_count) 105 | -- find class names 106 | self.classes = {} 107 | local classPaths = {} 108 | if self.forceClasses then 109 | for k,v in pairs(self.forceClasses) do 110 | self.classes[k] = v 111 | classPaths[k] = {} 112 | end 113 | end 114 | local function tableFind(t, o) for k,v in pairs(t) do if v == o then return k end end end 115 | -- loop over each paths folder, get list of unique class names, 116 | -- also store the directory paths per class 117 | -- for each class, 118 | for k,path in ipairs(self.paths) do 119 | -- print('path', path) 120 | local dirs = {} -- hack 121 | dirs[1] = path 122 | -- local dirs = dir.getdirectories(path); 123 | for k,dirpath in ipairs(dirs) do 124 | local class = paths.basename(dirpath) 125 | local idx = tableFind(self.classes, class) 126 | -- print(class) 127 | -- print(idx) 128 | if not idx then 129 | table.insert(self.classes, class) 130 | idx = #self.classes 131 | classPaths[idx] = {} 132 | end 133 | if not tableFind(classPaths[idx], dirpath) then 134 | table.insert(classPaths[idx], dirpath); 135 | end 136 | end 137 | end 138 | 139 | self.classIndices = {} 140 | for k,v in ipairs(self.classes) do 141 | self.classIndices[v] = k 142 | end 143 | 144 | -- define command-line tools, try your best to maintain OSX compatibility 145 | local wc = 'wc' 146 | local cut = 'cut' 147 | local find = 'find -H' -- if folder name is symlink, do find inside it after dereferencing 148 | 149 | if ffi.os == 'OSX' then 150 | wc = 'gwc' 151 | cut = 'gcut' 152 | find = 'gfind' 153 | end 154 | ---------------------------------------------------------------------- 155 | -- Options for the GNU find command 156 | local extensionList = {'jpg', 'png','JPG','PNG','JPEG', 'ppm', 'PPM', 'bmp', 'BMP'} 157 | local findOptions = ' -iname "*.' .. extensionList[1] .. '"' 158 | for i=2,#extensionList do 159 | findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"' 160 | end 161 | 162 | -- find the image path names 163 | self.imagePath = torch.CharTensor() -- path to each image in dataset 164 | self.imageClass = torch.LongTensor() -- class index of each image (class index in self.classes) 165 | self.classList = {} -- index of imageList to each image of a particular class 166 | self.classListSample = self.classList -- the main list used when sampling data 167 | 168 | print('running "find" on each class directory, and concatenate all' 169 | .. ' those filenames into a single file containing all image paths for a given class') 170 | -- so, generates one file per class 171 | local classFindFiles = {} 172 | for i=1,#self.classes do 173 | classFindFiles[i] = os.tmpname() 174 | end 175 | local combinedFindList = os.tmpname(); 176 | 177 | local tmpfile = os.tmpname() 178 | local tmphandle = assert(io.open(tmpfile, 'w')) 179 | -- iterate over classes 180 | for i, class in ipairs(self.classes) do 181 | -- iterate over classPaths 182 | for j,path in ipairs(classPaths[i]) do 183 | local command = find .. ' "' .. path .. '" ' .. findOptions 184 | .. ' >>"' .. classFindFiles[i] .. '" \n' 185 | tmphandle:write(command) 186 | end 187 | end 188 | io.close(tmphandle) 189 | os.execute('bash ' .. tmpfile) 190 | os.execute('rm -f ' .. tmpfile) 191 | 192 | print('now combine all the files to a single large file') 193 | local tmpfile = os.tmpname() 194 | local tmphandle = assert(io.open(tmpfile, 'w')) 195 | -- concat all finds to a single large file in the order of self.classes 196 | for i=1,#self.classes do 197 | local command = 'cat "' .. classFindFiles[i] .. '" >>' .. combinedFindList .. ' \n' 198 | tmphandle:write(command) 199 | end 200 | io.close(tmphandle) 201 | os.execute('bash ' .. tmpfile) 202 | os.execute('rm -f ' .. tmpfile) 203 | 204 | --========================================================================== 205 | print('load the large concatenated list of sample paths to self.imagePath') 206 | local cmd = wc .. " -L '" 207 | .. combinedFindList .. "' |" 208 | .. cut .. " -f1 -d' '" 209 | print('cmd..' .. cmd) 210 | local maxPathLength = tonumber(sys.fexecute(wc .. " -L '" 211 | .. combinedFindList .. "' |" 212 | .. cut .. " -f1 -d' '")) + 1 213 | local length = tonumber(sys.fexecute(wc .. " -l '" 214 | .. combinedFindList .. "' |" 215 | .. cut .. " -f1 -d' '")) 216 | assert(length > 0, "Could not find any image file in the given input paths") 217 | assert(maxPathLength > 0, "paths of files are length 0?") 218 | self.imagePath:resize(length, maxPathLength):fill(0) 219 | local s_data = self.imagePath:data() 220 | local count = 0 221 | for line in io.lines(combinedFindList) do 222 | ffi.copy(s_data, line) 223 | s_data = s_data + maxPathLength 224 | if self.verbose and count % 10000 == 0 then 225 | xlua.progress(count, length) 226 | end; 227 | count = count + 1 228 | end 229 | 230 | self.numSamples = self.imagePath:size(1) 231 | if self.verbose then print(self.numSamples .. ' samples found.') end 232 | --========================================================================== 233 | print('Updating classList and imageClass appropriately') 234 | self.imageClass:resize(self.numSamples) 235 | local runningIndex = 0 236 | for i=1,#self.classes do 237 | if self.verbose then xlua.progress(i, #(self.classes)) end 238 | local length = tonumber(sys.fexecute(wc .. " -l '" 239 | .. classFindFiles[i] .. "' |" 240 | .. cut .. " -f1 -d' '")) 241 | if length == 0 then 242 | error('Class has zero samples') 243 | else 244 | self.classList[i] = torch.linspace(runningIndex + 1, runningIndex + length, length):long() 245 | self.imageClass[{{runningIndex + 1, runningIndex + length}}]:fill(i) 246 | end 247 | runningIndex = runningIndex + length 248 | end 249 | 250 | --========================================================================== 251 | -- clean up temporary files 252 | print('Cleaning up temporary files') 253 | local tmpfilelistall = '' 254 | for i=1,#(classFindFiles) do 255 | tmpfilelistall = tmpfilelistall .. ' "' .. classFindFiles[i] .. '"' 256 | if i % 1000 == 0 then 257 | os.execute('rm -f ' .. tmpfilelistall) 258 | tmpfilelistall = '' 259 | end 260 | end 261 | os.execute('rm -f ' .. tmpfilelistall) 262 | os.execute('rm -f "' .. combinedFindList .. '"') 263 | --========================================================================== 264 | 265 | if self.split == 100 then 266 | self.testIndicesSize = 0 267 | else 268 | print('Splitting training and test sets to a ratio of ' 269 | .. self.split .. '/' .. (100-self.split)) 270 | self.classListTrain = {} 271 | self.classListTest = {} 272 | self.classListSample = self.classListTrain 273 | local totalTestSamples = 0 274 | -- split the classList into classListTrain and classListTest 275 | for i=1,#self.classes do 276 | local list = self.classList[i] 277 | local count = self.classList[i]:size(1) 278 | local splitidx = math.floor((count * self.split / 100) + 0.5) -- +round 279 | local perm = torch.randperm(count) 280 | self.classListTrain[i] = torch.LongTensor(splitidx) 281 | for j=1,splitidx do 282 | self.classListTrain[i][j] = list[perm[j]] 283 | end 284 | if splitidx == count then -- all samples were allocated to train set 285 | self.classListTest[i] = torch.LongTensor() 286 | else 287 | self.classListTest[i] = torch.LongTensor(count-splitidx) 288 | totalTestSamples = totalTestSamples + self.classListTest[i]:size(1) 289 | local idx = 1 290 | for j=splitidx+1,count do 291 | self.classListTest[i][idx] = list[perm[j]] 292 | idx = idx + 1 293 | end 294 | end 295 | end 296 | -- Now combine classListTest into a single tensor 297 | self.testIndices = torch.LongTensor(totalTestSamples) 298 | self.testIndicesSize = totalTestSamples 299 | local tdata = self.testIndices:data() 300 | local tidx = 0 301 | for i=1,#self.classes do 302 | local list = self.classListTest[i] 303 | if list:dim() ~= 0 then 304 | local ldata = list:data() 305 | for j=0,list:size(1)-1 do 306 | tdata[tidx] = ldata[j] 307 | tidx = tidx + 1 308 | end 309 | end 310 | end 311 | end 312 | end 313 | 314 | -- size(), size(class) 315 | function dataset:size(class, list) 316 | list = list or self.classList 317 | if not class then 318 | return self.numSamples 319 | elseif type(class) == 'string' then 320 | return list[self.classIndices[class]]:size(1) 321 | elseif type(class) == 'number' then 322 | return list[class]:size(1) 323 | end 324 | end 325 | 326 | -- getByClass 327 | function dataset:getByClass(class) 328 | local index = 0 329 | if self.serial_batches == 1 then 330 | index = math.fmod(self.image_count-1, self.classListSample[class]:nElement())+1 331 | self.image_count = self.image_count +1 332 | else 333 | index = math.ceil(torch.uniform() * self.classListSample[class]:nElement()) 334 | end 335 | 336 | local imgpath = ffi.string(torch.data(self.imagePath[self.classListSample[class][index]])) 337 | return self:sampleHookTrain(imgpath), imgpath 338 | end 339 | 340 | -- converts a table of samples (and corresponding labels) to a clean tensor 341 | local function tableToOutput(self, dataTable, scalarTable) 342 | local data, scalarLabels, labels 343 | if opt.resize_or_crop == 'crop' or opt.resize_or_crop == 'scale_width' or opt.resize_or_crop == 'scale_height' then 344 | assert(#scalarTable == 1) 345 | data = torch.Tensor(1, 346 | dataTable[1]:size(1), dataTable[1]:size(2), dataTable[1]:size(3)) 347 | data[1]:copy(dataTable[1]) 348 | scalarLabels = torch.LongTensor(#scalarTable):fill(-1111) 349 | else 350 | local quantity = #scalarTable 351 | data = torch.Tensor(quantity, 352 | self.sampleSize[1], self.sampleSize[2], self.sampleSize[3]) 353 | scalarLabels = torch.LongTensor(quantity):fill(-1111) 354 | for i=1,#dataTable do 355 | data[i]:copy(dataTable[i]) 356 | scalarLabels[i] = scalarTable[i] 357 | end 358 | end 359 | return data, scalarLabels 360 | end 361 | 362 | -- sampler, samples from the training set. 363 | function dataset:sample(quantity) 364 | assert(quantity) 365 | local dataTable = {} 366 | local scalarTable = {} 367 | local samplePaths = {} 368 | for i=1,quantity do 369 | local class = torch.random(1, #self.classes) 370 | local out, imgpath = self:getByClass(class) 371 | table.insert(dataTable, out) 372 | table.insert(scalarTable, class) 373 | samplePaths[i] = imgpath 374 | end 375 | 376 | local data, scalarLabels = tableToOutput(self, dataTable, scalarTable) 377 | return data, scalarLabels, samplePaths-- filePaths 378 | end 379 | 380 | function dataset:get(i1, i2) 381 | local indices = torch.range(i1, i2); 382 | local quantity = i2 - i1 + 1; 383 | assert(quantity > 0) 384 | -- now that indices has been initialized, get the samples 385 | local dataTable = {} 386 | local scalarTable = {} 387 | for i=1,quantity do 388 | -- load the sample 389 | local imgpath = ffi.string(torch.data(self.imagePath[indices[i]])) 390 | local out = self:sampleHookTest(imgpath) 391 | table.insert(dataTable, out) 392 | table.insert(scalarTable, self.imageClass[indices[i]]) 393 | end 394 | local data, scalarLabels = tableToOutput(self, dataTable, scalarTable) 395 | return data, scalarLabels 396 | end 397 | 398 | return dataset 399 | -------------------------------------------------------------------------------- /data/donkey_folder.lua: -------------------------------------------------------------------------------- 1 | 2 | --[[ 3 | This data loader is a modified version of the one from dcgan.torch 4 | (see https://github.com/soumith/dcgan.torch/blob/master/data/donkey_folder.lua). 5 | Copyright (c) 2016, Deepak Pathak [See LICENSE file for details] 6 | Copyright (c) 2015-present, Facebook, Inc. 7 | All rights reserved. 8 | This source code is licensed under the BSD-style license found in the 9 | LICENSE file in the root directory of this source tree. An additional grant 10 | of patent rights can be found in the PATENTS file in the same directory. 11 | ]]-- 12 | 13 | require 'image' 14 | paths.dofile('dataset.lua') 15 | -- This file contains the data-loading logic and details. 16 | -- It is run by each data-loader thread. 17 | ------------------------------------------ 18 | -------- COMMON CACHES and PATHS 19 | -- Check for existence of opt.data 20 | if opt.DATA_ROOT then 21 | opt.data = paths.concat(opt.DATA_ROOT, opt.phase) 22 | else 23 | print(os.getenv('DATA_ROOT')) 24 | opt.data = paths.concat(os.getenv('DATA_ROOT'), opt.phase) 25 | end 26 | 27 | if not paths.dirp(opt.data) then 28 | error('Did not find directory: ' .. opt.data) 29 | end 30 | 31 | -- a cache file of the training metadata (if doesnt exist, will be created) 32 | local cache_prefix = opt.data:gsub('/', '_') 33 | os.execute(('mkdir -p %s'):format(opt.cache_dir)) 34 | local trainCache = paths.concat(opt.cache_dir, cache_prefix .. '_trainCache.t7') 35 | 36 | -------------------------------------------------------------------------------------------- 37 | local input_nc = opt.nc -- input channels 38 | local loadSize = {input_nc, opt.loadSize} 39 | local sampleSize = {input_nc, opt.fineSize} 40 | 41 | local function loadImage(path) 42 | local input = image.load(path, 3, 'float') 43 | local h = input:size(2) 44 | local w = input:size(3) 45 | 46 | local imA = image.crop(input, 0, 0, w/2, h) 47 | imA = image.scale(imA, loadSize[2], loadSize[2]) 48 | local imB = image.crop(input, w/2, 0, w, h) 49 | imB = image.scale(imB, loadSize[2], loadSize[2]) 50 | 51 | local perm = torch.LongTensor{3, 2, 1} 52 | imA = imA:index(1, perm) 53 | imA = imA:mul(2):add(-1) 54 | imB = imB:index(1, perm) 55 | imB = imB:mul(2):add(-1) 56 | 57 | assert(imA:max()<=1,"A: badly scaled inputs") 58 | assert(imA:min()>=-1,"A: badly scaled inputs") 59 | assert(imB:max()<=1,"B: badly scaled inputs") 60 | assert(imB:min()>=-1,"B: badly scaled inputs") 61 | 62 | 63 | local oW = sampleSize[2] 64 | local oH = sampleSize[2] 65 | local iH = imA:size(2) 66 | local iW = imA:size(3) 67 | 68 | if iH~=oH then 69 | h1 = math.ceil(torch.uniform(1e-2, iH-oH)) 70 | end 71 | 72 | if iW~=oW then 73 | w1 = math.ceil(torch.uniform(1e-2, iW-oW)) 74 | end 75 | if iH ~= oH or iW ~= oW then 76 | imA = image.crop(imA, w1, h1, w1 + oW, h1 + oH) 77 | imB = image.crop(imB, w1, h1, w1 + oW, h1 + oH) 78 | end 79 | 80 | if opt.flip == 1 and torch.uniform() > 0.5 then 81 | imA = image.hflip(imA) 82 | imB = image.hflip(imB) 83 | end 84 | 85 | local concatenated = torch.cat(imA,imB,1) 86 | 87 | return concatenated 88 | end 89 | 90 | 91 | local function loadSingleImage(path) 92 | local im = image.load(path, input_nc, 'float') 93 | if opt.resize_or_crop == 'resize_and_crop' then 94 | im = image.scale(im, loadSize[2], loadSize[2]) 95 | end 96 | if input_nc == 3 then 97 | local perm = torch.LongTensor{3, 2, 1} 98 | im = im:index(1, perm)--:mul(256.0): brg, rgb 99 | im = im:mul(2):add(-1) 100 | end 101 | assert(im:max()<=1,"A: badly scaled inputs") 102 | assert(im:min()>=-1,"A: badly scaled inputs") 103 | 104 | local oW = sampleSize[2] 105 | local oH = sampleSize[2] 106 | local iH = im:size(2) 107 | local iW = im:size(3) 108 | if (opt.resize_or_crop == 'resize_and_crop' ) then 109 | local h1, w1 = 0, 0 110 | if iH~=oH then 111 | h1 = math.ceil(torch.uniform(1e-2, iH-oH)) 112 | end 113 | if iW~=oW then 114 | w1 = math.ceil(torch.uniform(1e-2, iW-oW)) 115 | end 116 | if iH ~= oH or iW ~= oW then 117 | im = image.crop(im, w1, h1, w1 + oW, h1 + oH) 118 | end 119 | elseif (opt.resize_or_crop == 'combined') then 120 | local sH = math.min(math.ceil(oH * torch.uniform(1+1e-2, 2.0-1e-2)), iH-1e-2) 121 | local sW = math.min(math.ceil(oW * torch.uniform(1+1e-2, 2.0-1e-2)), iW-1e-2) 122 | local h1 = math.ceil(torch.uniform(1e-2, iH-sH)) 123 | local w1 = math.ceil(torch.uniform(1e-2, iW-sW)) 124 | im = image.crop(im, w1, h1, w1 + sW, h1 + sH) 125 | im = image.scale(im, oW, oH) 126 | elseif (opt.resize_or_crop == 'crop') then 127 | local w = math.min(math.min(oH, iH),iW) 128 | w = math.floor(w/4)*4 129 | local x = math.floor(torch.uniform(0, iW - w)) 130 | local y = math.floor(torch.uniform(0, iH - w)) 131 | im = image.crop(im, x, y, x+w, y+w) 132 | elseif (opt.resize_or_crop == 'scale_width') then 133 | w = oW 134 | h = torch.floor(iH * oW/iW) 135 | im = image.scale(im, w, h) 136 | elseif (opt.resize_or_crop == 'scale_height') then 137 | h = oH 138 | w = torch.floor(iW * oH / iH) 139 | im = image.scale(im, w, h) 140 | end 141 | 142 | if opt.flip == 1 and torch.uniform() > 0.5 then 143 | im = image.hflip(im) 144 | end 145 | 146 | return im 147 | 148 | end 149 | 150 | -- channel-wise mean and std. Calculate or load them from disk later in the script. 151 | local mean,std 152 | -------------------------------------------------------------------------------- 153 | -- Hooks that are used for each image that is loaded 154 | 155 | -- function to load the image, jitter it appropriately (random crops etc.) 156 | local trainHook_singleimage = function(self, path) 157 | collectgarbage() 158 | -- print('load single image') 159 | local im = loadSingleImage(path) 160 | return im 161 | end 162 | 163 | -- function that loads images that have juxtaposition 164 | -- of two images from two domains 165 | local trainHook_doubleimage = function(self, path) 166 | -- print('load double image') 167 | collectgarbage() 168 | 169 | local im = loadImage(path) 170 | return im 171 | end 172 | 173 | 174 | if opt.align_data > 0 then 175 | sample_nc = input_nc*2 176 | trainHook = trainHook_doubleimage 177 | else 178 | sample_nc = input_nc 179 | trainHook = trainHook_singleimage 180 | end 181 | 182 | trainLoader = dataLoader{ 183 | paths = {opt.data}, 184 | loadSize = {input_nc, loadSize[2], loadSize[2]}, 185 | sampleSize = {sample_nc, sampleSize[2], sampleSize[2]}, 186 | split = 100, 187 | serial_batches = opt.serial_batches, 188 | verbose = true 189 | } 190 | 191 | trainLoader.sampleHookTrain = trainHook 192 | collectgarbage() 193 | 194 | -- do some sanity checks on trainLoader 195 | do 196 | local class = trainLoader.imageClass 197 | local nClasses = #trainLoader.classes 198 | assert(class:max() <= nClasses, "class logic has error") 199 | assert(class:min() >= 1, "class logic has error") 200 | end 201 | -------------------------------------------------------------------------------- /data/unaligned_data_loader.lua: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- 2 | -- Subclass of BaseDataLoader that provides data from two datasets. 3 | -- The samples from the datasets are not aligned. 4 | -- The datasets can have different sizes 5 | -------------------------------------------------------------------------------- 6 | require 'data.base_data_loader' 7 | 8 | local class = require 'class' 9 | data_util = paths.dofile('data_util.lua') 10 | 11 | UnalignedDataLoader = class('UnalignedDataLoader', 'BaseDataLoader') 12 | 13 | function UnalignedDataLoader:__init(conf) 14 | BaseDataLoader.__init(self, conf) 15 | conf = conf or {} 16 | end 17 | 18 | function UnalignedDataLoader:name() 19 | return 'UnalignedDataLoader' 20 | end 21 | 22 | function UnalignedDataLoader:Initialize(opt) 23 | opt.align_data = 0 24 | self.dataA = data_util.load_dataset('A', opt, opt.input_nc) 25 | self.dataB = data_util.load_dataset('B', opt, opt.output_nc) 26 | end 27 | 28 | -- actually fetches the data 29 | -- |return|: a table of two tables, each corresponding to 30 | -- the batch for dataset A and dataset B 31 | function UnalignedDataLoader:LoadBatchForAllDatasets() 32 | local batchA, pathA = self.dataA:getBatch() 33 | local batchB, pathB = self.dataB:getBatch() 34 | return batchA, batchB, pathA, pathB 35 | end 36 | 37 | -- returns the size of each dataset 38 | function UnalignedDataLoader:size(dataset) 39 | if dataset == 'A' then 40 | return self.dataA:size() 41 | end 42 | 43 | if dataset == 'B' then 44 | return self.dataB:size() 45 | end 46 | 47 | return math.max(self.dataA:size(), self.dataB:size()) 48 | -- return the size of the largest dataset by default 49 | end 50 | -------------------------------------------------------------------------------- /datasets/bibtex/cityscapes.tex: -------------------------------------------------------------------------------- 1 | @inproceedings{Cordts2016Cityscapes, 2 | title={The Cityscapes Dataset for Semantic Urban Scene Understanding}, 3 | author={Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt}, 4 | booktitle={Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 5 | year={2016} 6 | } 7 | -------------------------------------------------------------------------------- /datasets/bibtex/facades.tex: -------------------------------------------------------------------------------- 1 | @INPROCEEDINGS{Tylecek13, 2 | author = {Radim Tyle{\v c}ek, Radim {\v S}{\' a}ra}, 3 | title = {Spatial Pattern Templates for Recognition of Objects with Regular Structure}, 4 | booktitle = {Proc. GCPR}, 5 | year = {2013}, 6 | address = {Saarbrucken, Germany}, 7 | } 8 | -------------------------------------------------------------------------------- /datasets/download_dataset.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | 3 | if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then 4 | echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos" 5 | exit 1 6 | fi 7 | 8 | if [[ $FILE == "cityscapes" ]]; then 9 | echo "Due to license issue, we cannot provide the Cityscapes dataset from our repository. Please download the Cityscapes dataset from https://cityscapes-dataset.com, and use the script ./datasets/prepare_cityscapes_dataset.py." 10 | echo "You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip. For further instruction, please read ./datasets/prepare_cityscapes_dataset.py" 11 | exit 1 12 | fi 13 | 14 | URL=https://efrosgans.eecs.berkeley.edu/cyclegan/datasets/$FILE.zip 15 | ZIP_FILE=./datasets/$FILE.zip 16 | TARGET_DIR=./datasets/$FILE/ 17 | wget -N $URL -O $ZIP_FILE 18 | mkdir $TARGET_DIR 19 | unzip $ZIP_FILE -d ./datasets/ 20 | rm $ZIP_FILE 21 | -------------------------------------------------------------------------------- /datasets/prepare_cityscapes_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from PIL import Image 4 | 5 | help_msg = """ 6 | The dataset can be downloaded from https://cityscapes-dataset.com. 7 | Please download the datasets [gtFine_trainvaltest.zip] and [leftImg8bit_trainvaltest.zip] and unzip them. 8 | gtFine contains the semantics segmentations. Use --gtFine_dir to specify the path to the unzipped gtFine_trainvaltest directory. 9 | leftImg8bit contains the dashcam photographs. Use --leftImg8bit_dir to specify the path to the unzipped leftImg8bit_trainvaltest directory. 10 | The processed images will be placed at --output_dir. 11 | 12 | Example usage: 13 | 14 | python prepare_cityscapes_dataset.py --gitFine_dir ./gtFine/ --leftImg8bit_dir ./leftImg8bit --output_dir ./datasets/cityscapes/ 15 | """ 16 | 17 | def load_resized_img(path): 18 | return Image.open(path).convert('RGB').resize((256, 256)) 19 | 20 | def check_matching_pair(segmap_path, photo_path): 21 | segmap_identifier = os.path.basename(segmap_path).replace('_gtFine_color', '') 22 | photo_identifier = os.path.basename(photo_path).replace('_leftImg8bit', '') 23 | 24 | assert segmap_identifier == photo_identifier, \ 25 | "[%s] and [%s] don't seem to be matching. Aborting." % (segmap_path, photo_path) 26 | 27 | 28 | def process_cityscapes(gtFine_dir, leftImg8bit_dir, output_dir, phase): 29 | save_phase = 'test' if phase == 'val' else 'train' 30 | savedir = os.path.join(output_dir, save_phase) 31 | os.makedirs(savedir, exist_ok=True) 32 | os.makedirs(savedir + 'A', exist_ok=True) 33 | os.makedirs(savedir + 'B', exist_ok=True) 34 | print("Directory structure prepared at %s" % output_dir) 35 | 36 | segmap_expr = os.path.join(gtFine_dir, phase) + "/*/*_color.png" 37 | segmap_paths = glob.glob(segmap_expr) 38 | segmap_paths = sorted(segmap_paths) 39 | 40 | photo_expr = os.path.join(leftImg8bit_dir, phase) + "/*/*_leftImg8bit.png" 41 | photo_paths = glob.glob(photo_expr) 42 | photo_paths = sorted(photo_paths) 43 | 44 | assert len(segmap_paths) == len(photo_paths), \ 45 | "%d images that match [%s], and %d images that match [%s]. Aborting." % (len(segmap_paths), segmap_expr, len(photo_paths), photo_expr) 46 | 47 | for i, (segmap_path, photo_path) in enumerate(zip(segmap_paths, photo_paths)): 48 | check_matching_pair(segmap_path, photo_path) 49 | segmap = load_resized_img(segmap_path) 50 | photo = load_resized_img(photo_path) 51 | 52 | # data for pix2pix where the two images are placed side-by-side 53 | sidebyside = Image.new('RGB', (512, 256)) 54 | sidebyside.paste(segmap, (256, 0)) 55 | sidebyside.paste(photo, (0, 0)) 56 | savepath = os.path.join(savedir, "%d.jpg" % i) 57 | sidebyside.save(savepath, format='JPEG', subsampling=0, quality=100) 58 | 59 | # data for cyclegan where the two images are stored at two distinct directories 60 | savepath = os.path.join(savedir + 'A', "%d_A.jpg" % i) 61 | photo.save(savepath, format='JPEG', subsampling=0, quality=100) 62 | savepath = os.path.join(savedir + 'B', "%d_B.jpg" % i) 63 | segmap.save(savepath, format='JPEG', subsampling=0, quality=100) 64 | 65 | if i % (len(segmap_paths) // 10) == 0: 66 | print("%d / %d: last image saved at %s, " % (i, len(segmap_paths), savepath)) 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | if __name__ == '__main__': 78 | import argparse 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument('--gtFine_dir', type=str, required=True, 81 | help='Path to the Cityscapes gtFine directory.') 82 | parser.add_argument('--leftImg8bit_dir', type=str, required=True, 83 | help='Path to the Cityscapes leftImg8bit_trainvaltest directory.') 84 | parser.add_argument('--output_dir', type=str, required=True, 85 | default='./datasets/cityscapes', 86 | help='Directory the output images will be written to.') 87 | opt = parser.parse_args() 88 | 89 | print(help_msg) 90 | 91 | print('Preparing Cityscapes Dataset for val phase') 92 | process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "val") 93 | print('Preparing Cityscapes Dataset for train phase') 94 | process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "train") 95 | 96 | print('Done') 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /examples/test_vangogh_style_on_ae_photos.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | ## This script download the dataset and pre-trained network, 4 | ## and generates style transferred images. 5 | 6 | # Download the dataset. The downloaded dataset is stored in ./datasets/${DATASET_NAME} 7 | DATASET_NAME='ae_photos' 8 | bash ./datasets/download_dataset.sh $DATASET_NAME 9 | 10 | # Download the pre-trained model. The downloaded model is stored in ./models/${MODEL_NAME}_pretrained/latest_net_G.t7 11 | MODEL_NAME='style_vangogh' 12 | bash ./pretrained_models/download_model.sh $MODEL_NAME 13 | 14 | # Run style transfer using the downloaded dataset and model 15 | DATA_ROOT=./datasets/$DATASET_NAME name=${MODEL_NAME}_pretrained model=one_direction_test phase=test how_many='all' loadSize=256 fineSize=256 resize_or_crop='scale_width' th test.lua 16 | 17 | if [ $? == 0 ]; then 18 | echo "The result can be viewed at ./results/${MODEL_NAME}_pretrained/latest_test/index.html" 19 | fi 20 | -------------------------------------------------------------------------------- /examples/train_maps.sh: -------------------------------------------------------------------------------- 1 | DB_NAME='maps' 2 | GPU_ID=1 3 | DISPLAY_ID=1 4 | NET_G=resnet_6blocks 5 | NET_D=basic 6 | MODEL=cycle_gan 7 | SAVE_EPOCH=5 8 | ALIGN_DATA=0 9 | LAMBDA=10 10 | NF=64 11 | 12 | 13 | EXPR_NAME=${DB_NAME}_${MODEL}_${LAMBDA} 14 | 15 | CHECKPOINT_DIR=./checkpoints/ 16 | LOG_FILE=${CHECKPOINT_DIR}${EXPR_NAME}/log.txt 17 | mkdir -p ${CHECKPOINT_DIR}${EXPR_NAME} 18 | 19 | DATA_ROOT=./datasets/$DB_NAME align_data=$ALIGN_DATA use_lsgan=1 \ 20 | which_direction='AtoB' display_plot=$PLOT pool_size=50 niter=100 niter_decay=100 \ 21 | which_model_netG=$NET_G which_model_netD=$NET_D model=$MODEL lr=0.0002 print_freq=200 lambda_A=$LAMBDA lambda_B=$LAMBDA \ 22 | loadSize=143 fineSize=128 gpu=$GPU_ID display_winsize=128 \ 23 | name=$EXPR_NAME flip=1 save_epoch_freq=$SAVE_EPOCH \ 24 | continue_train=0 display_id=$DISPLAY_ID \ 25 | checkpoints_dir=$CHECKPOINT_DIR\ 26 | th train.lua | tee -a $LOG_FILE 27 | -------------------------------------------------------------------------------- /imgs/failure_putin.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyanz/CycleGAN/40b4498526de6b566f94a000d98fe18b9b8921db/imgs/failure_putin.jpg -------------------------------------------------------------------------------- /imgs/horse2zebra.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyanz/CycleGAN/40b4498526de6b566f94a000d98fe18b9b8921db/imgs/horse2zebra.gif -------------------------------------------------------------------------------- /imgs/objects.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyanz/CycleGAN/40b4498526de6b566f94a000d98fe18b9b8921db/imgs/objects.jpg -------------------------------------------------------------------------------- /imgs/painting2photo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyanz/CycleGAN/40b4498526de6b566f94a000d98fe18b9b8921db/imgs/painting2photo.jpg -------------------------------------------------------------------------------- /imgs/paper_thumbnail.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyanz/CycleGAN/40b4498526de6b566f94a000d98fe18b9b8921db/imgs/paper_thumbnail.jpg -------------------------------------------------------------------------------- /imgs/photo2painting.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyanz/CycleGAN/40b4498526de6b566f94a000d98fe18b9b8921db/imgs/photo2painting.jpg -------------------------------------------------------------------------------- /imgs/photo_enhancement.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyanz/CycleGAN/40b4498526de6b566f94a000d98fe18b9b8921db/imgs/photo_enhancement.jpg -------------------------------------------------------------------------------- /imgs/season.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyanz/CycleGAN/40b4498526de6b566f94a000d98fe18b9b8921db/imgs/season.jpg -------------------------------------------------------------------------------- /imgs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyanz/CycleGAN/40b4498526de6b566f94a000d98fe18b9b8921db/imgs/teaser.jpg -------------------------------------------------------------------------------- /models/architectures.lua: -------------------------------------------------------------------------------- 1 | require 'nngraph' 2 | 3 | 4 | ---------------------------------------------------------------------------- 5 | local function weights_init(m) 6 | local name = torch.type(m) 7 | if name:find('Convolution') then 8 | m.weight:normal(0.0, 0.02) 9 | m.bias:fill(0) 10 | elseif name:find('Normalization') then 11 | if m.weight then m.weight:normal(1.0, 0.02) end 12 | if m.bias then m.bias:fill(0) end 13 | end 14 | end 15 | 16 | 17 | normalization = nil 18 | 19 | function set_normalization(norm) 20 | if norm == 'instance' then 21 | require 'util.InstanceNormalization' 22 | print('use InstanceNormalization') 23 | normalization = nn.InstanceNormalization 24 | elseif norm == 'batch' then 25 | print('use SpatialBatchNormalization') 26 | normalization = nn.SpatialBatchNormalization 27 | end 28 | end 29 | 30 | function defineG(input_nc, output_nc, ngf, which_model_netG, nz, arch) 31 | local netG = nil 32 | if which_model_netG == "encoder_decoder" then netG = defineG_encoder_decoder(input_nc, output_nc, ngf) 33 | elseif which_model_netG == "unet128" then netG = defineG_unet128(input_nc, output_nc, ngf) 34 | elseif which_model_netG == "unet256" then netG = defineG_unet256(input_nc, output_nc, ngf) 35 | elseif which_model_netG == "resnet_6blocks" then netG = defineG_resnet_6blocks(input_nc, output_nc, ngf) 36 | elseif which_model_netG == "resnet_9blocks" then netG = defineG_resnet_9blocks(input_nc, output_nc, ngf) 37 | else error("unsupported netG model") 38 | end 39 | netG:apply(weights_init) 40 | 41 | return netG 42 | end 43 | 44 | function defineD(input_nc, ndf, which_model_netD, n_layers_D, use_sigmoid) 45 | local netD = nil 46 | if which_model_netD == "basic" then netD = defineD_basic(input_nc, ndf, use_sigmoid) 47 | elseif which_model_netD == "imageGAN" then netD = defineD_imageGAN(input_nc, ndf, use_sigmoid) 48 | elseif which_model_netD == "n_layers" then netD = defineD_n_layers(input_nc, ndf, n_layers_D, use_sigmoid) 49 | else error("unsupported netD model") 50 | end 51 | netD:apply(weights_init) 52 | 53 | return netD 54 | end 55 | 56 | function defineG_encoder_decoder(input_nc, output_nc, ngf) 57 | -- input is (nc) x 256 x 256 58 | local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1) 59 | -- input is (ngf) x 128 x 128 60 | local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2) 61 | -- input is (ngf * 2) x 64 x 64 62 | local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4) 63 | -- input is (ngf * 4) x 32 x 32 64 | local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) 65 | -- input is (ngf * 8) x 16 x 16 66 | local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) 67 | -- input is (ngf * 8) x 8 x 8 68 | local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) 69 | -- input is (ngf * 8) x 4 x 4 70 | local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) 71 | -- input is (ngf * 8) x 2 x 2 72 | local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- normalization(ngf * 8) 73 | -- input is (ngf * 8) x 1 x 1 74 | 75 | local d1 = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5) 76 | -- input is (ngf * 8) x 2 x 2 77 | local d2 = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5) 78 | -- input is (ngf * 8) x 4 x 4 79 | local d3 = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5) 80 | -- input is (ngf * 8) x 8 x 8 81 | local d4 = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) 82 | -- input is (ngf * 8) x 16 x 16 83 | local d5 = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4) 84 | -- input is (ngf * 4) x 32 x 32 85 | local d6 = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2) 86 | -- input is (ngf * 2) x 64 x 64 87 | local d7 = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf) 88 | -- input is (ngf) x128 x 128 89 | local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf, output_nc, 4, 4, 2, 2, 1, 1) 90 | -- input is (nc) x 256 x 256 91 | local o1 = d8 - nn.Tanh() 92 | 93 | local netG = nn.gModule({e1},{o1}) 94 | return netG 95 | end 96 | 97 | 98 | function defineG_unet128(input_nc, output_nc, ngf) 99 | local netG = nil 100 | -- input is (nc) x 128 x 128 101 | local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1) 102 | -- input is (ngf) x 64 x 64 103 | local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2) 104 | -- input is (ngf * 2) x 32 x 32 105 | local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4) 106 | -- input is (ngf * 4) x 16 x 16 107 | local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) 108 | -- input is (ngf * 8) x 8 x 8 109 | local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) 110 | -- input is (ngf * 8) x 4 x 4 111 | local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) 112 | -- input is (ngf * 8) x 2 x 2 113 | local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- normalization(ngf * 8) 114 | -- input is (ngf * 8) x 1 x 1 115 | 116 | local d1_ = e7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5) 117 | -- input is (ngf * 8) x 2 x 2 118 | local d1 = {d1_,e6} - nn.JoinTable(2) 119 | local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5) 120 | -- input is (ngf * 8) x 4 x 4 121 | local d2 = {d2_,e5} - nn.JoinTable(2) 122 | local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5) 123 | -- input is (ngf * 8) x 8 x 8 124 | local d3 = {d3_,e4} - nn.JoinTable(2) 125 | local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4) 126 | -- input is (ngf * 8) x 16 x 16 127 | local d4 = {d4_,e3} - nn.JoinTable(2) 128 | local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2) 129 | -- input is (ngf * 4) x 32 x 32 130 | local d5 = {d5_,e2} - nn.JoinTable(2) 131 | local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf) 132 | -- input is (ngf * 2) x 64 x 64 133 | local d6 = {d6_,e1} - nn.JoinTable(2) 134 | 135 | local d7 = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1) 136 | -- input is (nc) x 128 x 128 137 | 138 | local o1 = d7 - nn.Tanh() 139 | local netG = nn.gModule({e1},{o1}) 140 | return netG 141 | end 142 | 143 | 144 | function defineG_unet256(input_nc, output_nc, ngf) 145 | local netG = nil 146 | -- input is (nc) x 256 x 256 147 | local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1) 148 | -- input is (ngf) x 128 x 128 149 | local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2) 150 | -- input is (ngf * 2) x 64 x 64 151 | local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4) 152 | -- input is (ngf * 4) x 32 x 32 153 | local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) 154 | -- input is (ngf * 8) x 16 x 16 155 | local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) 156 | -- input is (ngf * 8) x 8 x 8 157 | local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) 158 | -- input is (ngf * 8) x 4 x 4 159 | local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) 160 | -- input is (ngf * 8) x 2 x 2 161 | local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- - normalization(ngf * 8) 162 | -- input is (ngf * 8) x 1 x 1 163 | 164 | local d1_ = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5) 165 | -- input is (ngf * 8) x 2 x 2 166 | local d1 = {d1_,e7} - nn.JoinTable(2) 167 | local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5) 168 | -- input is (ngf * 8) x 4 x 4 169 | local d2 = {d2_,e6} - nn.JoinTable(2) 170 | local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5) 171 | -- input is (ngf * 8) x 8 x 8 172 | local d3 = {d3_,e5} - nn.JoinTable(2) 173 | local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) 174 | -- input is (ngf * 8) x 16 x 16 175 | local d4 = {d4_,e4} - nn.JoinTable(2) 176 | local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4) 177 | -- input is (ngf * 4) x 32 x 32 178 | local d5 = {d5_,e3} - nn.JoinTable(2) 179 | local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2) 180 | -- input is (ngf * 2) x 64 x 64 181 | local d6 = {d6_,e2} - nn.JoinTable(2) 182 | local d7_ = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf) 183 | -- input is (ngf) x128 x 128 184 | local d7 = {d7_,e1} - nn.JoinTable(2) 185 | local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1) 186 | -- input is (nc) x 256 x 256 187 | 188 | local o1 = d8 - nn.Tanh() 189 | local netG = nn.gModule({e1},{o1}) 190 | return netG 191 | end 192 | 193 | -------------------------------------------------------------------------------- 194 | -- Justin Johnson's model from https://github.com/jcjohnson/fast-neural-style/ 195 | -------------------------------------------------------------------------------- 196 | 197 | local function build_conv_block(dim, padding_type) 198 | local conv_block = nn.Sequential() 199 | local p = 0 200 | if padding_type == 'reflect' then 201 | conv_block:add(nn.SpatialReflectionPadding(1, 1, 1, 1)) 202 | elseif padding_type == 'replicate' then 203 | conv_block:add(nn.SpatialReplicationPadding(1, 1, 1, 1)) 204 | elseif padding_type == 'zero' then 205 | p = 1 206 | end 207 | conv_block:add(nn.SpatialConvolution(dim, dim, 3, 3, 1, 1, p, p)) 208 | conv_block:add(normalization(dim)) 209 | conv_block:add(nn.ReLU(true)) 210 | if padding_type == 'reflect' then 211 | conv_block:add(nn.SpatialReflectionPadding(1, 1, 1, 1)) 212 | elseif padding_type == 'replicate' then 213 | conv_block:add(nn.SpatialReplicationPadding(1, 1, 1, 1)) 214 | end 215 | conv_block:add(nn.SpatialConvolution(dim, dim, 3, 3, 1, 1, p, p)) 216 | conv_block:add(normalization(dim)) 217 | return conv_block 218 | end 219 | 220 | 221 | local function build_res_block(dim, padding_type) 222 | local conv_block = build_conv_block(dim, padding_type) 223 | local res_block = nn.Sequential() 224 | local concat = nn.ConcatTable() 225 | concat:add(conv_block) 226 | concat:add(nn.Identity()) 227 | 228 | res_block:add(concat):add(nn.CAddTable()) 229 | return res_block 230 | end 231 | 232 | function defineG_resnet_6blocks(input_nc, output_nc, ngf) 233 | padding_type = 'reflect' 234 | local ks = 3 235 | local netG = nil 236 | local f = 7 237 | local p = (f - 1) / 2 238 | local data = -nn.Identity() 239 | local e1 = data - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(input_nc, ngf, f, f, 1, 1) - normalization(ngf) - nn.ReLU(true) 240 | local e2 = e1 - nn.SpatialConvolution(ngf, ngf*2, ks, ks, 2, 2, 1, 1) - normalization(ngf*2) - nn.ReLU(true) 241 | local e3 = e2 - nn.SpatialConvolution(ngf*2, ngf*4, ks, ks, 2, 2, 1, 1) - normalization(ngf*4) - nn.ReLU(true) 242 | local d1 = e3 - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) 243 | - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) 244 | local d2 = d1 - nn.SpatialFullConvolution(ngf*4, ngf*2, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf*2) - nn.ReLU(true) 245 | local d3 = d2 - nn.SpatialFullConvolution(ngf*2, ngf, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf) - nn.ReLU(true) 246 | local d4 = d3 - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(ngf, output_nc, f, f, 1, 1) - nn.Tanh() 247 | netG = nn.gModule({data},{d4}) 248 | return netG 249 | end 250 | 251 | function defineG_resnet_9blocks(input_nc, output_nc, ngf) 252 | padding_type = 'reflect' 253 | local ks = 3 254 | local netG = nil 255 | local f = 7 256 | local p = (f - 1) / 2 257 | local data = -nn.Identity() 258 | local e1 = data - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(input_nc, ngf, f, f, 1, 1) - normalization(ngf) - nn.ReLU(true) 259 | local e2 = e1 - nn.SpatialConvolution(ngf, ngf*2, ks, ks, 2, 2, 1, 1) - normalization(ngf*2) - nn.ReLU(true) 260 | local e3 = e2 - nn.SpatialConvolution(ngf*2, ngf*4, ks, ks, 2, 2, 1, 1) - normalization(ngf*4) - nn.ReLU(true) 261 | local d1 = e3 - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) 262 | - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) 263 | - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) 264 | local d2 = d1 - nn.SpatialFullConvolution(ngf*4, ngf*2, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf*2) - nn.ReLU(true) 265 | local d3 = d2 - nn.SpatialFullConvolution(ngf*2, ngf, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf) - nn.ReLU(true) 266 | local d4 = d3 - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(ngf, output_nc, f, f, 1, 1) - nn.Tanh() 267 | netG = nn.gModule({data},{d4}) 268 | return netG 269 | end 270 | 271 | function defineD_imageGAN(input_nc, ndf, use_sigmoid) 272 | local netD = nn.Sequential() 273 | 274 | -- input is (nc) x 256 x 256 275 | netD:add(nn.SpatialConvolution(input_nc, ndf, 4, 4, 2, 2, 1, 1)) 276 | netD:add(nn.LeakyReLU(0.2, true)) 277 | -- state size: (ndf) x 128 x 128 278 | netD:add(nn.SpatialConvolution(ndf, ndf * 2, 4, 4, 2, 2, 1, 1)) 279 | netD:add(nn.SpatialBatchNormalization(ndf * 2)):add(nn.LeakyReLU(0.2, true)) 280 | -- state size: (ndf*2) x 64 x 64 281 | netD:add(nn.SpatialConvolution(ndf * 2, ndf*4, 4, 4, 2, 2, 1, 1)) 282 | netD:add(nn.SpatialBatchNormalization(ndf * 4)):add(nn.LeakyReLU(0.2, true)) 283 | -- state size: (ndf*4) x 32 x 32 284 | netD:add(nn.SpatialConvolution(ndf * 4, ndf * 8, 4, 4, 2, 2, 1, 1)) 285 | netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true)) 286 | -- state size: (ndf*8) x 16 x 16 287 | netD:add(nn.SpatialConvolution(ndf * 8, ndf * 8, 4, 4, 2, 2, 1, 1)) 288 | netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true)) 289 | -- state size: (ndf*8) x 8 x 8 290 | netD:add(nn.SpatialConvolution(ndf * 8, ndf * 8, 4, 4, 2, 2, 1, 1)) 291 | netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true)) 292 | -- state size: (ndf*8) x 4 x 4 293 | netD:add(nn.SpatialConvolution(ndf * 8, 1, 4, 4, 2, 2, 1, 1)) 294 | -- state size: 1 x 1 x 1 295 | if use_sigmoid then 296 | netD:add(nn.Sigmoid()) 297 | end 298 | 299 | return netD 300 | end 301 | 302 | 303 | 304 | function defineD_basic(input_nc, ndf, use_sigmoid) 305 | n_layers = 3 306 | return defineD_n_layers(input_nc, ndf, n_layers, use_sigmoid) 307 | end 308 | 309 | -- rf=1 310 | function defineD_pixelGAN(input_nc, ndf, use_sigmoid) 311 | 312 | local netD = nn.Sequential() 313 | 314 | -- input is (nc) x 256 x 256 315 | netD:add(nn.SpatialConvolution(input_nc, ndf, 1, 1, 1, 1, 0, 0)) 316 | netD:add(nn.LeakyReLU(0.2, true)) 317 | -- state size: (ndf) x 256 x 256 318 | netD:add(nn.SpatialConvolution(ndf, ndf * 2, 1, 1, 1, 1, 0, 0)) 319 | netD:add(normalization(ndf * 2)):add(nn.LeakyReLU(0.2, true)) 320 | -- state size: (ndf*2) x 256 x 256 321 | netD:add(nn.SpatialConvolution(ndf * 2, 1, 1, 1, 1, 1, 0, 0)) 322 | -- state size: 1 x 256 x 256 323 | if use_sigmoid then 324 | netD:add(nn.Sigmoid()) 325 | -- state size: 1 x 30 x 30 326 | end 327 | 328 | return netD 329 | end 330 | 331 | -- if n=0, then use pixelGAN (rf=1) 332 | -- else rf is 16 if n=1 333 | -- 34 if n=2 334 | -- 70 if n=3 335 | -- 142 if n=4 336 | -- 286 if n=5 337 | -- 574 if n=6 338 | function defineD_n_layers(input_nc, ndf, n_layers, use_sigmoid, kw, dropout_ratio) 339 | 340 | if dropout_ratio == nil then 341 | dropout_ratio = 0.0 342 | end 343 | 344 | if kw == nil then 345 | kw = 4 346 | end 347 | padw = math.ceil((kw-1)/2) 348 | 349 | if n_layers==0 then 350 | return defineD_pixelGAN(input_nc, ndf, use_sigmoid) 351 | else 352 | 353 | local netD = nn.Sequential() 354 | 355 | -- input is (nc) x 256 x 256 356 | -- print('input_nc', input_nc) 357 | netD:add(nn.SpatialConvolution(input_nc, ndf, kw, kw, 2, 2, padw, padw)) 358 | netD:add(nn.LeakyReLU(0.2, true)) 359 | 360 | local nf_mult = 1 361 | local nf_mult_prev = 1 362 | for n = 1, n_layers-1 do 363 | nf_mult_prev = nf_mult 364 | nf_mult = math.min(2^n,8) 365 | netD:add(nn.SpatialConvolution(ndf * nf_mult_prev, ndf * nf_mult, kw, kw, 2, 2, padw,padw)) 366 | netD:add(normalization(ndf * nf_mult)):add(nn.Dropout(dropout_ratio)) 367 | netD:add(nn.LeakyReLU(0.2, true)) 368 | end 369 | 370 | -- state size: (ndf*M) x N x N 371 | nf_mult_prev = nf_mult 372 | nf_mult = math.min(2^n_layers,8) 373 | netD:add(nn.SpatialConvolution(ndf * nf_mult_prev, ndf * nf_mult, kw, kw, 1, 1, padw, padw)) 374 | netD:add(normalization(ndf * nf_mult)):add(nn.LeakyReLU(0.2, true)) 375 | -- state size: (ndf*M*2) x (N-1) x (N-1) 376 | netD:add(nn.SpatialConvolution(ndf * nf_mult, 1, kw, kw, 1, 1, padw,padw)) 377 | -- state size: 1 x (N-2) x (N-2) 378 | if use_sigmoid then 379 | netD:add(nn.Sigmoid()) 380 | end 381 | -- state size: 1 x (N-2) x (N-2) 382 | return netD 383 | end 384 | end 385 | -------------------------------------------------------------------------------- /models/base_model.lua: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- 2 | -- Base Class for Providing Models 3 | -------------------------------------------------------------------------------- 4 | 5 | local class = require 'class' 6 | 7 | BaseModel = class('BaseModel') 8 | 9 | function BaseModel:__init(conf) 10 | conf = conf or {} 11 | end 12 | 13 | -- Returns the name of the model 14 | function BaseModel:model_name() 15 | return 'DoesNothingModel' 16 | end 17 | 18 | -- Defines models and networks 19 | function BaseModel:Initialize(opt) 20 | models = {} 21 | return models 22 | end 23 | 24 | -- Runs the forward pass of the network 25 | function BaseModel:Forward(input, opt) 26 | output = {} 27 | return output 28 | end 29 | 30 | -- Runs the backprop gradient descent 31 | -- Corresponds to a single batch of data 32 | function BaseModel:OptimizeParameters(opt) 33 | end 34 | 35 | -- This function can be used to reset momentum after each epoch 36 | function BaseModel:RefreshParameters(opt) 37 | end 38 | 39 | -- This function can be used to reset momentum after each epoch 40 | function BaseModel:UpdateLearningRate(opt) 41 | end 42 | -- Save the current model to the file system 43 | function BaseModel:Save(prefix, opt) 44 | end 45 | 46 | -- returns a string that describes the current errors 47 | function BaseModel:GetCurrentErrorDescription() 48 | return "No Error exists in BaseModel" 49 | end 50 | 51 | -- returns current errors 52 | function BaseModel:GetCurrentErrors(opt) 53 | return {} 54 | end 55 | 56 | -- returns a table of image/label pairs that describe 57 | -- the current results. 58 | -- |return|: a table of table. List of image/label pairs 59 | function BaseModel:GetCurrentVisuals(opt, size) 60 | return {} 61 | end 62 | 63 | -- returns a string that describes the display plot configuration 64 | function BaseModel:DisplayPlot(opt) 65 | return {} 66 | end 67 | -------------------------------------------------------------------------------- /models/bigan_model.lua: -------------------------------------------------------------------------------- 1 | local class = require 'class' 2 | require 'models.base_model' 3 | require 'models.architectures' 4 | require 'util.image_pool' 5 | util = paths.dofile('../util/util.lua') 6 | content = paths.dofile('../util/content_loss.lua') 7 | 8 | BiGANModel = class('BiGANModel', 'BaseModel') 9 | 10 | function BiGANModel:__init(conf) 11 | BaseModel.__init(self, conf) 12 | conf = conf or {} 13 | end 14 | 15 | function BiGANModel:model_name() 16 | return 'BiGANModel' 17 | end 18 | 19 | function BiGANModel:InitializeStates(use_wgan) 20 | optimState = {learningRate=opt.lr, beta1=opt.beta1,} 21 | return optimState 22 | end 23 | -- Defines models and networks 24 | function BiGANModel:Initialize(opt) 25 | if opt.test == 0 then 26 | self.realABPool = ImagePool(opt.pool_size) 27 | self.fakeABPool = ImagePool(opt.pool_size) 28 | end 29 | -- define tensors 30 | local d_input_nc = opt.input_nc + opt.output_nc 31 | self.real_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize) 32 | self.fake_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize) 33 | -- load/define models 34 | self.criterionGAN = nn.MSECriterion() 35 | 36 | local netG, netE, netD = nil, nil, nil 37 | if opt.continue_train == 1 then 38 | if opt.test == 1 then -- which_epoch option exists in test mode 39 | netG = util.load_test_model('G', opt) 40 | netE = util.load_test_model('E', opt) 41 | netD = util.load_test_model('D', opt) 42 | else 43 | netG = util.load_model('G', opt) 44 | netE = util.load_model('E', opt) 45 | netD = util.load_model('D', opt) 46 | end 47 | else 48 | -- netG_test = defineG(opt.input_nc, opt.output_nc, opt.ngf, "resnet_unet", opt.arch) 49 | -- os.exit() 50 | netD = defineD(d_input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, false) -- no sigmoid layer 51 | print('netD...', netD) 52 | netG = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.arch) 53 | print('netG...', netG) 54 | netE = defineG(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.arch) 55 | print('netE...', netE) 56 | 57 | end 58 | 59 | self.netD = netD 60 | self.netG = netG 61 | self.netE = netE 62 | 63 | -- define real/fake labels 64 | netD_output_size = self.netD:forward(self.real_AB):size() 65 | self.fake_label = torch.Tensor(netD_output_size):fill(0.0) 66 | self.real_label = torch.Tensor(netD_output_size):fill(1.0) -- no soft smoothing 67 | 68 | self.optimStateD = self:InitializeStates() 69 | self.optimStateG = self:InitializeStates() 70 | self.optimStateE = self:InitializeStates() 71 | self.A_idx = {{}, {1, opt.input_nc}, {}, {}} 72 | self.B_idx = {{}, {opt.input_nc+1, opt.input_nc+opt.output_nc}, {}, {}} 73 | self:RefreshParameters() 74 | 75 | print('---------- # Learnable Parameters --------------') 76 | print(('G = %d'):format(self.parametersG:size(1))) 77 | print(('E = %d'):format(self.parametersE:size(1))) 78 | print(('D = %d'):format(self.parametersD:size(1))) 79 | print('------------------------------------------------') 80 | -- os.exit() 81 | end 82 | 83 | -- Runs the forward pass of the network and 84 | -- saves the result to member variables of the class 85 | function BiGANModel:Forward(input, opt) 86 | if opt.which_direction == 'BtoA' then 87 | local temp = input.real_A 88 | input.real_A = input.real_B 89 | input.real_B = temp 90 | end 91 | self.real_AB[self.A_idx]:copy(input.real_A) 92 | self.fake_AB[self.B_idx]:copy(input.real_B) 93 | self.real_A = self.real_AB[self.A_idx] 94 | self.real_B = self.fake_AB[self.B_idx] 95 | self.fake_B = self.netG:forward(self.real_A):clone() 96 | self.fake_A = self.netE:forward(self.real_B):clone() 97 | self.real_AB[self.B_idx]:copy(self.fake_B) -- real_AB: real_A, fake_B -> real_label 98 | self.fake_AB[self.A_idx]:copy(self.fake_A) -- fake_AB: fake_A, real_B -> fake_label 99 | -- if opt.test == 0 then 100 | -- self.real_AB = self.realABPool:Query(self.real_AB) -- batch history 101 | -- self.fake_AB = self.fakeABPool:Query(self.fake_AB) -- batch history 102 | -- end 103 | end 104 | 105 | -- create closure to evaluate f(X) and df/dX of discriminator 106 | function BiGANModel:fDx_basic(x, gradParams, netD, real_AB, fake_AB, opt) 107 | util.BiasZero(netD) 108 | gradParams:zero() 109 | -- Real log(D_A(B)) 110 | local output = netD:forward(real_AB):clone() 111 | local errD_real = self.criterionGAN:forward(output, self.real_label) 112 | local df_do = self.criterionGAN:backward(output, self.real_label) 113 | netD:backward(real_AB, df_do) 114 | -- Fake + log(1 - D_A(G(A))) 115 | output = netD:forward(fake_AB):clone() 116 | local errD_fake = self.criterionGAN:forward(output, self.fake_label) 117 | local df_do2 = self.criterionGAN:backward(output, self.fake_label) 118 | netD:backward(fake_AB, df_do2) 119 | -- Compute loss 120 | local errD = (errD_real + errD_fake) / 2.0 121 | return errD, gradParams 122 | end 123 | 124 | 125 | function BiGANModel:fDx(x, opt) 126 | -- use image pool that stores the old fake images 127 | real_AB = self.realABPool:Query(self.real_AB) 128 | fake_AB = self.fakeABPool:Query(self.fake_AB) 129 | self.errD, gradParams = self:fDx_basic(x, self.gradParametersD, self.netD, real_AB, fake_AB, opt) 130 | return self.errD, gradParams 131 | end 132 | 133 | 134 | 135 | function BiGANModel:fGx_basic(x, netG, netD, gradParametersG, opt) 136 | util.BiasZero(netG) 137 | util.BiasZero(netD) 138 | gradParametersG:zero() 139 | 140 | -- First. G(A) should fake the discriminator 141 | local output = netD:forward(self.real_AB):clone() 142 | local errG = self.criterionGAN:forward(output, self.fake_label) 143 | local dgan_loss_dd = self.criterionGAN:backward(output, self.fake_label) 144 | local dgan_loss_do = netD:updateGradInput(self.real_AB, dgan_loss_dd) 145 | netG:backward(self.real_A, dgan_loss_do[self.B_idx]) -- real_AB: real_A, fake_B -> real_label 146 | return gradParametersG, errG 147 | end 148 | 149 | 150 | function BiGANModel:fGx(x, opt) 151 | self.gradParametersG, self.errG = self:fGx_basic(x, self.netG, self.netD, 152 | self.gradParametersG, opt) 153 | return self.errG, self.gradParametersG 154 | end 155 | 156 | 157 | function BiGANModel:fEx_basic(x, netE, netD, gradParametersE, opt) 158 | util.BiasZero(netE) 159 | util.BiasZero(netD) 160 | gradParametersE:zero() 161 | 162 | -- First. G(A) should fake the discriminator 163 | local output = netD:forward(self.fake_AB):clone() 164 | local errE= self.criterionGAN:forward(output, self.real_label) 165 | local dgan_loss_dd = self.criterionGAN:backward(output, self.real_label) 166 | local dgan_loss_do = netD:updateGradInput(self.fake_AB, dgan_loss_dd) 167 | netE:backward(self.real_B, dgan_loss_do[self.A_idx])-- fake_AB: fake_A, real_B -> fake_label 168 | return gradParametersE, errE 169 | end 170 | 171 | 172 | function BiGANModel:fEx(x, opt) 173 | self.gradParametersE, self.errE = self:fEx_basic(x, self.netE, self.netD, 174 | self.gradParametersE, opt) 175 | return self.errE, self.gradParametersE 176 | end 177 | 178 | 179 | function BiGANModel:OptimizeParameters(opt) 180 | local fG = function(x) return self:fGx(x, opt) end 181 | local fE = function(x) return self:fEx(x, opt) end 182 | local fD = function(x) return self:fDx(x, opt) end 183 | optim.adam(fD, self.parametersD, self.optimStateD) 184 | optim.adam(fG, self.parametersG, self.optimStateG) 185 | optim.adam(fE, self.parametersE, self.optimStateE) 186 | end 187 | 188 | function BiGANModel:RefreshParameters() 189 | self.parametersD, self.gradParametersD = nil, nil -- nil them to avoid spiking memory 190 | self.parametersG, self.gradParametersG = nil, nil 191 | self.parametersE, self.gradParametersE = nil, nil 192 | -- define parameters of optimization 193 | self.parametersD, self.gradParametersD = self.netD:getParameters() 194 | self.parametersG, self.gradParametersG = self.netG:getParameters() 195 | self.parametersE, self.gradParametersE = self.netE:getParameters() 196 | end 197 | 198 | function BiGANModel:Save(prefix, opt) 199 | util.save_model(self.netG, prefix .. '_net_G.t7', 1) 200 | util.save_model(self.netE, prefix .. '_net_E.t7', 1) 201 | util.save_model(self.netD, prefix .. '_net_D.t7', 1) 202 | end 203 | 204 | function BiGANModel:GetCurrentErrorDescription() 205 | description = ('D: %.4f G: %.4f E: %.4f'):format( 206 | self.errD and self.errD or -1, 207 | self.errG and self.errG or -1, 208 | self.errE and self.errE or -1) 209 | return description 210 | end 211 | 212 | function BiGANModel:GetCurrentErrors() 213 | local errors = {errD=self.errD, errG=self.errG, errE=self.errE} 214 | return errors 215 | end 216 | 217 | -- returns a string that describes the display plot configuration 218 | function BiGANModel:DisplayPlot(opt) 219 | return 'errD,errG,errE' 220 | end 221 | function BiGANModel:UpdateLearningRate(opt) 222 | local lrd = opt.lr / opt.niter_decay 223 | local old_lr = self.optimStateD['learningRate'] 224 | local lr = old_lr - lrd 225 | self.optimStateD['learningRate'] = lr 226 | self.optimStateG['learningRate'] = lr 227 | self.optimStateE['learningRate'] = lr 228 | print(('update learning rate: %f -> %f'):format(old_lr, lr)) 229 | end 230 | 231 | local function MakeIm3(im) 232 | -- print('before im_size', im:size()) 233 | local im3 = nil 234 | if im:size(2) == 1 then 235 | im3 = torch.repeatTensor(im, 1,3,1,1) 236 | else 237 | im3 = im 238 | end 239 | -- print('after im_size', im:size()) 240 | -- print('after im3_size', im3:size()) 241 | return im3 242 | end 243 | function BiGANModel:GetCurrentVisuals(opt, size) 244 | if not size then 245 | size = opt.display_winsize 246 | end 247 | 248 | local visuals = {} 249 | table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'}) 250 | table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'}) 251 | table.insert(visuals, {img=MakeIm3(self.real_B), label='real_B'}) 252 | table.insert(visuals, {img=MakeIm3(self.fake_A), label='fake_A'}) 253 | return visuals 254 | end 255 | -------------------------------------------------------------------------------- /models/content_gan_model.lua: -------------------------------------------------------------------------------- 1 | local class = require 'class' 2 | require 'models.base_model' 3 | require 'models.architectures' 4 | require 'util.image_pool' 5 | util = paths.dofile('../util/util.lua') 6 | content = paths.dofile('../util/content_loss.lua') 7 | 8 | ContentGANModel = class('ContentGANModel', 'BaseModel') 9 | 10 | function ContentGANModel:__init(conf) 11 | BaseModel.__init(self, conf) 12 | conf = conf or {} 13 | end 14 | 15 | function ContentGANModel:model_name() 16 | return 'ContentGANModel' 17 | end 18 | 19 | function ContentGANModel:InitializeStates() 20 | local optimState = {learningRate=opt.lr, beta1=opt.beta1,} 21 | return optimState 22 | end 23 | -- Defines models and networks 24 | function ContentGANModel:Initialize(opt) 25 | if opt.test == 0 then 26 | self.fakePool = ImagePool(opt.pool_size) 27 | end 28 | -- define tensors 29 | self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) 30 | self.fake_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) 31 | self.real_B = self.fake_B:clone() --torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) 32 | 33 | -- load/define models 34 | self.criterionGAN = nn.MSECriterion() 35 | self.criterionContent = nn.AbsCriterion() 36 | self.contentFunc = content.defineContent(opt.content_loss, opt.layer_name) 37 | self.netG, self.netD = nil, nil 38 | if opt.continue_train == 1 then 39 | if opt.which_epoch then -- which_epoch option exists in test mode 40 | self.netG = util.load_test_model('G_A', opt) 41 | self.netD = util.load_test_model('D_A', opt) 42 | else 43 | self.netG = util.load_model('G_A', opt) 44 | self.netD = util.load_model('D_A', opt) 45 | end 46 | else 47 | self.netG = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG) 48 | print('netG...', self.netG) 49 | self.netD = defineD(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, false) 50 | print('netD...', self.netD) 51 | end 52 | -- define real/fake labels 53 | netD_output_size = self.netD:forward(self.real_A):size() 54 | self.fake_label = torch.Tensor(netD_output_size):fill(0.0) 55 | self.real_label = torch.Tensor(netD_output_size):fill(1.0) -- no soft smoothing 56 | self.optimStateD = self:InitializeStates() 57 | self.optimStateG = self:InitializeStates() 58 | self:RefreshParameters() 59 | print('---------- # Learnable Parameters --------------') 60 | print(('G = %d'):format(self.parametersG:size(1))) 61 | print(('D = %d'):format(self.parametersD:size(1))) 62 | print('------------------------------------------------') 63 | -- os.exit() 64 | end 65 | 66 | -- Runs the forward pass of the network and 67 | -- saves the result to member variables of the class 68 | function ContentGANModel:Forward(input, opt) 69 | if opt.which_direction == 'BtoA' then 70 | local temp = input.real_A 71 | input.real_A = input.real_B 72 | input.real_B = temp 73 | end 74 | 75 | self.real_A:copy(input.real_A) 76 | self.real_B:copy(input.real_B) 77 | self.fake_B = self.netG:forward(self.real_A):clone() 78 | -- output = {self.fake_B} 79 | output = {} 80 | -- if opt.test == 1 then 81 | 82 | -- end 83 | return output 84 | end 85 | 86 | -- create closure to evaluate f(X) and df/dX of discriminator 87 | function ContentGANModel:fDx_basic(x, gradParams, netD, netG, 88 | real_target, fake_target, opt) 89 | util.BiasZero(netD) 90 | util.BiasZero(netG) 91 | gradParams:zero() 92 | 93 | local errD_real, errD_rec, errD_fake, errD = 0, 0, 0, 0 94 | -- Real log(D_A(B)) 95 | local output = netD:forward(real_target) 96 | errD_real = self.criterionGAN:forward(output, self.real_label) 97 | df_do = self.criterionGAN:backward(output, self.real_label) 98 | netD:backward(real_target, df_do) 99 | 100 | -- Fake + log(1 - D_A(G_A(A))) 101 | output = netD:forward(fake_target) 102 | errD_fake = self.criterionGAN:forward(output, self.fake_label) 103 | df_do = self.criterionGAN:backward(output, self.fake_label) 104 | netD:backward(fake_target, df_do) 105 | errD = (errD_real + errD_fake) / 2.0 106 | -- print('errD', errD 107 | return errD, gradParams 108 | end 109 | 110 | 111 | function ContentGANModel:fDx(x, opt) 112 | fake_B = self.fakePool:Query(self.fake_B) 113 | self.errD, gradParams = self:fDx_basic(x, self.gradparametersD, self.netD, self.netG, 114 | self.real_B, fake_B, opt) 115 | return self.errD, gradParams 116 | end 117 | 118 | function ContentGANModel:fGx_basic(x, netG_source, netD_source, real_source, real_target, fake_target, 119 | gradParametersG_source, opt) 120 | util.BiasZero(netD_source) 121 | util.BiasZero(netG_source) 122 | gradParametersG_source:zero() 123 | -- GAN loss 124 | -- local df_d_GAN = torch.zeros(fake_target:size()) 125 | -- local errGAN = 0 126 | -- local errRec = 0 127 | --- Domain GAN loss: D_A(G_A(A)) 128 | local output = netD_source.output -- [hack] forward was already executed in fDx, so save computation netD_source:forward(fake_B) --- 129 | local errGAN = self.criterionGAN:forward(output, self.real_label) 130 | local df_do = self.criterionGAN:backward(output, self.real_label) 131 | local df_d_GAN = netD_source:updateGradInput(fake_target, df_do) ---:narrow(2,fake_AB:size(2)-output_nc+1, output_nc) 132 | 133 | -- content loss 134 | -- print('content_loss', opt.content_loss) 135 | -- function content.lossUpdate(criterionContent, real_source, fake_target, contentFunc, loss_type, weight) 136 | local errContent, df_d_content = content.lossUpdate(self.criterionContent, real_source, fake_target, self.contentFunc, opt.content_loss, opt.lambda_A) 137 | netG_source:forward(real_source) 138 | netG_source:backward(real_source, df_d_GAN + df_d_content) 139 | -- print('errD', errGAN) 140 | return gradParametersG_source, errGAN, errContent 141 | end 142 | 143 | function ContentGANModel:fGx(x, opt) 144 | self.gradparametersG, self.errG, self.errCont = 145 | self:fGx_basic(x, self.netG, self.netD, 146 | self.real_A, self.real_B, self.fake_B, 147 | self.gradparametersG, opt) 148 | return self.errG, self.gradparametersG 149 | end 150 | 151 | function ContentGANModel:OptimizeParameters(opt) 152 | local fDx = function(x) return self:fDx(x, opt) end 153 | local fGx = function(x) return self:fGx(x, opt) end 154 | optim.adam(fDx, self.parametersD, self.optimStateD) 155 | optim.adam(fGx, self.parametersG, self.optimStateG) 156 | end 157 | 158 | function ContentGANModel:RefreshParameters() 159 | self.parametersD, self.gradparametersD = nil, nil -- nil them to avoid spiking memory 160 | self.parametersG, self.gradparametersG = nil, nil 161 | -- define parameters of optimization 162 | self.parametersG, self.gradparametersG = self.netG:getParameters() 163 | self.parametersD, self.gradparametersD = self.netD:getParameters() 164 | end 165 | 166 | function ContentGANModel:Save(prefix, opt) 167 | util.save_model(self.netG, prefix .. '_net_G_A.t7', 1.0) 168 | util.save_model(self.netD, prefix .. '_net_D_A.t7', 1.0) 169 | end 170 | 171 | function ContentGANModel:GetCurrentErrorDescription() 172 | description = ('G: %.4f D: %.4f Content: %.4f'):format(self.errG and self.errG or -1, 173 | self.errD and self.errD or -1, 174 | self.errCont and self.errCont or -1) 175 | return description 176 | end 177 | 178 | 179 | function ContentGANModel:GetCurrentErrors() 180 | local errors = {errG=self.errG and self.errG or -1, errD=self.errD and self.errD or -1, 181 | errCont=self.errCont and self.errCont or -1} 182 | return errors 183 | end 184 | 185 | -- returns a string that describes the display plot configuration 186 | function ContentGANModel:DisplayPlot(opt) 187 | return 'errG,errD,errCont' 188 | end 189 | 190 | 191 | function ContentGANModel:GetCurrentVisuals(opt, size) 192 | if not size then 193 | size = opt.display_winsize 194 | end 195 | 196 | local visuals = {} 197 | table.insert(visuals, {img=self.real_A, label='real_A'}) 198 | table.insert(visuals, {img=self.fake_B, label='fake_B'}) 199 | table.insert(visuals, {img=self.real_B, label='real_B'}) 200 | return visuals 201 | end 202 | -------------------------------------------------------------------------------- /models/cycle_gan_model.lua: -------------------------------------------------------------------------------- 1 | local class = require 'class' 2 | require 'models.base_model' 3 | require 'models.architectures' 4 | require 'util.image_pool' 5 | 6 | util = paths.dofile('../util/util.lua') 7 | CycleGANModel = class('CycleGANModel', 'BaseModel') 8 | 9 | function CycleGANModel:__init(conf) 10 | BaseModel.__init(self, conf) 11 | conf = conf or {} 12 | end 13 | 14 | function CycleGANModel:model_name() 15 | return 'CycleGANModel' 16 | end 17 | 18 | function CycleGANModel:InitializeStates(use_wgan) 19 | optimState = {learningRate=opt.lr, beta1=opt.beta1,} 20 | return optimState 21 | end 22 | -- Defines models and networks 23 | function CycleGANModel:Initialize(opt) 24 | if opt.test == 0 then 25 | self.fakeAPool = ImagePool(opt.pool_size) 26 | self.fakeBPool = ImagePool(opt.pool_size) 27 | end 28 | -- define tensors 29 | if opt.test == 0 then -- allocate tensors for training 30 | self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) 31 | self.real_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) 32 | self.fake_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) 33 | self.fake_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) 34 | self.rec_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) 35 | self.rec_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) 36 | end 37 | -- load/define models 38 | local use_lsgan = ((opt.use_lsgan ~= nil) and (opt.use_lsgan == 1)) 39 | if not use_lsgan then 40 | self.criterionGAN = nn.BCECriterion() 41 | else 42 | self.criterionGAN = nn.MSECriterion() 43 | end 44 | self.criterionRec = nn.AbsCriterion() 45 | 46 | local netG_A, netD_A, netG_B, netD_B = nil, nil, nil, nil 47 | if opt.continue_train == 1 then 48 | if opt.test == 1 then -- test mode 49 | netG_A = util.load_test_model('G_A', opt) 50 | netG_B = util.load_test_model('G_B', opt) 51 | 52 | --setup optnet to save a little bit of memory 53 | if opt.use_optnet == 1 then 54 | local sample_input = torch.randn(1, opt.input_nc, 2, 2) 55 | local optnet = require 'optnet' 56 | optnet.optimizeMemory(netG_A, sample_input, {inplace=true, reuseBuffers=true}) 57 | optnet.optimizeMemory(netG_B, sample_input, {inplace=true, reuseBuffers=true}) 58 | end 59 | else 60 | netG_A = util.load_model('G_A', opt) 61 | netG_B = util.load_model('G_B', opt) 62 | netD_A = util.load_model('D_A', opt) 63 | netD_B = util.load_model('D_B', opt) 64 | end 65 | else 66 | local use_sigmoid = (not use_lsgan) 67 | -- netG_test = defineG(opt.input_nc, opt.output_nc, opt.ngf, "resnet_unet", opt.arch) 68 | -- os.exit() 69 | netG_A = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.arch) 70 | print('netG_A...', netG_A) 71 | netD_A = defineD(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, use_sigmoid) -- no sigmoid layer 72 | print('netD_A...', netD_A) 73 | netG_B = defineG(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.arch) 74 | print('netG_B...', netG_B) 75 | netD_B = defineD(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, use_sigmoid) -- no sigmoid layer 76 | print('netD_B', netD_B) 77 | end 78 | 79 | self.netD_A = netD_A 80 | self.netG_A = netG_A 81 | self.netG_B = netG_B 82 | self.netD_B = netD_B 83 | 84 | -- define real/fake labels 85 | if opt.test == 0 then 86 | local D_A_size = self.netD_A:forward(self.real_B):size() -- hack: assume D_size_A = D_size_B 87 | self.fake_label_A = torch.Tensor(D_A_size):fill(0.0) 88 | self.real_label_A = torch.Tensor(D_A_size):fill(1.0) -- no soft smoothing 89 | local D_B_size = self.netD_B:forward(self.real_A):size() -- hack: assume D_size_A = D_size_B 90 | self.fake_label_B = torch.Tensor(D_B_size):fill(0.0) 91 | self.real_label_B = torch.Tensor(D_B_size):fill(1.0) -- no soft smoothing 92 | self.optimStateD_A = self:InitializeStates() 93 | self.optimStateG_A = self:InitializeStates() 94 | self.optimStateD_B = self:InitializeStates() 95 | self.optimStateG_B = self:InitializeStates() 96 | self:RefreshParameters() 97 | print('---------- # Learnable Parameters --------------') 98 | print(('G_A = %d'):format(self.parametersG_A:size(1))) 99 | print(('D_A = %d'):format(self.parametersD_A:size(1))) 100 | print(('G_B = %d'):format(self.parametersG_B:size(1))) 101 | print(('D_B = %d'):format(self.parametersD_B:size(1))) 102 | print('------------------------------------------------') 103 | end 104 | end 105 | 106 | -- Runs the forward pass of the network and 107 | -- saves the result to member variables of the class 108 | function CycleGANModel:Forward(input, opt) 109 | if opt.which_direction == 'BtoA' then 110 | local temp = input.real_A:clone() 111 | input.real_A = input.real_B:clone() 112 | input.real_B = temp 113 | end 114 | 115 | if opt.test == 0 then 116 | self.real_A:copy(input.real_A) 117 | self.real_B:copy(input.real_B) 118 | end 119 | 120 | if opt.test == 1 then -- forward for test 121 | if opt.gpu > 0 then 122 | self.real_A = input.real_A:cuda() 123 | self.real_B = input.real_B:cuda() 124 | else 125 | self.real_A = input.real_A:clone() 126 | self.real_B = input.real_B:clone() 127 | end 128 | self.fake_B = self.netG_A:forward(self.real_A):clone() 129 | self.fake_A = self.netG_B:forward(self.real_B):clone() 130 | self.rec_A = self.netG_B:forward(self.fake_B):clone() 131 | self.rec_B = self.netG_A:forward(self.fake_A):clone() 132 | end 133 | end 134 | 135 | -- create closure to evaluate f(X) and df/dX of discriminator 136 | function CycleGANModel:fDx_basic(x, gradParams, netD, netG, real, fake, real_label, fake_label, opt) 137 | util.BiasZero(netD) 138 | util.BiasZero(netG) 139 | gradParams:zero() 140 | -- Real log(D_A(B)) 141 | local output = netD:forward(real) 142 | local errD_real = self.criterionGAN:forward(output, real_label) 143 | local df_do = self.criterionGAN:backward(output, real_label) 144 | netD:backward(real, df_do) 145 | -- Fake + log(1 - D_A(G_A(A))) 146 | output = netD:forward(fake) 147 | local errD_fake = self.criterionGAN:forward(output, fake_label) 148 | local df_do2 = self.criterionGAN:backward(output, fake_label) 149 | netD:backward(fake, df_do2) 150 | -- Compute loss 151 | local errD = (errD_real + errD_fake) / 2.0 152 | return errD, gradParams 153 | end 154 | 155 | 156 | function CycleGANModel:fDAx(x, opt) 157 | -- use image pool that stores the old fake images 158 | fake_B = self.fakeBPool:Query(self.fake_B) 159 | self.errD_A, gradParams = self:fDx_basic(x, self.gradparametersD_A, self.netD_A, self.netG_A, 160 | self.real_B, fake_B, self.real_label_A, self.fake_label_A, opt) 161 | return self.errD_A, gradParams 162 | end 163 | 164 | 165 | function CycleGANModel:fDBx(x, opt) 166 | -- use image pool that stores the old fake images 167 | fake_A = self.fakeAPool:Query(self.fake_A) 168 | self.errD_B, gradParams = self:fDx_basic(x, self.gradparametersD_B, self.netD_B, self.netG_B, 169 | self.real_A, fake_A, self.real_label_B, self.fake_label_B, opt) 170 | return self.errD_B, gradParams 171 | end 172 | 173 | 174 | function CycleGANModel:fGx_basic(x, gradParams, netG, netD, netE, real, real2, real_label, lambda1, lambda2, opt) 175 | util.BiasZero(netD) 176 | util.BiasZero(netG) 177 | util.BiasZero(netE) -- inverse mapping 178 | gradParams:zero() 179 | 180 | -- G should be identity if real2 is fed. 181 | local errI = nil 182 | local identity = nil 183 | if opt.lambda_identity > 0 then 184 | identity = netG:forward(real2):clone() 185 | errI = self.criterionRec:forward(identity, real2) * lambda2 * opt.lambda_identity 186 | local didentity_loss_do = self.criterionRec:backward(identity, real2):mul(lambda2):mul(opt.lambda_identity) 187 | netG:backward(real2, didentity_loss_do) 188 | end 189 | 190 | --- GAN loss: D_A(G_A(A)) 191 | local fake = netG:forward(real):clone() 192 | local output = netD:forward(fake) 193 | local errG = self.criterionGAN:forward(output, real_label) 194 | local df_do1 = self.criterionGAN:backward(output, real_label) 195 | local df_d_GAN = netD:updateGradInput(fake, df_do1) -- 196 | 197 | -- forward cycle loss 198 | local rec = netE:forward(fake):clone() 199 | local errRec = self.criterionRec:forward(rec, real) * lambda1 200 | local df_do2 = self.criterionRec:backward(rec, real):mul(lambda1) 201 | local df_do_rec = netE:updateGradInput(fake, df_do2) 202 | 203 | netG:backward(real, df_d_GAN + df_do_rec) 204 | 205 | -- backward cycle loss 206 | local fake2 = netE:forward(real2)--:clone() 207 | local rec2 = netG:forward(fake2)--:clone() 208 | local errAdapt = self.criterionRec:forward(rec2, real2) * lambda2 209 | local df_do_coadapt = self.criterionRec:backward(rec2, real2):mul(lambda2) 210 | netG:backward(fake2, df_do_coadapt) 211 | 212 | return gradParams, errG, errRec, errI, fake, rec, identity 213 | end 214 | 215 | function CycleGANModel:fGAx(x, opt) 216 | self.gradparametersG_A, self.errG_A, self.errRec_A, self.errI_A, self.fake_B, self.rec_A, self.identity_B = 217 | self:fGx_basic(x, self.gradparametersG_A, self.netG_A, self.netD_A, self.netG_B, self.real_A, self.real_B, 218 | self.real_label_A, opt.lambda_A, opt.lambda_B, opt) 219 | return self.errG_A, self.gradparametersG_A 220 | end 221 | 222 | function CycleGANModel:fGBx(x, opt) 223 | self.gradparametersG_B, self.errG_B, self.errRec_B, self.errI_B, self.fake_A, self.rec_B, self.identity_A = 224 | self:fGx_basic(x, self.gradparametersG_B, self.netG_B, self.netD_B, self.netG_A, self.real_B, self.real_A, 225 | self.real_label_B, opt.lambda_B, opt.lambda_A, opt) 226 | return self.errG_B, self.gradparametersG_B 227 | end 228 | 229 | 230 | function CycleGANModel:OptimizeParameters(opt) 231 | local fDA = function(x) return self:fDAx(x, opt) end 232 | local fGA = function(x) return self:fGAx(x, opt) end 233 | local fDB = function(x) return self:fDBx(x, opt) end 234 | local fGB = function(x) return self:fGBx(x, opt) end 235 | 236 | optim.adam(fGA, self.parametersG_A, self.optimStateG_A) 237 | optim.adam(fDA, self.parametersD_A, self.optimStateD_A) 238 | optim.adam(fGB, self.parametersG_B, self.optimStateG_B) 239 | optim.adam(fDB, self.parametersD_B, self.optimStateD_B) 240 | end 241 | 242 | function CycleGANModel:RefreshParameters() 243 | self.parametersD_A, self.gradparametersD_A = nil, nil -- nil them to avoid spiking memory 244 | self.parametersG_A, self.gradparametersG_A = nil, nil 245 | self.parametersG_B, self.gradparametersG_B = nil, nil 246 | self.parametersD_B, self.gradparametersD_B = nil, nil 247 | -- define parameters of optimization 248 | self.parametersG_A, self.gradparametersG_A = self.netG_A:getParameters() 249 | self.parametersD_A, self.gradparametersD_A = self.netD_A:getParameters() 250 | self.parametersG_B, self.gradparametersG_B = self.netG_B:getParameters() 251 | self.parametersD_B, self.gradparametersD_B = self.netD_B:getParameters() 252 | end 253 | 254 | function CycleGANModel:Save(prefix, opt) 255 | util.save_model(self.netG_A, prefix .. '_net_G_A.t7', 1) 256 | util.save_model(self.netD_A, prefix .. '_net_D_A.t7', 1) 257 | util.save_model(self.netG_B, prefix .. '_net_G_B.t7', 1) 258 | util.save_model(self.netD_B, prefix .. '_net_D_B.t7', 1) 259 | end 260 | 261 | function CycleGANModel:GetCurrentErrorDescription() 262 | description = ('[A] G: %.4f D: %.4f Rec: %.4f I: %.4f || [B] G: %.4f D: %.4f Rec: %.4f I:%.4f'):format( 263 | self.errG_A and self.errG_A or -1, 264 | self.errD_A and self.errD_A or -1, 265 | self.errRec_A and self.errRec_A or -1, 266 | self.errI_A and self.errI_A or -1, 267 | self.errG_B and self.errG_B or -1, 268 | self.errD_B and self.errD_B or -1, 269 | self.errRec_B and self.errRec_B or -1, 270 | self.errI_B and self.errI_B or -1) 271 | return description 272 | end 273 | 274 | function CycleGANModel:GetCurrentErrors() 275 | local errors = {errG_A=self.errG_A, errD_A=self.errD_A, errRec_A=self.errRec_A, errI_A=self.errI_A, 276 | errG_B=self.errG_B, errD_B=self.errD_B, errRec_B=self.errRec_B, errI_B=self.errI_B} 277 | return errors 278 | end 279 | 280 | -- returns a string that describes the display plot configuration 281 | function CycleGANModel:DisplayPlot(opt) 282 | if opt.lambda_identity > 0 then 283 | return 'errG_A,errD_A,errRec_A,errI_A,errG_B,errD_B,errRec_B,errI_B' 284 | else 285 | return 'errG_A,errD_A,errRec_A,errG_B,errD_B,errRec_B' 286 | end 287 | end 288 | 289 | function CycleGANModel:UpdateLearningRate(opt) 290 | local lrd = opt.lr / opt.niter_decay 291 | local old_lr = self.optimStateD_A['learningRate'] 292 | local lr = old_lr - lrd 293 | self.optimStateD_A['learningRate'] = lr 294 | self.optimStateD_B['learningRate'] = lr 295 | self.optimStateG_A['learningRate'] = lr 296 | self.optimStateG_B['learningRate'] = lr 297 | print(('update learning rate: %f -> %f'):format(old_lr, lr)) 298 | end 299 | 300 | local function MakeIm3(im) 301 | if im:size(2) == 1 then 302 | local im3 = torch.repeatTensor(im, 1,3,1,1) 303 | return im3 304 | else 305 | return im 306 | end 307 | end 308 | 309 | function CycleGANModel:GetCurrentVisuals(opt, size) 310 | local visuals = {} 311 | table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'}) 312 | table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'}) 313 | table.insert(visuals, {img=MakeIm3(self.rec_A), label='rec_A'}) 314 | if opt.test == 0 and opt.lambda_identity > 0 then 315 | table.insert(visuals, {img=MakeIm3(self.identity_A), label='identity_A'}) 316 | end 317 | table.insert(visuals, {img=MakeIm3(self.real_B), label='real_B'}) 318 | table.insert(visuals, {img=MakeIm3(self.fake_A), label='fake_A'}) 319 | table.insert(visuals, {img=MakeIm3(self.rec_B), label='rec_B'}) 320 | if opt.test == 0 and opt.lambda_identity > 0 then 321 | table.insert(visuals, {img=MakeIm3(self.identity_B), label='identity_B'}) 322 | end 323 | return visuals 324 | end 325 | -------------------------------------------------------------------------------- /models/one_direction_test_model.lua: -------------------------------------------------------------------------------- 1 | local class = require 'class' 2 | require 'models.base_model' 3 | require 'models.architectures' 4 | require 'util.image_pool' 5 | 6 | util = paths.dofile('../util/util.lua') 7 | OneDirectionTestModel = class('OneDirectionTestModel', 'BaseModel') 8 | 9 | function OneDirectionTestModel:__init(conf) 10 | BaseModel.__init(self, conf) 11 | conf = conf or {} 12 | end 13 | 14 | function OneDirectionTestModel:model_name() 15 | return 'OneDirectionTestModel' 16 | end 17 | 18 | -- Defines models and networks 19 | function OneDirectionTestModel:Initialize(opt) 20 | -- define tensors 21 | self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) 22 | 23 | -- load/define models 24 | self.netG_A = util.load_test_model('G', opt) 25 | 26 | -- setup optnet to save a bit of memory 27 | if opt.use_optnet == 1 then 28 | local optnet = require 'optnet' 29 | local sample_input = torch.randn(1, opt.input_nc, 2, 2) 30 | optnet.optimizeMemory(self.netG_A, sample_input, {inplace=true, reuseBuffers=true}) 31 | end 32 | 33 | self:RefreshParameters() 34 | 35 | print('---------- # Learnable Parameters --------------') 36 | print(('G_A = %d'):format(self.parametersG_A:size(1))) 37 | print('------------------------------------------------') 38 | end 39 | 40 | -- Runs the forward pass of the network and 41 | -- saves the result to member variables of the class 42 | function OneDirectionTestModel:Forward(input, opt) 43 | if opt.which_direction == 'BtoA' then 44 | input.real_A = input.real_B:clone() 45 | end 46 | 47 | self.real_A = input.real_A:clone() 48 | if opt.gpu > 0 then 49 | self.real_A = self.real_A:cuda() 50 | end 51 | 52 | self.fake_B = self.netG_A:forward(self.real_A):clone() 53 | end 54 | 55 | function OneDirectionTestModel:RefreshParameters() 56 | self.parametersG_A, self.gradparametersG_A = nil, nil 57 | self.parametersG_A, self.gradparametersG_A = self.netG_A:getParameters() 58 | end 59 | 60 | 61 | local function MakeIm3(im) 62 | if im:size(2) == 1 then 63 | local im3 = torch.repeatTensor(im, 1,3,1,1) 64 | return im3 65 | else 66 | return im 67 | end 68 | end 69 | 70 | function OneDirectionTestModel:GetCurrentVisuals(opt, size) 71 | if not size then 72 | size = opt.display_winsize 73 | end 74 | 75 | local visuals = {} 76 | table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'}) 77 | table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'}) 78 | return visuals 79 | end 80 | -------------------------------------------------------------------------------- /models/pix2pix_model.lua: -------------------------------------------------------------------------------- 1 | local class = require 'class' 2 | require 'models.base_model' 3 | require 'models.architectures' 4 | require 'util.image_pool' 5 | util = paths.dofile('../util/util.lua') 6 | Pix2PixModel = class('Pix2PixModel', 'BaseModel') 7 | 8 | function Pix2PixModel:__init(conf) 9 | conf = conf or {} 10 | end 11 | 12 | -- Returns the name of the model 13 | function Pix2PixModel:model_name() 14 | return 'Pix2PixModel' 15 | end 16 | 17 | function Pix2PixModel:InitializeStates() 18 | return {learningRate=opt.lr, beta1=opt.beta1,} 19 | end 20 | 21 | -- Defines models and networks 22 | function Pix2PixModel:Initialize(opt) -- use lsgan 23 | -- define tensors 24 | local d_input_nc = opt.input_nc + opt.output_nc 25 | self.real_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize) 26 | self.fake_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize) 27 | if opt.test == 0 then 28 | self.fakeABPool = ImagePool(opt.pool_size) 29 | end 30 | -- load/define models 31 | self.criterionGAN = nn.MSECriterion() 32 | self.criterionL1 = nn.AbsCriterion() 33 | 34 | local netG, netD = nil, nil 35 | if opt.continue_train == 1 then 36 | if opt.test == 1 then -- only load model G for test 37 | netG = util.load_test_model('G', opt) 38 | else 39 | netG = util.load_model('G', opt) 40 | netD = util.load_model('D', opt) 41 | end 42 | else 43 | netG = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG) 44 | netD = defineD(d_input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, false) -- with sigmoid 45 | end 46 | 47 | self.netD = netD 48 | self.netG = netG 49 | 50 | -- define real/fake labels 51 | if opt.test == 0 then 52 | netD_output_size = self.netD:forward(self.real_AB):size() 53 | self.fake_label = torch.Tensor(netD_output_size):fill(0.0) 54 | self.real_label = torch.Tensor(netD_output_size):fill(1.0) -- no soft smoothing 55 | 56 | self.optimStateD = self:InitializeStates() 57 | self.optimStateG = self:InitializeStates() 58 | 59 | self:RefreshParameters() 60 | 61 | print('---------- # Learnable Parameters --------------') 62 | print(('G = %d'):format(self.parametersG:size(1))) 63 | print(('D = %d'):format(self.parametersD:size(1))) 64 | print('------------------------------------------------') 65 | end 66 | 67 | self.A_idx = {{}, {1, opt.input_nc}, {}, {}} 68 | self.B_idx = {{}, {opt.input_nc+1, opt.input_nc+opt.output_nc}, {}, {}} 69 | end 70 | 71 | -- Runs the forward pass of the network 72 | function Pix2PixModel:Forward(input, opt) 73 | if opt.which_direction == 'BtoA' then 74 | local temp = input.real_A 75 | input.real_A = input.real_B 76 | input.real_B = temp 77 | end 78 | 79 | if opt.test == 0 then 80 | self.real_AB[self.A_idx]:copy(input.real_A) 81 | self.real_AB[self.B_idx]:copy(input.real_B) 82 | self.real_A = self.real_AB[self.A_idx] 83 | self.real_B = self.real_AB[self.B_idx] 84 | 85 | self.fake_AB[self.A_idx]:copy(self.real_A) 86 | self.fake_B = self.netG:forward(self.real_A):clone() 87 | self.fake_AB[self.B_idx]:copy(self.fake_B) 88 | else 89 | if opt.gpu > 0 then 90 | self.real_A = input.real_A:cuda() 91 | self.real_B = input.real_B:cuda() 92 | else 93 | self.real_A = input.real_A:clone() 94 | self.real_B = input.real_B:clone() 95 | end 96 | self.fake_B = self.netG:forward(self.real_A):clone() 97 | end 98 | end 99 | 100 | -- create closure to evaluate f(X) and df/dX of discriminator 101 | function Pix2PixModel:fDx_basic(x, gradParams, netD, netG, real, fake, opt) 102 | util.BiasZero(netD) 103 | util.BiasZero(netG) 104 | gradParams:zero() 105 | 106 | -- Real log(D(B)) 107 | local output = netD:forward(real) 108 | local errD_real = self.criterionGAN:forward(output, self.real_label) 109 | local df_do = self.criterionGAN:backward(output, self.real_label) 110 | netD:backward(real, df_do) 111 | -- Fake + log(1 - D(G(A))) 112 | output = netD:forward(fake) 113 | local errD_fake = self.criterionGAN:forward(output, self.fake_label) 114 | local df_do2 = self.criterionGAN:backward(output, self.fake_label) 115 | netD:backward(fake, df_do2) 116 | -- calculate loss 117 | local errD = (errD_real + errD_fake) / 2.0 118 | return errD, gradParams 119 | end 120 | 121 | 122 | function Pix2PixModel:fDx(x, opt) 123 | fake_AB = self.fakeABPool:Query(self.fake_AB) 124 | self.errD, gradParams = self:fDx_basic(x, self.gradParametersD, self.netD, self.netG, 125 | self.real_AB, fake_AB, opt) 126 | return self.errD, gradParams 127 | end 128 | 129 | function Pix2PixModel:fGx_basic(x, netG, netD, real, fake, gradParametersG, opt) 130 | util.BiasZero(netG) 131 | util.BiasZero(netD) 132 | gradParametersG:zero() 133 | 134 | -- First. G(A) should fake the discriminator 135 | local output = netD:forward(fake) 136 | local errG = self.criterionGAN:forward(output, self.real_label) 137 | local dgan_loss_dd = self.criterionGAN:backward(output, self.real_label) 138 | local dgan_loss_do = netD:updateGradInput(fake, dgan_loss_dd) 139 | 140 | -- Second. G(A) should be close to the real 141 | real_B = real[self.B_idx] 142 | real_A = real[self.A_idx] 143 | fake_B = fake[self.B_idx] 144 | local errL1 = self.criterionL1:forward(fake_B, real_B) * opt.lambda_A 145 | local dl1_loss_do = self.criterionL1:backward(fake_B, real_B) * opt.lambda_A 146 | netG:backward(real_A, dgan_loss_do[self.B_idx] + dl1_loss_do) 147 | 148 | return gradParametersG, errG, errL1 149 | end 150 | 151 | function Pix2PixModel:fGx(x, opt) 152 | self.gradParametersG, self.errG, self.errL1 = self:fGx_basic(x, self.netG, self.netD, 153 | self.real_AB, self.fake_AB, self.gradParametersG, opt) 154 | return self.errG, self.gradParametersG 155 | end 156 | 157 | -- Runs the backprop gradient descent 158 | -- Corresponds to a single batch of data 159 | function Pix2PixModel:OptimizeParameters(opt) 160 | local fD = function(x) return self:fDx(x, opt) end 161 | local fG = function(x) return self:fGx(x, opt) end 162 | optim.adam(fD, self.parametersD, self.optimStateD) 163 | optim.adam(fG, self.parametersG, self.optimStateG) 164 | end 165 | 166 | -- This function can be used to reset momentum after each epoch 167 | function Pix2PixModel:RefreshParameters() 168 | self.parametersD, self.gradParametersD = nil, nil -- nil them to avoid spiking memory 169 | self.parametersG, self.gradParametersG = nil, nil 170 | 171 | -- define parameters of optimization 172 | self.parametersG, self.gradParametersG = self.netG:getParameters() 173 | self.parametersD, self.gradParametersD = self.netD:getParameters() 174 | end 175 | 176 | -- This function updates the learning rate; lr for the first opt.niter iterations; graduatlly decreases the lr to 0 for the next opt.niter_decay iterations 177 | function Pix2PixModel:UpdateLearningRate(opt) 178 | local lrd = opt.lr / opt.niter_decay 179 | local old_lr = self.optimStateD['learningRate'] 180 | local lr = old_lr - lrd 181 | self.optimStateD['learningRate'] = lr 182 | self.optimStateG['learningRate'] = lr 183 | print(('update learning rate: %f -> %f'):format(old_lr, lr)) 184 | end 185 | 186 | 187 | -- Save the current model to the file system 188 | function Pix2PixModel:Save(prefix, opt) 189 | util.save_model(self.netG, prefix .. '_net_G.t7', 1.0) 190 | util.save_model(self.netD, prefix .. '_net_D.t7', 1.0) 191 | end 192 | 193 | -- returns a string that describes the current errors 194 | function Pix2PixModel:GetCurrentErrorDescription() 195 | description = ('G: %.4f D: %.4f L1: %.4f'):format( 196 | self.errG and self.errG or -1, self.errD and self.errD or -1, self.errL1 and self.errL1 or -1) 197 | return description 198 | 199 | end 200 | 201 | 202 | -- returns a string that describes the display plot configuration 203 | function Pix2PixModel:DisplayPlot(opt) 204 | return 'errG,errD,errL1' 205 | end 206 | 207 | 208 | -- returns current errors 209 | function Pix2PixModel:GetCurrentErrors() 210 | local errors = {errG=self.errG, errD=self.errD, errL1=self.errL1} 211 | return errors 212 | end 213 | 214 | -- returns a table of image/label pairs that describe 215 | -- the current results. 216 | -- |return|: a table of table. List of image/label pairs 217 | function Pix2PixModel:GetCurrentVisuals(opt, size) 218 | if not size then 219 | size = opt.display_winsize 220 | end 221 | 222 | local visuals = {} 223 | table.insert(visuals, {img=self.real_A, label='real_A'}) 224 | table.insert(visuals, {img=self.fake_B, label='fake_B'}) 225 | table.insert(visuals, {img=self.real_B, label='real_B'}) 226 | 227 | return visuals 228 | end 229 | -------------------------------------------------------------------------------- /options.lua: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- 2 | -- Configure options 3 | -------------------------------------------------------------------------------- 4 | 5 | local options = {} 6 | -- options for train 7 | local opt_train = { 8 | DATA_ROOT = '', -- path to images (should have subfolders 'train', 'val', etc) 9 | batchSize = 1, -- # images in batch 10 | loadSize = 143, -- scale images to this size 11 | fineSize = 128, -- then crop to this size 12 | ngf = 64, -- # of gen filters in first conv layer 13 | ndf = 64, -- # of discrim filters in first conv layer 14 | input_nc = 3, -- # of input image channels 15 | output_nc = 3, -- # of output image channels 16 | niter = 100, -- # of iter at starting learning rate 17 | niter_decay = 100, -- # of iter to linearly decay learning rate to zero 18 | lr = 0.0002, -- initial learning rate for adam 19 | beta1 = 0.5, -- momentum term of adam 20 | ntrain = math.huge, -- # of examples per epoch. math.huge for full dataset 21 | flip = 1, -- if flip the images for data argumentation 22 | display_id = 10, -- display window id. 23 | display_winsize = 128, -- display window size 24 | display_freq = 25, -- display the current results every display_freq iterations 25 | gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X 26 | name = '', -- name of the experiment, should generally be passed on the command line 27 | which_direction = 'AtoB', -- AtoB or BtoA 28 | phase = 'train', -- train, val, test, etc 29 | nThreads = 2, -- # threads for loading data 30 | save_epoch_freq = 1, -- save a model every save_epoch_freq epochs (does not overwrite previously saved models) 31 | save_latest_freq = 5000, -- save the latest model every latest_freq sgd iterations (overwrites the previous latest model) 32 | print_freq = 50, -- print the debug information every print_freq iterations 33 | save_display_freq = 2500, -- save the current display of results every save_display_freq_iterations 34 | continue_train = 0, -- if continue training, load the latest model: 1: true, 0: false 35 | serial_batches = 0, -- if 1, takes images in order to make batches, otherwise takes them randomly 36 | checkpoints_dir = './checkpoints', -- models are saved here 37 | cache_dir = './cache', -- cache files are saved here 38 | cudnn = 1, -- set to 0 to not use cudnn 39 | which_model_netD = 'basic', -- selects model to use for netD 40 | which_model_netG = 'resnet_6blocks', -- selects model to use for netG 41 | norm = 'instance', -- batch or instance normalization 42 | n_layers_D = 3, -- only used if which_model_netD=='n_layers' 43 | content_loss = 'pixel', -- content loss type: pixel, vgg 44 | layer_name = 'pixel', -- layer used in content loss (e.g. relu4_2) 45 | lambda_A = 10.0, -- weight for cycle loss (A -> B -> A) 46 | lambda_B = 10.0, -- weight for cycle loss (B -> A -> B) 47 | model = 'cycle_gan', -- which mode to run. 'cycle_gan', 'pix2pix', 'bigan', 'content_gan' 48 | use_lsgan = 1, -- if 1, use least square GAN, if 0, use vanilla GAN 49 | align_data = 0, -- if > 0, use the dataloader for where the images are aligned 50 | pool_size = 50, -- the size of image buffer that stores previously generated images 51 | resize_or_crop = 'resize_and_crop', -- resizing/cropping strategy: resize_and_crop | crop | scale_width | scale_height 52 | lambda_identity = 0.5, -- use identity mapping. Setting opt.lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set opt.lambda_identity = 0.1 53 | use_optnet = 0, -- use optnet to save GPU memory during test 54 | } 55 | 56 | -- options for test 57 | local opt_test = { 58 | DATA_ROOT = '', -- path to images (should have subfolders 'train', 'val', etc) 59 | loadSize = 128, -- scale images to this size 60 | fineSize = 128, -- then crop to this size 61 | flip = 0, -- horizontal mirroring data augmentation 62 | display = 1, -- display samples while training. 0 = false 63 | display_id = 200, -- display window id. 64 | gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X 65 | how_many = 'all', -- how many test images to run (set to all to run on every image found in the data/phase folder) 66 | phase = 'test', -- train, val, test, etc 67 | aspect_ratio = 1.0, -- aspect ratio of result images 68 | norm = 'instance', -- batchnorm or isntance norm 69 | name = '', -- name of experiment, selects which model to run, should generally should be passed on command line 70 | input_nc = 3, -- # of input image channels 71 | output_nc = 3, -- # of output image channels 72 | serial_batches = 1, -- if 1, takes images in order to make batches, otherwise takes them randomly 73 | cudnn = 1, -- set to 0 to not use cudnn (untested) 74 | checkpoints_dir = './checkpoints', -- loads models from here 75 | cache_dir = './cache', -- cache files are saved here 76 | results_dir='./results/', -- saves results here 77 | which_epoch = 'latest', -- which epoch to test? set to 'latest' to use latest cached model 78 | model = 'cycle_gan', -- which mode to run. 'cycle_gan', 'pix2pix', 'bigan', 'content_gan'; to use pretrained model, select `one_direction_test` 79 | align_data = 0, -- if > 0, use the dataloader for pix2pix 80 | which_direction = 'AtoB', -- AtoB or BtoA 81 | resize_or_crop = 'resize_and_crop', -- resizing/cropping strategy: resize_and_crop | crop | scale_width | scale_height 82 | } 83 | 84 | -------------------------------------------------------------------------------- 85 | -- util functions 86 | -------------------------------------------------------------------------------- 87 | function options.clone(opt) 88 | local copy = {} 89 | for orig_key, orig_value in pairs(opt) do 90 | copy[orig_key] = orig_value 91 | end 92 | return copy 93 | end 94 | 95 | function options.parse_options(mode) 96 | if mode == 'train' then 97 | opt = opt_train 98 | opt.test = 0 99 | elseif mode == 'test' then 100 | opt = opt_test 101 | opt.test = 1 102 | else 103 | print("Invalid option [" .. mode .. "]") 104 | return nil 105 | end 106 | 107 | -- one-line argument parser. parses enviroment variables to override the defaults 108 | for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end 109 | if mode == 'test' then 110 | opt.nThreads = 1 111 | opt.continue_train = 1 112 | opt.batchSize = 1 -- test code only supports batchSize=1 113 | end 114 | 115 | -- print by keys 116 | keyset = {} 117 | for k,v in pairs(opt) do 118 | table.insert(keyset, k) 119 | end 120 | table.sort(keyset) 121 | print("------------------- Options -------------------") 122 | for i,k in ipairs(keyset) do 123 | print(('%+25s: %s'):format(k, opt[k])) 124 | end 125 | print("-----------------------------------------------") 126 | 127 | -- save opt to checkpoints 128 | paths.mkdir(opt.checkpoints_dir) 129 | paths.mkdir(paths.concat(opt.checkpoints_dir, opt.name)) 130 | opt.visual_dir = paths.concat(opt.checkpoints_dir, opt.name, 'visuals') 131 | paths.mkdir(opt.visual_dir) 132 | -- save opt to the disk 133 | fd = io.open(paths.concat(opt.checkpoints_dir, opt.name, 'opt_' .. mode .. '.txt'), 'w') 134 | for i,k in ipairs(keyset) do 135 | fd:write(("%+25s: %s\n"):format(k, opt[k])) 136 | end 137 | fd:close() 138 | 139 | return opt 140 | end 141 | 142 | 143 | return options 144 | -------------------------------------------------------------------------------- /pretrained_models/download_model.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | 3 | echo "Note: available models are apple2orange, facades_photo2label, map2sat, orange2apple, style_cezanne, style_ukiyoe, summer2winter_yosemite, zebra2horse, facades_label2photo, horse2zebra,monet2photo, sat2map, style_monet,style_vangogh, winter2summer_yosemite, iphone2dslr_flower" 4 | 5 | echo "Specified [$FILE]" 6 | 7 | mkdir -p ./checkpoints/${FILE}_pretrained 8 | URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/models/$FILE.t7 9 | MODEL_FILE=./checkpoints/${FILE}_pretrained/latest_net_G.t7 10 | wget -N $URL -O $MODEL_FILE 11 | -------------------------------------------------------------------------------- /pretrained_models/download_vgg.sh: -------------------------------------------------------------------------------- 1 | URL1=https://people.eecs.berkeley.edu/~taesung_park/projects/CycleGAN/models/places_vgg.caffemodel 2 | MODEL_FILE1=./models/places_vgg.caffemodel 3 | URL2=https://people.eecs.berkeley.edu/~taesung_park/projects/CycleGAN/models/places_vgg.prototxt 4 | MODEL_FILE2=./models/places_vgg.prototxt 5 | wget -N $URL1 -O $MODEL_FILE1 6 | wget -N $URL2 -O $MODEL_FILE2 7 | -------------------------------------------------------------------------------- /pretrained_models/places_vgg.prototxt: -------------------------------------------------------------------------------- 1 | name: "VGG-Places365" 2 | input: "data" 3 | input_dim: 1 4 | input_dim: 3 5 | input_dim: 224 6 | input_dim: 224 7 | layer { 8 | name: "conv1_1" 9 | type: "Convolution" 10 | bottom: "data" 11 | top: "conv1_1" 12 | param { 13 | lr_mult: 1.0 14 | decay_mult: 1.0 15 | } 16 | param { 17 | lr_mult: 2.0 18 | decay_mult: 0.0 19 | } 20 | convolution_param { 21 | num_output: 64 22 | pad: 1 23 | kernel_size: 3 24 | weight_filler { 25 | type: "gaussian" 26 | std: 0.01 27 | } 28 | bias_filler { 29 | type: "constant" 30 | value: 0.0 31 | } 32 | } 33 | } 34 | layer { 35 | name: "relu1_1" 36 | type: "ReLU" 37 | bottom: "conv1_1" 38 | top: "conv1_1" 39 | } 40 | layer { 41 | name: "conv1_2" 42 | type: "Convolution" 43 | bottom: "conv1_1" 44 | top: "conv1_2" 45 | param { 46 | lr_mult: 1.0 47 | decay_mult: 1.0 48 | } 49 | param { 50 | lr_mult: 2.0 51 | decay_mult: 0.0 52 | } 53 | convolution_param { 54 | num_output: 64 55 | pad: 1 56 | kernel_size: 3 57 | weight_filler { 58 | type: "gaussian" 59 | std: 0.01 60 | } 61 | bias_filler { 62 | type: "constant" 63 | value: 0.0 64 | } 65 | } 66 | } 67 | layer { 68 | name: "relu1_2" 69 | type: "ReLU" 70 | bottom: "conv1_2" 71 | top: "conv1_2" 72 | } 73 | layer { 74 | name: "pool1" 75 | type: "Pooling" 76 | bottom: "conv1_2" 77 | top: "pool1" 78 | pooling_param { 79 | pool: MAX 80 | kernel_size: 2 81 | stride: 2 82 | } 83 | } 84 | layer { 85 | name: "conv2_1" 86 | type: "Convolution" 87 | bottom: "pool1" 88 | top: "conv2_1" 89 | param { 90 | lr_mult: 1.0 91 | decay_mult: 1.0 92 | } 93 | param { 94 | lr_mult: 2.0 95 | decay_mult: 0.0 96 | } 97 | convolution_param { 98 | num_output: 128 99 | pad: 1 100 | kernel_size: 3 101 | weight_filler { 102 | type: "gaussian" 103 | std: 0.01 104 | } 105 | bias_filler { 106 | type: "constant" 107 | value: 0.0 108 | } 109 | } 110 | } 111 | layer { 112 | name: "relu2_1" 113 | type: "ReLU" 114 | bottom: "conv2_1" 115 | top: "conv2_1" 116 | } 117 | layer { 118 | name: "conv2_2" 119 | type: "Convolution" 120 | bottom: "conv2_1" 121 | top: "conv2_2" 122 | param { 123 | lr_mult: 1.0 124 | decay_mult: 1.0 125 | } 126 | param { 127 | lr_mult: 2.0 128 | decay_mult: 0.0 129 | } 130 | convolution_param { 131 | num_output: 128 132 | pad: 1 133 | kernel_size: 3 134 | weight_filler { 135 | type: "gaussian" 136 | std: 0.01 137 | } 138 | bias_filler { 139 | type: "constant" 140 | value: 0.0 141 | } 142 | } 143 | } 144 | layer { 145 | name: "relu2_2" 146 | type: "ReLU" 147 | bottom: "conv2_2" 148 | top: "conv2_2" 149 | } 150 | layer { 151 | name: "pool2" 152 | type: "Pooling" 153 | bottom: "conv2_2" 154 | top: "pool2" 155 | pooling_param { 156 | pool: MAX 157 | kernel_size: 2 158 | stride: 2 159 | } 160 | } 161 | layer { 162 | name: "conv3_1" 163 | type: "Convolution" 164 | bottom: "pool2" 165 | top: "conv3_1" 166 | param { 167 | lr_mult: 1.0 168 | decay_mult: 1.0 169 | } 170 | param { 171 | lr_mult: 2.0 172 | decay_mult: 0.0 173 | } 174 | convolution_param { 175 | num_output: 256 176 | pad: 1 177 | kernel_size: 3 178 | weight_filler { 179 | type: "gaussian" 180 | std: 0.01 181 | } 182 | bias_filler { 183 | type: "constant" 184 | value: 0.0 185 | } 186 | } 187 | } 188 | layer { 189 | name: "relu3_1" 190 | type: "ReLU" 191 | bottom: "conv3_1" 192 | top: "conv3_1" 193 | } 194 | layer { 195 | name: "conv3_2" 196 | type: "Convolution" 197 | bottom: "conv3_1" 198 | top: "conv3_2" 199 | param { 200 | lr_mult: 1.0 201 | decay_mult: 1.0 202 | } 203 | param { 204 | lr_mult: 2.0 205 | decay_mult: 0.0 206 | } 207 | convolution_param { 208 | num_output: 256 209 | pad: 1 210 | kernel_size: 3 211 | weight_filler { 212 | type: "gaussian" 213 | std: 0.01 214 | } 215 | bias_filler { 216 | type: "constant" 217 | value: 0.0 218 | } 219 | } 220 | } 221 | layer { 222 | name: "relu3_2" 223 | type: "ReLU" 224 | bottom: "conv3_2" 225 | top: "conv3_2" 226 | } 227 | layer { 228 | name: "conv3_3" 229 | type: "Convolution" 230 | bottom: "conv3_2" 231 | top: "conv3_3" 232 | param { 233 | lr_mult: 1.0 234 | decay_mult: 1.0 235 | } 236 | param { 237 | lr_mult: 2.0 238 | decay_mult: 0.0 239 | } 240 | convolution_param { 241 | num_output: 256 242 | pad: 1 243 | kernel_size: 3 244 | weight_filler { 245 | type: "gaussian" 246 | std: 0.01 247 | } 248 | bias_filler { 249 | type: "constant" 250 | value: 0.0 251 | } 252 | } 253 | } 254 | layer { 255 | name: "relu3_3" 256 | type: "ReLU" 257 | bottom: "conv3_3" 258 | top: "conv3_3" 259 | } 260 | layer { 261 | name: "pool3" 262 | type: "Pooling" 263 | bottom: "conv3_3" 264 | top: "pool3" 265 | pooling_param { 266 | pool: MAX 267 | kernel_size: 2 268 | stride: 2 269 | } 270 | } 271 | layer { 272 | name: "conv4_1" 273 | type: "Convolution" 274 | bottom: "pool3" 275 | top: "conv4_1" 276 | param { 277 | lr_mult: 1.0 278 | decay_mult: 1.0 279 | } 280 | param { 281 | lr_mult: 2.0 282 | decay_mult: 0.0 283 | } 284 | convolution_param { 285 | num_output: 512 286 | pad: 1 287 | kernel_size: 3 288 | weight_filler { 289 | type: "gaussian" 290 | std: 0.01 291 | } 292 | bias_filler { 293 | type: "constant" 294 | value: 0.0 295 | } 296 | } 297 | } 298 | layer { 299 | name: "relu4_1" 300 | type: "ReLU" 301 | bottom: "conv4_1" 302 | top: "conv4_1" 303 | } 304 | layer { 305 | name: "conv4_2" 306 | type: "Convolution" 307 | bottom: "conv4_1" 308 | top: "conv4_2" 309 | param { 310 | lr_mult: 1.0 311 | decay_mult: 1.0 312 | } 313 | param { 314 | lr_mult: 2.0 315 | decay_mult: 0.0 316 | } 317 | convolution_param { 318 | num_output: 512 319 | pad: 1 320 | kernel_size: 3 321 | weight_filler { 322 | type: "gaussian" 323 | std: 0.01 324 | } 325 | bias_filler { 326 | type: "constant" 327 | value: 0.0 328 | } 329 | } 330 | } 331 | layer { 332 | name: "relu4_2" 333 | type: "ReLU" 334 | bottom: "conv4_2" 335 | top: "conv4_2" 336 | } 337 | layer { 338 | name: "conv4_3" 339 | type: "Convolution" 340 | bottom: "conv4_2" 341 | top: "conv4_3" 342 | param { 343 | lr_mult: 1.0 344 | decay_mult: 1.0 345 | } 346 | param { 347 | lr_mult: 2.0 348 | decay_mult: 0.0 349 | } 350 | convolution_param { 351 | num_output: 512 352 | pad: 1 353 | kernel_size: 3 354 | weight_filler { 355 | type: "gaussian" 356 | std: 0.01 357 | } 358 | bias_filler { 359 | type: "constant" 360 | value: 0.0 361 | } 362 | } 363 | } 364 | layer { 365 | name: "relu4_3" 366 | type: "ReLU" 367 | bottom: "conv4_3" 368 | top: "conv4_3" 369 | } 370 | layer { 371 | name: "pool4" 372 | type: "Pooling" 373 | bottom: "conv4_3" 374 | top: "pool4" 375 | pooling_param { 376 | pool: MAX 377 | kernel_size: 2 378 | stride: 2 379 | } 380 | } 381 | layer { 382 | name: "conv5_1" 383 | type: "Convolution" 384 | bottom: "pool4" 385 | top: "conv5_1" 386 | param { 387 | lr_mult: 1.0 388 | decay_mult: 1.0 389 | } 390 | param { 391 | lr_mult: 2.0 392 | decay_mult: 0.0 393 | } 394 | convolution_param { 395 | num_output: 512 396 | pad: 1 397 | kernel_size: 3 398 | weight_filler { 399 | type: "gaussian" 400 | std: 0.01 401 | } 402 | bias_filler { 403 | type: "constant" 404 | value: 0.0 405 | } 406 | } 407 | } 408 | layer { 409 | name: "relu5_1" 410 | type: "ReLU" 411 | bottom: "conv5_1" 412 | top: "conv5_1" 413 | } 414 | layer { 415 | name: "conv5_2" 416 | type: "Convolution" 417 | bottom: "conv5_1" 418 | top: "conv5_2" 419 | param { 420 | lr_mult: 1.0 421 | decay_mult: 1.0 422 | } 423 | param { 424 | lr_mult: 2.0 425 | decay_mult: 0.0 426 | } 427 | convolution_param { 428 | num_output: 512 429 | pad: 1 430 | kernel_size: 3 431 | weight_filler { 432 | type: "gaussian" 433 | std: 0.01 434 | } 435 | bias_filler { 436 | type: "constant" 437 | value: 0.0 438 | } 439 | } 440 | } 441 | layer { 442 | name: "relu5_2" 443 | type: "ReLU" 444 | bottom: "conv5_2" 445 | top: "conv5_2" 446 | } 447 | layer { 448 | name: "conv5_3" 449 | type: "Convolution" 450 | bottom: "conv5_2" 451 | top: "conv5_3" 452 | param { 453 | lr_mult: 1.0 454 | decay_mult: 1.0 455 | } 456 | param { 457 | lr_mult: 2.0 458 | decay_mult: 0.0 459 | } 460 | convolution_param { 461 | num_output: 512 462 | pad: 1 463 | kernel_size: 3 464 | weight_filler { 465 | type: "gaussian" 466 | std: 0.01 467 | } 468 | bias_filler { 469 | type: "constant" 470 | value: 0.0 471 | } 472 | } 473 | } 474 | layer { 475 | name: "relu5_3" 476 | type: "ReLU" 477 | bottom: "conv5_3" 478 | top: "conv5_3" 479 | } 480 | layer { 481 | name: "pool5" 482 | type: "Pooling" 483 | bottom: "conv5_3" 484 | top: "pool5" 485 | pooling_param { 486 | pool: MAX 487 | kernel_size: 2 488 | stride: 2 489 | } 490 | } 491 | layer { 492 | name: "fc6" 493 | type: "InnerProduct" 494 | bottom: "pool5" 495 | top: "fc6" 496 | param { 497 | lr_mult: 1.0 498 | decay_mult: 1.0 499 | } 500 | param { 501 | lr_mult: 2.0 502 | decay_mult: 0.0 503 | } 504 | inner_product_param { 505 | num_output: 4096 506 | weight_filler { 507 | type: "gaussian" 508 | std: 0.01 509 | } 510 | bias_filler { 511 | type: "constant" 512 | value: 0.0 513 | } 514 | } 515 | } 516 | layer { 517 | name: "relu6" 518 | type: "ReLU" 519 | bottom: "fc6" 520 | top: "fc6" 521 | } 522 | layer { 523 | name: "drop6" 524 | type: "Dropout" 525 | bottom: "fc6" 526 | top: "fc6" 527 | dropout_param { 528 | dropout_ratio: 0.5 529 | } 530 | } 531 | layer { 532 | name: "fc7" 533 | type: "InnerProduct" 534 | bottom: "fc6" 535 | top: "fc7" 536 | param { 537 | lr_mult: 1.0 538 | decay_mult: 1.0 539 | } 540 | param { 541 | lr_mult: 2.0 542 | decay_mult: 0.0 543 | } 544 | inner_product_param { 545 | num_output: 4096 546 | weight_filler { 547 | type: "gaussian" 548 | std: 0.01 549 | } 550 | bias_filler { 551 | type: "constant" 552 | value: 0.0 553 | } 554 | } 555 | } 556 | layer { 557 | name: "relu7" 558 | type: "ReLU" 559 | bottom: "fc7" 560 | top: "fc7" 561 | } 562 | layer { 563 | name: "drop7" 564 | type: "Dropout" 565 | bottom: "fc7" 566 | top: "fc7" 567 | dropout_param { 568 | dropout_ratio: 0.5 569 | } 570 | } 571 | layer { 572 | name: "fc8a" 573 | type: "InnerProduct" 574 | bottom: "fc7" 575 | top: "fc8a" 576 | param { 577 | lr_mult: 1.0 578 | decay_mult: 1.0 579 | } 580 | param { 581 | lr_mult: 2.0 582 | decay_mult: 0.0 583 | } 584 | inner_product_param { 585 | num_output: 365 586 | } 587 | } 588 | layer { 589 | name: "prob" 590 | type: "Softmax" 591 | bottom: "fc8a" 592 | top: "prob" 593 | } 594 | -------------------------------------------------------------------------------- /test.lua: -------------------------------------------------------------------------------- 1 | -- usage: DATA_ROOT=/path/to/data/ name=expt1 which_direction=BtoA th test.lua 2 | -- 3 | -- code derived from https://github.com/soumith/dcgan.torch and https://github.com/phillipi/pix2pix 4 | require 'image' 5 | require 'nn' 6 | require 'nngraph' 7 | require 'models.architectures' 8 | 9 | 10 | util = paths.dofile('util/util.lua') 11 | options = require 'options' 12 | opt = options.parse_options('test') 13 | 14 | -- initialize torch GPU/CPU mode 15 | if opt.gpu > 0 then 16 | require 'cutorch' 17 | require 'cunn' 18 | cutorch.setDevice(opt.gpu) 19 | print ("GPU Mode") 20 | torch.setdefaulttensortype('torch.CudaTensor') 21 | else 22 | torch.setdefaulttensortype('torch.FloatTensor') 23 | print ("CPU Mode") 24 | end 25 | 26 | -- setup visualization 27 | visualizer = require 'util/visualizer' 28 | 29 | function TableConcat(t1,t2) 30 | for i=1,#t2 do 31 | t1[#t1+1] = t2[i] 32 | end 33 | return t1 34 | end 35 | 36 | 37 | -- load data 38 | local data_loader = nil 39 | if opt.align_data > 0 then 40 | require 'data.aligned_data_loader' 41 | data_loader = AlignedDataLoader() 42 | else 43 | require 'data.unaligned_data_loader' 44 | data_loader = UnalignedDataLoader() 45 | end 46 | print( "DataLoader " .. data_loader:name() .. " was created.") 47 | data_loader:Initialize(opt) 48 | 49 | if opt.how_many == 'all' then 50 | opt.how_many = data_loader:size() 51 | end 52 | 53 | opt.how_many = math.min(opt.how_many, data_loader:size()) 54 | 55 | -- set batch/instance normalization 56 | set_normalization(opt.norm) 57 | 58 | -- load model 59 | opt.continue_train = 1 60 | -- define model 61 | if opt.model == 'cycle_gan' then 62 | require 'models.cycle_gan_model' 63 | model = CycleGANModel() 64 | elseif opt.model == 'one_direction_test' then 65 | require 'models.one_direction_test_model' 66 | model = OneDirectionTestModel() 67 | elseif opt.model == 'pix2pix' then 68 | require 'models.pix2pix_model' 69 | model = Pix2PixModel() 70 | elseif opt.model == 'bigan' then 71 | require 'models.bigan_model' 72 | model = BiGANModel() 73 | elseif opt.model == 'content_gan' then 74 | require 'models.content_gan_model' 75 | model = ContentGANModel() 76 | else 77 | error('Please specify a correct model') 78 | end 79 | model:Initialize(opt) 80 | 81 | local pathsA = {} -- paths to images A tested on 82 | local pathsB = {} -- paths to images B tested on 83 | local web_dir = paths.concat(opt.results_dir, opt.name .. '/' .. opt.which_epoch .. '_' .. opt.phase) 84 | paths.mkdir(web_dir) 85 | local image_dir = paths.concat(web_dir, 'images') 86 | paths.mkdir(image_dir) 87 | s1 = opt.fineSize 88 | s2 = opt.fineSize / opt.aspect_ratio 89 | 90 | visuals = {} 91 | 92 | for n = 1, math.floor(opt.how_many) do 93 | print('processing batch ' .. n) 94 | local cur_dataA, cur_dataB, cur_pathsA, cur_pathsB = data_loader:GetNextBatch() 95 | 96 | cur_pathsA = util.basename_batch(cur_pathsA) 97 | cur_pathsB = util.basename_batch(cur_pathsB) 98 | print('pathsA', cur_pathsA) 99 | print('pathsB', cur_PathsB) 100 | model:Forward({real_A=cur_dataA, real_B=cur_dataB}, opt) 101 | 102 | visuals = model:GetCurrentVisuals(opt, opt.fineSize) 103 | 104 | for i,visual in ipairs(visuals) do 105 | if opt.resize_or_crop == 'scale_width' or opt.resize_or_crop == 'scale_height' then 106 | s1 = nil 107 | s2 = nil 108 | end 109 | visualizer.save_images(visual.img, paths.concat(image_dir, visual.label), {string.gsub(cur_pathsA[1],'.jpg','.png')}, s1, s2) 110 | end 111 | 112 | 113 | print('Saved images to: ', image_dir) 114 | pathsA = TableConcat(pathsA, cur_pathsA) 115 | pathsB = TableConcat(pathsB, cur_pathsB) 116 | end 117 | 118 | labels = {} 119 | for i,visual in ipairs(visuals) do 120 | table.insert(labels, visual.label) 121 | end 122 | 123 | -- make webpage 124 | io.output(paths.concat(web_dir, 'index.html')) 125 | io.write('') 126 | io.write('') 127 | for i = 1, #labels do 128 | io.write('') 129 | end 130 | io.write('') 131 | 132 | for n = 1,math.floor(opt.how_many) do 133 | io.write('') 134 | io.write('') 135 | for j = 1, #labels do 136 | label = labels[j] 137 | io.write('') 138 | end 139 | io.write('') 140 | end 141 | 142 | io.write('
Image ' .. labels[i] .. '
' .. tostring(n) .. '
') 143 | -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | -- usage example: DATA_ROOT=/path/to/data/ which_direction=BtoA name=expt1 th train.lua 2 | -- code derived from https://github.com/soumith/dcgan.torch and https://github.com/phillipi/pix2pix 3 | 4 | require 'torch' 5 | require 'nn' 6 | require 'optim' 7 | util = paths.dofile('util/util.lua') 8 | content = paths.dofile('util/content_loss.lua') 9 | require 'image' 10 | require 'models.architectures' 11 | 12 | -- load configuration file 13 | options = require 'options' 14 | opt = options.parse_options('train') 15 | 16 | -- setup visualization 17 | visualizer = require 'util/visualizer' 18 | 19 | -- initialize torch GPU/CPU mode 20 | if opt.gpu > 0 then 21 | require 'cutorch' 22 | require 'cunn' 23 | cutorch.setDevice(opt.gpu) 24 | print ("GPU Mode") 25 | torch.setdefaulttensortype('torch.CudaTensor') 26 | else 27 | torch.setdefaulttensortype('torch.FloatTensor') 28 | print ("CPU Mode") 29 | end 30 | 31 | -- load data 32 | local data_loader = nil 33 | if opt.align_data > 0 then 34 | require 'data.aligned_data_loader' 35 | data_loader = AlignedDataLoader() 36 | else 37 | require 'data.unaligned_data_loader' 38 | data_loader = UnalignedDataLoader() 39 | end 40 | print( "DataLoader " .. data_loader:name() .. " was created.") 41 | data_loader:Initialize(opt) 42 | 43 | -- set batch/instance normalization 44 | set_normalization(opt.norm) 45 | 46 | --- timer 47 | local epoch_tm = torch.Timer() 48 | local tm = torch.Timer() 49 | 50 | -- define model 51 | local model = nil 52 | local display_plot = nil 53 | if opt.model == 'cycle_gan' then 54 | assert(data_loader:name() == 'UnalignedDataLoader') 55 | require 'models.cycle_gan_model' 56 | model = CycleGANModel() 57 | elseif opt.model == 'pix2pix' then 58 | require 'models.pix2pix_model' 59 | assert(data_loader:name() == 'AlignedDataLoader') 60 | model = Pix2PixModel() 61 | elseif opt.model == 'bigan' then 62 | assert(data_loader:name() == 'UnalignedDataLoader') 63 | require 'models.bigan_model' 64 | model = BiGANModel() 65 | elseif opt.model == 'content_gan' then 66 | require 'models.content_gan_model' 67 | assert(data_loader:name() == 'UnalignedDataLoader') 68 | model = ContentGANModel() 69 | else 70 | error('Please specify a correct model') 71 | end 72 | 73 | -- print the model name 74 | print('Model ' .. model:model_name() .. ' was specified.') 75 | model:Initialize(opt) 76 | 77 | -- set up the loss plot 78 | require 'util/plot_util' 79 | plotUtil = PlotUtil() 80 | display_plot = model:DisplayPlot(opt) 81 | plotUtil:Initialize(display_plot, opt.display_id, opt.name) 82 | 83 | -------------------------------------------------------------------------------- 84 | -- Helper Functions 85 | -------------------------------------------------------------------------------- 86 | function visualize_current_results() 87 | local visuals = model:GetCurrentVisuals(opt) 88 | for i,visual in ipairs(visuals) do 89 | visualizer.disp_image(visual.img, opt.display_winsize, 90 | opt.display_id+i, opt.name .. ' ' .. visual.label) 91 | end 92 | end 93 | 94 | function save_current_results(epoch, counter) 95 | local visuals = model:GetCurrentVisuals(opt) 96 | for i,visual in ipairs(visuals) do 97 | output_path = paths.concat(opt.visual_dir, 'train_epoch' .. epoch .. '_iter' .. counter .. '_' .. visual.label .. '.jpg') 98 | visualizer.save_results(visual.img, output_path) 99 | end 100 | end 101 | 102 | function print_current_errors(epoch, counter_in_epoch) 103 | print(('Epoch: [%d][%8d / %8d]\t Time: %.3f DataTime: %.3f ' 104 | .. '%s'): 105 | format(epoch, ((counter_in_epoch-1) / opt.batchSize), 106 | math.floor(math.min(data_loader:size(), opt.ntrain) / opt.batchSize), 107 | tm:time().real / opt.batchSize, 108 | data_loader:time_elapsed_to_fetch_data() / opt.batchSize, 109 | model:GetCurrentErrorDescription() 110 | )) 111 | end 112 | 113 | function plot_current_errors(epoch, counter_ratio, opt) 114 | local errs = model:GetCurrentErrors(opt) 115 | local plot_vals = { epoch + counter_ratio} 116 | plotUtil:Display(plot_vals, errs) 117 | end 118 | 119 | -------------------------------------------------------------------------------- 120 | -- Main Training Loop 121 | -------------------------------------------------------------------------------- 122 | local counter = 0 123 | local num_batches = math.floor(math.min(data_loader:size(), opt.ntrain) / opt.batchSize) 124 | print('#training iterations: ' .. opt.niter+opt.niter_decay ) 125 | 126 | for epoch = 1, opt.niter+opt.niter_decay do 127 | epoch_tm:reset() 128 | for counter_in_epoch = 1, math.min(data_loader:size(), opt.ntrain), opt.batchSize do 129 | tm:reset() 130 | -- load a batch and run G on that batch 131 | local real_dataA, real_dataB, _, _ = data_loader:GetNextBatch() 132 | 133 | model:Forward({real_A=real_dataA, real_B=real_dataB}, opt) 134 | -- run forward pass 135 | opt.counter = counter 136 | -- run backward pass 137 | model:OptimizeParameters(opt) 138 | -- display on the web server 139 | if counter % opt.display_freq == 0 and opt.display_id > 0 then 140 | visualize_current_results() 141 | end 142 | 143 | -- logging 144 | if counter % opt.print_freq == 0 then 145 | print_current_errors(epoch, counter_in_epoch) 146 | plot_current_errors(epoch, counter_in_epoch/num_batches, opt) 147 | end 148 | 149 | -- save latest model 150 | if counter % opt.save_latest_freq == 0 and counter > 0 then 151 | print(('saving the latest model (epoch %d, iters %d)'):format(epoch, counter)) 152 | model:Save('latest', opt) 153 | end 154 | 155 | -- save latest results 156 | if counter % opt.save_display_freq == 0 then 157 | save_current_results(epoch, counter) 158 | end 159 | counter = counter + 1 160 | end 161 | 162 | -- save model at the end of epoch 163 | if epoch % opt.save_epoch_freq == 0 then 164 | print(('saving the model (epoch %d, iters %d)'):format(epoch, counter)) 165 | model:Save('latest', opt) 166 | model:Save(epoch, opt) 167 | end 168 | -- print the timing information after each epoch 169 | print(('End of epoch %d / %d \t Time Taken: %.3f'): 170 | format(epoch, opt.niter+opt.niter_decay, epoch_tm:time().real)) 171 | 172 | -- update learning rate 173 | if epoch > opt.niter then 174 | model:UpdateLearningRate(opt) 175 | end 176 | -- refresh parameters 177 | model:RefreshParameters(opt) 178 | end 179 | -------------------------------------------------------------------------------- /util/InstanceNormalization.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | --[[ 4 | Implements instance normalization as described in the paper 5 | 6 | Instance Normalization: The Missing Ingredient for Fast Stylization 7 | Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky 8 | https://arxiv.org/abs/1607.08022 9 | This implementation is based on 10 | https://github.com/DmitryUlyanov/texture_nets 11 | ]] 12 | 13 | local InstanceNormalization, parent = torch.class('nn.InstanceNormalization', 'nn.Module') 14 | 15 | function InstanceNormalization:__init(nOutput, eps, momentum, affine) 16 | parent.__init(self) 17 | self.running_mean = torch.zeros(nOutput) 18 | self.running_var = torch.ones(nOutput) 19 | 20 | self.eps = eps or 1e-5 21 | self.momentum = momentum or 0.0 22 | if affine ~= nil then 23 | assert(type(affine) == 'boolean', 'affine has to be true/false') 24 | self.affine = affine 25 | else 26 | self.affine = true 27 | end 28 | 29 | self.nOutput = nOutput 30 | self.prev_batch_size = -1 31 | 32 | if self.affine then 33 | self.weight = torch.Tensor(nOutput):uniform() 34 | self.bias = torch.Tensor(nOutput):zero() 35 | self.gradWeight = torch.Tensor(nOutput) 36 | self.gradBias = torch.Tensor(nOutput) 37 | end 38 | end 39 | 40 | function InstanceNormalization:updateOutput(input) 41 | self.output = self.output or input.new() 42 | assert(input:size(2) == self.nOutput) 43 | 44 | local batch_size = input:size(1) 45 | 46 | if batch_size ~= self.prev_batch_size or (self.bn and self:type() ~= self.bn:type()) then 47 | self.bn = nn.SpatialBatchNormalization(input:size(1)*input:size(2), self.eps, self.momentum, self.affine) 48 | self.bn:type(self:type()) 49 | self.bn.running_mean:copy(self.running_mean:repeatTensor(batch_size)) 50 | self.bn.running_var:copy(self.running_var:repeatTensor(batch_size)) 51 | 52 | self.prev_batch_size = input:size(1) 53 | end 54 | 55 | -- Get statistics 56 | self.running_mean:copy(self.bn.running_mean:view(input:size(1),self.nOutput):mean(1)) 57 | self.running_var:copy(self.bn.running_var:view(input:size(1),self.nOutput):mean(1)) 58 | 59 | -- Set params for BN 60 | if self.affine then 61 | self.bn.weight:copy(self.weight:repeatTensor(batch_size)) 62 | self.bn.bias:copy(self.bias:repeatTensor(batch_size)) 63 | end 64 | 65 | local input_1obj = input:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4)) 66 | self.output = self.bn:forward(input_1obj):viewAs(input) 67 | 68 | return self.output 69 | end 70 | 71 | function InstanceNormalization:updateGradInput(input, gradOutput) 72 | self.gradInput = self.gradInput or gradOutput.new() 73 | 74 | assert(self.bn) 75 | 76 | local input_1obj = input:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4)) 77 | local gradOutput_1obj = gradOutput:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4)) 78 | 79 | if self.affine then 80 | self.bn.gradWeight:zero() 81 | self.bn.gradBias:zero() 82 | end 83 | 84 | self.gradInput = self.bn:backward(input_1obj, gradOutput_1obj):viewAs(input) 85 | 86 | if self.affine then 87 | self.gradWeight:add(self.bn.gradWeight:view(input:size(1),self.nOutput):sum(1)) 88 | self.gradBias:add(self.bn.gradBias:view(input:size(1),self.nOutput):sum(1)) 89 | end 90 | return self.gradInput 91 | end 92 | 93 | function InstanceNormalization:clearState() 94 | self.output = self.output.new() 95 | self.gradInput = self.gradInput.new() 96 | 97 | if self.bn then 98 | self.bn:clearState() 99 | end 100 | end 101 | 102 | function InstanceNormalization:evaluate() 103 | end 104 | 105 | function InstanceNormalization:training() 106 | end 107 | -------------------------------------------------------------------------------- /util/VGG_preprocess.lua: -------------------------------------------------------------------------------- 1 | -- define nn module for VGG postprocessing 2 | local VGG_postprocess, parent = torch.class('nn.VGG_postprocess', 'nn.Module') 3 | 4 | function VGG_postprocess:__init() 5 | parent.__init(self) 6 | end 7 | 8 | function VGG_postprocess:updateOutput(input) 9 | self.output = input:add(1):mul(127.5) 10 | -- print(self.output:max(), self.output:min()) 11 | if self.output:max() > 255 or self.output:min() < 0 then 12 | print(self.output:min(), self.output:max()) 13 | end 14 | -- assert(self.output:min()>=0,"badly scaled inputs") 15 | -- assert(self.output:max()<=255,"badly scaled inputs") 16 | 17 | local mean_pixel = torch.FloatTensor({103.939, 116.779, 123.68}) 18 | mean_pixel = mean_pixel:reshape(1,3,1,1) 19 | mean_pixel = mean_pixel:repeatTensor(input:size(1), 1, input:size(3), input:size(4)):cuda() 20 | self.output:add(-1, mean_pixel) 21 | return self.output 22 | end 23 | 24 | function VGG_postprocess:updateGradInput(input, gradOutput) 25 | self.gradInput = gradOutput:div(127.5) 26 | return self.gradInput 27 | end 28 | -------------------------------------------------------------------------------- /util/content_loss.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | local content = {} 4 | 5 | function content.defineVGG(content_layer) 6 | local contentFunc = nn.Sequential() 7 | require 'loadcaffe' 8 | require 'util/VGG_preprocess' 9 | cnn = loadcaffe.load('../models/vgg.prototxt', '../models/vgg.caffemodel', 'cudnn') 10 | contentFunc:add(nn.SpatialUpSamplingBilinear({oheight=224, owidth=224})) 11 | contentFunc:add(nn.VGG_postprocess()) 12 | for i = 1, #cnn do 13 | local layer = cnn:get(i):clone() 14 | local name = layer.name 15 | local layer_type = torch.type(layer) 16 | contentFunc:add(layer) 17 | if name == content_layer then 18 | print("Setting up content layer: ", layer.name) 19 | break 20 | end 21 | end 22 | cnn = nil 23 | collectgarbage() 24 | print(contentFunc) 25 | return contentFunc 26 | end 27 | 28 | function content.defineAlexNet(content_layer) 29 | local contentFunc = nn.Sequential() 30 | require 'loadcaffe' 31 | require 'util/VGG_preprocess' 32 | cnn = loadcaffe.load('../models/alexnet.prototxt', '../models/alexnet.caffemodel', 'cudnn') 33 | contentFunc:add(nn.SpatialUpSamplingBilinear({oheight=224, owidth=224})) 34 | contentFunc:add(nn.VGG_postprocess()) 35 | for i = 1, #cnn do 36 | local layer = cnn:get(i):clone() 37 | local name = layer.name 38 | local layer_type = torch.type(layer) 39 | contentFunc:add(layer) 40 | if name == content_layer then 41 | print("Setting up content layer: ", layer.name) 42 | break 43 | end 44 | end 45 | cnn = nil 46 | collectgarbage() 47 | print(contentFunc) 48 | return contentFunc 49 | end 50 | 51 | 52 | 53 | function content.defineContent(content_loss, layer_name) 54 | -- print('content_loss_define', content_loss) 55 | if content_loss == 'pixel' or content_loss == 'none' then 56 | return nil 57 | elseif content_loss == 'vgg' then 58 | return content.defineVGG(layer_name) 59 | else 60 | print("unsupported content loss") 61 | return nil 62 | end 63 | end 64 | 65 | 66 | function content.lossUpdate(criterionContent, real_source, fake_target, contentFunc, loss_type, weight) 67 | if loss_type == 'none' then 68 | local errCont = 0.0 69 | local df_d_content = torch.zeros(fake_target:size()) 70 | return errCont, df_d_content 71 | elseif loss_type == 'pixel' then 72 | local errCont = criterionContent:forward(fake_target, real_source) * weight 73 | local df_do_content = criterionContent:backward(fake_target, real_source)*weight 74 | return errCont, df_do_content 75 | elseif loss_type == 'vgg' then 76 | local f_fake = contentFunc:forward(fake_target):clone() 77 | local f_real = contentFunc:forward(real_source):clone() 78 | local errCont = criterionContent:forward(f_fake, f_real) * weight 79 | local df_do_tmp = criterionContent:backward(f_fake, f_real) * weight 80 | local df_do_content = contentFunc:updateGradInput(fake_target, df_do_tmp)--:mul(weight) 81 | return errCont, df_do_content 82 | else error("unsupported content loss") 83 | end 84 | end 85 | 86 | 87 | return content 88 | -------------------------------------------------------------------------------- /util/cudnn_convert_custom.lua: -------------------------------------------------------------------------------- 1 | -- modified from https://github.com/NVIDIA/torch-cudnn/blob/master/convert.lua 2 | -- removed error on nngraph 3 | 4 | -- modules that can be converted to nn seamlessly 5 | local layer_list = { 6 | 'BatchNormalization', 7 | 'SpatialBatchNormalization', 8 | 'SpatialConvolution', 9 | 'SpatialCrossMapLRN', 10 | 'SpatialFullConvolution', 11 | 'SpatialMaxPooling', 12 | 'SpatialAveragePooling', 13 | 'ReLU', 14 | 'Tanh', 15 | 'Sigmoid', 16 | 'SoftMax', 17 | 'LogSoftMax', 18 | 'VolumetricBatchNormalization', 19 | 'VolumetricConvolution', 20 | 'VolumetricFullConvolution', 21 | 'VolumetricMaxPooling', 22 | 'VolumetricAveragePooling', 23 | } 24 | 25 | -- goes over a given net and converts all layers to dst backend 26 | -- for example: net = cudnn_convert_custom(net, cudnn) 27 | -- same as cudnn.convert with gModule check commented out 28 | function cudnn_convert_custom(net, dst, exclusion_fn) 29 | return net:replace(function(x) 30 | --if torch.type(x) == 'nn.gModule' then 31 | -- io.stderr:write('Warning: cudnn.convert does not work with nngraph yet. Ignoring nn.gModule') 32 | -- return x 33 | --end 34 | local y = 0 35 | local src = dst == nn and cudnn or nn 36 | local src_prefix = src == nn and 'nn.' or 'cudnn.' 37 | local dst_prefix = dst == nn and 'nn.' or 'cudnn.' 38 | 39 | local function convert(v) 40 | local y = {} 41 | torch.setmetatable(y, dst_prefix..v) 42 | if v == 'ReLU' then y = dst.ReLU() end -- because parameters 43 | for k,u in pairs(x) do y[k] = u end 44 | if src == cudnn and x.clearDesc then x.clearDesc(y) end 45 | if src == cudnn and v == 'SpatialAveragePooling' then 46 | y.divide = true 47 | y.count_include_pad = v.mode == 'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING' 48 | end 49 | if src == nn and string.find(v, 'Convolution') then 50 | y.groups = 1 51 | end 52 | return y 53 | end 54 | 55 | if exclusion_fn and exclusion_fn(x) then 56 | return x 57 | end 58 | local t = torch.typename(x) 59 | if t == 'nn.SpatialConvolutionMM' then 60 | y = convert('SpatialConvolution') 61 | elseif t == 'inn.SpatialCrossResponseNormalization' then 62 | y = convert('SpatialCrossMapLRN') 63 | else 64 | for i,v in ipairs(layer_list) do 65 | if torch.typename(x) == src_prefix..v then 66 | y = convert(v) 67 | end 68 | end 69 | end 70 | return y == 0 and x or y 71 | end) 72 | end 73 | -------------------------------------------------------------------------------- /util/image_pool.lua: -------------------------------------------------------------------------------- 1 | local class = require 'class' 2 | ImagePool= class('ImagePool') 3 | 4 | require 'torch' 5 | require 'image' 6 | 7 | function ImagePool:__init(pool_size) 8 | self.pool_size = pool_size 9 | if pool_size > 0 then 10 | self.num_imgs = 0 11 | self.images = {} 12 | end 13 | end 14 | 15 | function ImagePool:model_name() 16 | return 'ImagePool' 17 | end 18 | -- 19 | -- function ImagePool:Initialize(pool_size) 20 | -- -- torch.manualSeed(0) 21 | -- -- assert(pool_size > 0) 22 | -- self.pool_size = pool_size 23 | -- if pool_size > 0 then 24 | -- self.num_imgs = 0 25 | -- self.images = {} 26 | -- end 27 | -- end 28 | 29 | function ImagePool:Query(image) 30 | -- print('query image') 31 | if self.pool_size == 0 then 32 | -- print('get identical image') 33 | return image 34 | end 35 | if self.num_imgs < self.pool_size then 36 | -- self.images.insert(image:clone()) 37 | self.num_imgs = self.num_imgs + 1 38 | self.images[self.num_imgs] = image 39 | return image 40 | else 41 | local p = math.random() 42 | -- print('p' ,p) 43 | -- os.exit() 44 | if p > 0.5 then 45 | -- print('use old image') 46 | -- random_id = torch.Tensor(1) 47 | -- random_id:random(1, self.pool_size) 48 | local random_id = math.random(self.pool_size) 49 | -- print('random_id', random_id) 50 | local tmp = self.images[random_id]:clone() 51 | self.images[random_id] = image:clone() 52 | return tmp 53 | else 54 | return image 55 | end 56 | 57 | end 58 | 59 | end 60 | -------------------------------------------------------------------------------- /util/plot_util.lua: -------------------------------------------------------------------------------- 1 | local class = require 'class' 2 | PlotUtil = class('PlotUtil') 3 | 4 | 5 | require 'torch' 6 | disp = require 'display' 7 | util = require 'util/util' 8 | require 'image' 9 | 10 | local unpack = unpack or table.unpack 11 | 12 | function PlotUtil:__init(conf) 13 | conf = conf or {} 14 | end 15 | 16 | function PlotUtil:model_name() 17 | return 'PlotUtil' 18 | end 19 | 20 | function PlotUtil:Initialize(display_plot, display_id, name) 21 | self.display_plot = string.split(string.gsub(display_plot, "%s+", ""), ",") 22 | 23 | self.plot_config = { 24 | title = name .. ' loss over time', 25 | labels = {'epoch', unpack(self.display_plot)}, 26 | ylabel = 'loss', 27 | win = display_id, 28 | } 29 | 30 | self.plot_data = {} 31 | print('display_opt', self.display_plot) 32 | end 33 | 34 | 35 | function PlotUtil:Display(plot_vals, loss) 36 | for k, v in ipairs(self.display_plot) do 37 | if loss[v] ~= nil then 38 | plot_vals[#plot_vals + 1] = loss[v] 39 | end 40 | end 41 | 42 | table.insert(self.plot_data, plot_vals) 43 | disp.plot(self.plot_data, self.plot_config) 44 | end 45 | -------------------------------------------------------------------------------- /util/util.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- code derived from https://github.com/soumith/dcgan.torch 3 | -- 4 | 5 | local util = {} 6 | 7 | require 'torch' 8 | 9 | 10 | function util.BiasZero(net) 11 | net:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end) 12 | end 13 | 14 | 15 | function util.checkEqual(A, B, name) 16 | local dif = (A:float()-B:float()):abs():mean() 17 | print(name, dif) 18 | end 19 | 20 | function util.containsValue(table, value) 21 | for k, v in pairs(table) do 22 | if v == value then return true end 23 | end 24 | return false 25 | end 26 | 27 | 28 | function util.CheckTensor(A, name) 29 | print(name, A:min(), A:max(), A:mean()) 30 | end 31 | 32 | 33 | function util.normalize(img) 34 | -- rescale image to 0 .. 1 35 | local min = img:min() 36 | local max = img:max() 37 | 38 | img = torch.FloatTensor(img:size()):copy(img) 39 | img:add(-min):mul(1/(max-min)) 40 | return img 41 | end 42 | 43 | function util.normalizeBatch(batch) 44 | for i = 1, batch:size(1) do 45 | batch[i] = util.normalize(batch[i]:squeeze()) 46 | end 47 | return batch 48 | end 49 | 50 | function util.basename_batch(batch) 51 | for i = 1, #batch do 52 | batch[i] = paths.basename(batch[i]) 53 | end 54 | return batch 55 | end 56 | 57 | 58 | 59 | -- default preprocessing 60 | -- 61 | -- Preprocesses an image before passing it to a net 62 | -- Converts from RGB to BGR and rescales from [0,1] to [-1,1] 63 | function util.preprocess(img) 64 | -- RGB to BGR 65 | if img:size(1) == 3 then 66 | local perm = torch.LongTensor{3, 2, 1} 67 | img = img:index(1, perm) 68 | end 69 | -- [0,1] to [-1,1] 70 | img = img:mul(2):add(-1) 71 | 72 | -- check that input is in expected range 73 | assert(img:max()<=1,"badly scaled inputs") 74 | assert(img:min()>=-1,"badly scaled inputs") 75 | 76 | return img 77 | end 78 | 79 | -- Undo the above preprocessing. 80 | function util.deprocess(img) 81 | -- BGR to RGB 82 | if img:size(1) == 3 then 83 | local perm = torch.LongTensor{3, 2, 1} 84 | img = img:index(1, perm) 85 | end 86 | 87 | -- [-1,1] to [0,1] 88 | img = img:add(1):div(2) 89 | 90 | return img 91 | end 92 | 93 | function util.preprocess_batch(batch) 94 | for i = 1, batch:size(1) do 95 | batch[i] = util.preprocess(batch[i]:squeeze()) 96 | end 97 | return batch 98 | end 99 | 100 | function util.print_tensor(name, x) 101 | print(name, x:size(), x:min(), x:mean(), x:max()) 102 | end 103 | 104 | function util.deprocess_batch(batch) 105 | for i = 1, batch:size(1) do 106 | batch[i] = util.deprocess(batch[i]:squeeze()) 107 | end 108 | return batch 109 | end 110 | 111 | 112 | function util.scaleBatch(batch,s1,s2) 113 | -- print('s1', s1) 114 | -- print('s2', s2) 115 | local scaled_batch = torch.Tensor(batch:size(1),batch:size(2),s1,s2) 116 | for i = 1, batch:size(1) do 117 | scaled_batch[i] = image.scale(batch[i],s1,s2):squeeze() 118 | end 119 | return scaled_batch 120 | end 121 | 122 | 123 | 124 | function util.toTrivialBatch(input) 125 | return input:reshape(1,input:size(1),input:size(2),input:size(3)) 126 | end 127 | function util.fromTrivialBatch(input) 128 | return input[1] 129 | end 130 | 131 | -- input is between -1 and 1 132 | function util.jitter(input) 133 | local noise = torch.rand(input:size())/256.0 134 | input:add(1.0):mul(0.5*255.0/256.0):add(noise):add(-0.5):mul(2.0) 135 | --local scaled = (input+1.0)*0.5 136 | --local jittered = scaled*255.0/256.0 + torch.rand(input:size())/256.0 137 | --local scaled_back = (jittered-0.5)*2.0 138 | --return scaled_back 139 | end 140 | 141 | function util.scaleImage(input, loadSize) 142 | 143 | -- replicate bw images to 3 channels 144 | if input:size(1)==1 then 145 | input = torch.repeatTensor(input,3,1,1) 146 | end 147 | 148 | input = image.scale(input, loadSize, loadSize) 149 | 150 | return input 151 | end 152 | 153 | function util.getAspectRatio(path) 154 | local input = image.load(path, 3, 'float') 155 | local ar = input:size(3)/input:size(2) 156 | return ar 157 | end 158 | 159 | function util.loadImage(path, loadSize, nc) 160 | local input = image.load(path, 3, 'float') 161 | input= util.preprocess(util.scaleImage(input, loadSize)) 162 | 163 | if nc == 1 then 164 | input = input[{{1}, {}, {}}] 165 | end 166 | 167 | return input 168 | end 169 | 170 | function file_exists(filename) 171 | local f = io.open(filename,"r") 172 | if f ~= nil then io.close(f) return true else return false end 173 | end 174 | 175 | -- TO DO: loading code is rather hacky; clean it up and make sure it works on all types of nets / cpu/gpu configurations 176 | function load_helper(filename, opt) 177 | fileExists = file_exists(filename) 178 | if not fileExists then 179 | print('model not found! ' .. filename) 180 | return nil 181 | end 182 | print(('loading previously trained model (%s)'):format(filename)) 183 | if opt.norm == 'instance' then 184 | print('use InstanceNormalization') 185 | require 'util.InstanceNormalization' 186 | end 187 | 188 | if opt.cudnn>0 then 189 | require 'cudnn' 190 | end 191 | 192 | local net = torch.load(filename) 193 | if opt.gpu > 0 then 194 | require 'cunn' 195 | net:cuda() 196 | 197 | -- calling cuda on cudnn saved nngraphs doesn't change all variables to cuda, so do it below 198 | if net.forwardnodes then 199 | for i=1,#net.forwardnodes do 200 | if net.forwardnodes[i].data.module then 201 | net.forwardnodes[i].data.module:cuda() 202 | end 203 | end 204 | end 205 | 206 | else 207 | net:float() 208 | end 209 | net:apply(function(m) if m.weight then 210 | m.gradWeight = m.weight:clone():zero(); 211 | m.gradBias = m.bias:clone():zero(); end end) 212 | return net 213 | end 214 | 215 | function util.load_model(name, opt) 216 | -- if opt['lambda_'.. name] > 0.0 then 217 | -- print('not loading model '.. opt.checkpoints_dir .. opt.name .. 218 | -- 'latest_net_' .. name .. '.t7' .. ' because opt.lambda is not greater than zero') 219 | return load_helper(paths.concat(opt.checkpoints_dir, opt.name, 220 | 'latest_net_' .. name .. '.t7'), opt) 221 | -- end 222 | end 223 | 224 | function util.load_test_model(name, opt) 225 | return load_helper(paths.concat(opt.checkpoints_dir, opt.name, 226 | opt.which_epoch .. '_net_' .. name .. '.t7'), opt) 227 | end 228 | 229 | 230 | -- load dataset from the file system 231 | -- |name|: name of the dataset. It's currently either 'A' or 'B' 232 | -- function util.load_dataset(name, nc, opt, nc) 233 | -- local tensortype = torch.getdefaulttensortype() 234 | -- torch.setdefaulttensortype('torch.FloatTensor') 235 | -- 236 | -- local new_opt = options.clone(opt) 237 | -- new_opt.manualSeed = torch.random(1, 10000) -- fix seed 238 | -- new_opt.nc = nc 239 | -- torch.manualSeed(new_opt.manualSeed) 240 | -- local data_loader = paths.dofile('../data/data.lua') 241 | -- new_opt.phase = new_opt.phase .. name 242 | -- local data = data_loader.new(new_opt.nThreads, new_opt) 243 | -- print("Dataset Size " .. name .. ": ", data:size()) 244 | -- 245 | -- torch.setdefaulttensortype(tensortype) 246 | -- return data 247 | -- end 248 | 249 | 250 | 251 | function util.cudnn(net) 252 | require 'cudnn' 253 | require 'util/cudnn_convert_custom' 254 | return cudnn_convert_custom(net, cudnn) 255 | end 256 | 257 | function util.save_model(net, net_name, weight) 258 | if weight > 0.0 then 259 | torch.save(paths.concat(opt.checkpoints_dir, opt.name, net_name), net:clearState()) 260 | end 261 | end 262 | 263 | 264 | 265 | 266 | return util 267 | -------------------------------------------------------------------------------- /util/visualizer.lua: -------------------------------------------------------------------------------- 1 | ------------------------------------------------------------- 2 | -- Various utilities for visualization through the web server 3 | ------------------------------------------------------------- 4 | 5 | local visualizer = {} 6 | 7 | require 'torch' 8 | disp = nil 9 | print(opt) 10 | if opt.display_id > 0 then -- [hack]: assume that opt already existed 11 | disp = require 'display' 12 | end 13 | util = require 'util/util' 14 | require 'image' 15 | 16 | -- function visualizer 17 | function visualizer.disp_image(img_data, win_size, display_id, title) 18 | images = util.deprocess_batch(util.scaleBatch(img_data:float(),win_size,win_size)) 19 | disp.image(images, {win=display_id, title=title}) 20 | end 21 | 22 | function visualizer.save_results(img_data, output_path) 23 | local tensortype = torch.getdefaulttensortype() 24 | torch.setdefaulttensortype('torch.FloatTensor') 25 | local image_out = nil 26 | local win_size = opt.display_winsize 27 | images = torch.squeeze(util.deprocess_batch(util.scaleBatch(img_data:float(), win_size, win_size))) 28 | 29 | if images:dim() == 3 then 30 | image_out = images 31 | else 32 | for i = 1,images:size(1) do 33 | img = images[i] 34 | if image_out == nil then 35 | image_out = img 36 | else 37 | image_out = torch.cat(image_out, img) 38 | end 39 | end 40 | end 41 | image.save(output_path, image_out) 42 | torch.setdefaulttensortype(tensortype) 43 | end 44 | 45 | function visualizer.save_images(imgs, save_dir, impaths, s1, s2) 46 | local tensortype = torch.getdefaulttensortype() 47 | torch.setdefaulttensortype('torch.FloatTensor') 48 | batchSize = imgs:size(1) 49 | imgs_f = util.deprocess_batch(imgs):float() 50 | paths.mkdir(save_dir) 51 | for i = 1, batchSize do -- imgs_f[i]:size(2), imgs_f[i]:size(3)/opt.aspect_ratio 52 | if s1 ~= nil and s2 ~= nil then 53 | im_s = image.scale(imgs_f[i], s1, s2):float() 54 | else 55 | im_s = imgs_f[i]:float() 56 | end 57 | img_to_save = torch.FloatTensor(im_s:size()):copy(im_s) 58 | image.save(paths.concat(save_dir, impaths[i]), img_to_save) 59 | end 60 | torch.setdefaulttensortype(tensortype) 61 | end 62 | 63 | return visualizer 64 | --------------------------------------------------------------------------------