├── .gitignore
├── LICENSE
├── README.md
├── data
├── aligned_data_loader.lua
├── base_data_loader.lua
├── data.lua
├── data_util.lua
├── dataset.lua
├── donkey_folder.lua
└── unaligned_data_loader.lua
├── datasets
├── bibtex
│ ├── cityscapes.tex
│ └── facades.tex
├── download_dataset.sh
└── prepare_cityscapes_dataset.py
├── examples
├── test_vangogh_style_on_ae_photos.sh
└── train_maps.sh
├── imgs
├── failure_putin.jpg
├── horse2zebra.gif
├── objects.jpg
├── painting2photo.jpg
├── paper_thumbnail.jpg
├── photo2painting.jpg
├── photo_enhancement.jpg
├── season.jpg
└── teaser.jpg
├── models
├── architectures.lua
├── base_model.lua
├── bigan_model.lua
├── content_gan_model.lua
├── cycle_gan_model.lua
├── one_direction_test_model.lua
└── pix2pix_model.lua
├── options.lua
├── pretrained_models
├── download_model.sh
├── download_vgg.sh
└── places_vgg.prototxt
├── test.lua
├── train.lua
└── util
├── InstanceNormalization.lua
├── VGG_preprocess.lua
├── content_loss.lua
├── cudnn_convert_custom.lua
├── image_pool.lua
├── plot_util.lua
├── util.lua
└── visualizer.lua
/.gitignore:
--------------------------------------------------------------------------------
1 | datasets/
2 | checkpoints/
3 | results/
4 | build/
5 | dist/
6 | *.png
7 | torch.egg-info/
8 | */**/__pycache__
9 | torch/version.py
10 | torch/csrc/generic/TensorMethods.cpp
11 | torch/lib/*.so*
12 | torch/lib/*.dylib*
13 | torch/lib/*.h
14 | torch/lib/build
15 | torch/lib/tmp_install
16 | torch/lib/include
17 | torch/lib/torch_shm_manager
18 | torch/csrc/cudnn/cuDNN.cpp
19 | torch/csrc/nn/THNN.cwrap
20 | torch/csrc/nn/THNN.cpp
21 | torch/csrc/nn/THCUNN.cwrap
22 | torch/csrc/nn/THCUNN.cpp
23 | torch/csrc/nn/THNN_generic.cwrap
24 | torch/csrc/nn/THNN_generic.cpp
25 | torch/csrc/nn/THNN_generic.h
26 | docs/src/**/*
27 | test/data/legacy_modules.t7
28 | test/data/gpu_tensors.pt
29 | test/htmlcov
30 | test/.coverage
31 | */*.pyc
32 | */**/*.pyc
33 | */**/**/*.pyc
34 | */**/**/**/*.pyc
35 | */**/**/**/**/*.pyc
36 | */*.so*
37 | */**/*.so*
38 | */**/*.dylib*
39 | test/data/legacy_serialized.pt
40 | *~
41 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 |
25 |
26 | --------------------------- LICENSE FOR pix2pix --------------------------------
27 | BSD License
28 |
29 | For pix2pix software
30 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
31 | All rights reserved.
32 |
33 | Redistribution and use in source and binary forms, with or without
34 | modification, are permitted provided that the following conditions are met:
35 |
36 | * Redistributions of source code must retain the above copyright notice, this
37 | list of conditions and the following disclaimer.
38 |
39 | * Redistributions in binary form must reproduce the above copyright notice,
40 | this list of conditions and the following disclaimer in the documentation
41 | and/or other materials provided with the distribution.
42 |
43 | ----------------------------- LICENSE FOR DCGAN --------------------------------
44 | BSD License
45 |
46 | For dcgan.torch software
47 |
48 | Copyright (c) 2015, Facebook, Inc. All rights reserved.
49 |
50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
51 |
52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
53 |
54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
55 |
56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
57 |
58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
59 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # CycleGAN
6 | ### [PyTorch](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) | [project page](https://junyanz.github.io/CycleGAN/) | [paper](https://arxiv.org/pdf/1703.10593.pdf)
7 |
8 | Torch implementation for learning an image-to-image translation (i.e. [pix2pix](https://github.com/phillipi/pix2pix)) **without** input-output pairs, for example:
9 |
10 | **New**: Please check out [contrastive-unpaired-translation](https://github.com/taesungp/contrastive-unpaired-translation) (CUT), our new unpaired image-to-image translation model that enables fast and memory-efficient training.
11 |
12 |
13 |
14 | [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://junyanz.github.io/CycleGAN/)
15 | [Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz/)\*, [Taesung Park](https://taesung.me/)\*, [Phillip Isola](http://web.mit.edu/phillipi/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/)
16 | Berkeley AI Research Lab, UC Berkeley
17 | In ICCV 2017. (* equal contributions)
18 |
19 | This package includes CycleGAN, [pix2pix](https://github.com/phillipi/pix2pix), as well as other methods like [BiGAN](https://arxiv.org/abs/1605.09782)/[ALI](https://ishmaelbelghazi.github.io/ALI/) and Apple's paper [S+U learning](https://arxiv.org/pdf/1612.07828.pdf).
20 | The code was written by [Jun-Yan Zhu](https://github.com/junyanz) and [Taesung Park](https://github.com/taesung).
21 | **Update**: Please check out [PyTorch](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) implementation for CycleGAN and pix2pix.
22 | The PyTorch version is under active development and can produce results comparable or better than this Torch version.
23 |
24 | ## Other implementations:
25 |
[Tensorflow] (by Harry Yang),
26 | [Tensorflow] (by Archit Rathore),
27 | [Tensorflow] (by Van Huy),
28 | [Tensorflow] (by Xiaowei Hu),
29 | [Tensorflow-simple] (by Zhenliang He),
30 | [TensorLayer] (by luoxier),
31 | [Chainer] (by Yanghua Jin),
32 | [Minimal PyTorch] (by yunjey),
33 | [Mxnet] (by Ldpe2G),
34 | [lasagne/Keras] (by tjwei),
35 | [Keras] (by Simon Karlsson)
36 |
37 |
38 | ## Applications
39 | ### Monet Paintings to Photos
40 |
41 |
42 | ### Collection Style Transfer
43 |
44 |
45 | ### Object Transfiguration
46 |
47 |
48 | ### Season Transfer
49 |
50 |
51 | ### Photo Enhancement: Narrow depth of field
52 |
53 |
54 |
55 |
56 | ## Prerequisites
57 | - Linux or OSX
58 | - NVIDIA GPU + CUDA CuDNN (CPU mode and CUDA without CuDNN may work with minimal modification, but untested)
59 | - For MAC users, you need the Linux/GNU commands `gfind` and `gwc`, which can be installed with `brew install findutils coreutils`.
60 |
61 | ## Getting Started
62 | ### Installation
63 | - Install torch and dependencies from https://github.com/torch/distro
64 | - Install torch packages `nngraph`, `class`, `display`
65 | ```bash
66 | luarocks install nngraph
67 | luarocks install class
68 | luarocks install https://raw.githubusercontent.com/szym/display/master/display-scm-0.rockspec
69 | ```
70 | - Clone this repo:
71 | ```bash
72 | git clone https://github.com/junyanz/CycleGAN
73 | cd CycleGAN
74 | ```
75 |
76 | ### Apply a Pre-trained Model
77 | - Download the test photos (taken by [Alexei Efros](https://www.flickr.com/photos/aaefros)):
78 | ```
79 | bash ./datasets/download_dataset.sh ae_photos
80 | ```
81 | - Download the pre-trained model `style_cezanne` (For CPU model, use `style_cezanne_cpu`):
82 | ```
83 | bash ./pretrained_models/download_model.sh style_cezanne
84 | ```
85 | - Now, let's generate Paul Cézanne style images:
86 | ```
87 | DATA_ROOT=./datasets/ae_photos name=style_cezanne_pretrained model=one_direction_test phase=test loadSize=256 fineSize=256 resize_or_crop="scale_width" th test.lua
88 | ```
89 | The test results will be saved to `./results/style_cezanne_pretrained/latest_test/index.html`.
90 | Please refer to [Model Zoo](#model-zoo) for more pre-trained models.
91 | `./examples/test_vangogh_style_on_ae_photos.sh` is an example script that downloads the pretrained Van Gogh style network and runs it on Efros's photos.
92 |
93 | ### Train
94 | - Download a dataset (e.g. zebra and horse images from ImageNet):
95 | ```bash
96 | bash ./datasets/download_dataset.sh horse2zebra
97 | ```
98 | - Train a model:
99 | ```bash
100 | DATA_ROOT=./datasets/horse2zebra name=horse2zebra_model th train.lua
101 | ```
102 | - (CPU only) The same training command without using a GPU or CUDNN. Setting the environment variables ```gpu=0 cudnn=0``` forces CPU only
103 | ```bash
104 | DATA_ROOT=./datasets/horse2zebra name=horse2zebra_model gpu=0 cudnn=0 th train.lua
105 | ```
106 | - (Optionally) start the display server to view results as the model trains. (See [Display UI](#display-ui) for more details):
107 | ```bash
108 | th -ldisplay.start 8000 0.0.0.0
109 | ```
110 |
111 | ### Test
112 | - Finally, test the model:
113 | ```bash
114 | DATA_ROOT=./datasets/horse2zebra name=horse2zebra_model phase=test th test.lua
115 | ```
116 | The test results will be saved to an HTML file here: `./results/horse2zebra_model/latest_test/index.html`.
117 |
118 |
119 | ## Model Zoo
120 | Download the pre-trained models with the following script. The model will be saved to `./checkpoints/model_name/latest_net_G.t7`.
121 | ```bash
122 | bash ./pretrained_models/download_model.sh model_name
123 | ```
124 | - `orange2apple` (orange -> apple) and `apple2orange`: trained on ImageNet categories `apple` and `orange`.
125 | - `horse2zebra` (horse -> zebra) and `zebra2horse` (zebra -> horse): trained on ImageNet categories `horse` and `zebra`.
126 | - `style_monet` (landscape photo -> Monet painting style), `style_vangogh` (landscape photo -> Van Gogh painting style), `style_ukiyoe` (landscape photo -> Ukiyo-e painting style), `style_cezanne` (landscape photo -> Cezanne painting style): trained on paintings and Flickr landscape photos.
127 | - `monet2photo` (Monet paintings -> real landscape): trained on paintings and Flickr landscape photographs.
128 | - `cityscapes_photo2label` (street scene -> label) and `cityscapes_label2photo` (label -> street scene): trained on the Cityscapes dataset.
129 | - `map2sat` (map -> aerial photo) and `sat2map` (aerial photo -> map): trained on Google maps.
130 | - `iphone2dslr_flower` (iPhone photos of flowers -> DSLR photos of flowers): trained on Flickr photos.
131 |
132 | CPU models can be downloaded using:
133 | ```bash
134 | bash pretrained_models/download_model.sh _cpu
135 | ```
136 | , where `` can be `horse2zebra`, `style_monet`, etc. You just need to append `_cpu` to the target model.
137 |
138 | ## Training and Test Details
139 | To train a model,
140 | ```bash
141 | DATA_ROOT=/path/to/data/ name=expt_name th train.lua
142 | ```
143 | Models are saved to `./checkpoints/expt_name` (can be changed by passing `checkpoint_dir=your_dir` in train.lua).
144 | See `opt_train` in `options.lua` for additional training options.
145 |
146 | To test the model,
147 | ```bash
148 | DATA_ROOT=/path/to/data/ name=expt_name phase=test th test.lua
149 | ```
150 | This will run the model named `expt_name` in both directions on all images in `/path/to/data/testA` and `/path/to/data/testB`.
151 | A webpage with result images will be saved to `./results/expt_name` (can be changed by passing `results_dir=your_dir` in test.lua).
152 | See `opt_test` in `options.lua` for additional test options. Please use `model=one_direction_test` if you only would like to generate outputs of the trained network in only one direction, and specify `which_direction=AtoB` or `which_direction=BtoA` to set the direction.
153 |
154 | There are other options that can be used. For example, you can specify `resize_or_crop=crop` option to avoid resizing the image to squares. This is indeed how we trained GTA2Cityscapes model in the projet [webpage](https://junyanz.github.io/CycleGAN/) and [Cycada](https://arxiv.org/pdf/1711.03213.pdf) model. We prepared the images at 1024px resolution, and used `resize_or_crop=crop fineSize=360` to work with the cropped images of size 360x360. We also used `lambda_identity=1.0`.
155 |
156 | ## Datasets
157 | Download the datasets using the following script. Many of the datasets were collected by other researchers. Please cite their papers if you use the data.
158 | ```bash
159 | bash ./datasets/download_dataset.sh dataset_name
160 | ```
161 | - `facades`: 400 images from the [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade/). [[Citation](datasets/bibtex/facades.tex)]
162 | - `cityscapes`: 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com/). [[Citation](datasets/bibtex/cityscapes.tex)]. Note: Due to license issue, we do not host the dataset on our repo. Please download the dataset directly from the Cityscapes webpage. Please refer to `./datasets/prepare_cityscapes_dataset.py` for more detail.
163 | - `maps`: 1096 training images scraped from Google Maps.
164 | - `horse2zebra`: 939 horse images and 1177 zebra images downloaded from [ImageNet](http://www.image-net.org/) using the keywords `wild horse` and `zebra`
165 | - `apple2orange`: 996 apple images and 1020 orange images downloaded from [ImageNet](http://www.image-net.org/) using the keywords `apple` and `navel orange`.
166 | - `summer2winter_yosemite`: 1273 summer Yosemite images and 854 winter Yosemite images were downloaded using Flickr API. See more details in our paper.
167 | - `monet2photo`, `vangogh2photo`, `ukiyoe2photo`, `cezanne2photo`: The art images were downloaded from [Wikiart](https://www.wikiart.org/). The real photos are downloaded from Flickr using the combination of the tags *landscape* and *landscapephotography*. The training set size of each class is Monet:1074, Cezanne:584, Van Gogh:401, Ukiyo-e:1433, Photographs:6853.
168 | - `iphone2dslr_flower`: both classes of images were downloaded from Flickr. The training set size of each class is iPhone:1813, DSLR:3316. See more details in our paper.
169 |
170 |
171 | ## Display UI
172 | Optionally, for displaying images during training and test, use the [display package](https://github.com/szym/display).
173 |
174 | - Install it with: `luarocks install https://raw.githubusercontent.com/szym/display/master/display-scm-0.rockspec`
175 | - Then start the server with: `th -ldisplay.start`
176 | - Open this URL in your browser: [http://localhost:8000](http://localhost:8000)
177 |
178 | By default, the server listens on localhost. Pass `0.0.0.0` to allow external connections on any interface:
179 | ```bash
180 | th -ldisplay.start 8000 0.0.0.0
181 | ```
182 | Then open `http://(hostname):(port)/` in your browser to load the remote desktop.
183 |
184 | ## Setup Training and Test data
185 | To train CycleGAN model on your own datasets, you need to create a data folder with two subdirectories `trainA` and `trainB` that contain images from domain A and B. You can test your model on your training set by setting ``phase='train'`` in `test.lua`. You can also create subdirectories `testA` and `testB` if you have test data.
186 |
187 | You should **not** expect our method to work on just any random combination of input and output datasets (e.g. `cats<->keyboards`). From our experiments, we find it works better if two datasets share similar visual content. For example, `landscape painting<->landscape photographs` works much better than `portrait painting <-> landscape photographs`. `zebras<->horses` achieves compelling results while `cats<->dogs` completely fails. See the following section for more discussion.
188 |
189 |
190 | ## Failure cases
191 |
192 |
193 | Our model does not work well when the test image is rather different from the images on which the model is trained, as is the case in the figure to the left (we trained on horses and zebras without riders, but test here one a horse with a rider). See additional typical failure cases [here](https://junyanz.github.io/CycleGAN/images/failures.jpg). On translation tasks that involve color and texture changes, like many of those reported above, the method often succeeds. We have also explored tasks that require geometric changes, with little success. For example, on the task of `dog<->cat` transfiguration, the learned translation degenerates into making minimal changes to the input. We also observe a lingering gap between the results achievable with paired training data and those achieved by our unpaired method. In some cases, this gap may be very hard -- or even impossible,-- to close: for example, our method sometimes permutes the labels for tree and building in the output of the cityscapes photos->labels task.
194 |
195 |
196 |
197 | ## Citation
198 | If you use this code for your research, please cite our [paper](https://junyanz.github.io/CycleGAN/):
199 |
200 | ```
201 | @inproceedings{CycleGAN2017,
202 | title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networkss},
203 | author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A},
204 | booktitle={Computer Vision (ICCV), 2017 IEEE International Conference on},
205 | year={2017}
206 | }
207 |
208 | ```
209 |
210 |
211 | ## Related Projects:
212 | **[contrastive-unpaired-translation](https://github.com/taesungp/contrastive-unpaired-translation) (CUT)**
213 | **[pix2pix-Torch](https://github.com/phillipi/pix2pix) | [pix2pixHD](https://github.com/NVIDIA/pix2pixHD) |
214 | [BicycleGAN](https://github.com/junyanz/BicycleGAN) | [vid2vid](https://tcwang0509.github.io/vid2vid/) | [SPADE/GauGAN](https://github.com/NVlabs/SPADE)**
215 | **[iGAN](https://github.com/junyanz/iGAN) | [GAN Dissection](https://github.com/CSAILVision/GANDissect) | [GAN Paint](http://ganpaint.io/)**
216 |
217 | ## Cat Paper Collection
218 | If you love cats, and love reading cool graphics, vision, and ML papers, please check out the Cat Paper [Collection](https://github.com/junyanz/CatPapers).
219 |
220 |
221 | ## Acknowledgments
222 | Code borrows from [pix2pix](https://github.com/phillipi/pix2pix) and [DCGAN](https://github.com/soumith/dcgan.torch). The data loader is modified from [DCGAN](https://github.com/soumith/dcgan.torch) and [Context-Encoder](https://github.com/pathak22/context-encoder). The generative network is adopted from [neural-style](https://github.com/jcjohnson/neural-style) with [Instance Normalization](https://github.com/DmitryUlyanov/texture_nets/blob/master/InstanceNormalization.lua).
223 |
--------------------------------------------------------------------------------
/data/aligned_data_loader.lua:
--------------------------------------------------------------------------------
1 | --------------------------------------------------------------------------------
2 | -- Subclass of BaseDataLoader that provides data from two datasets.
3 | -- The samples from the datasets are aligned
4 | -- The datasets are of the same size
5 | --------------------------------------------------------------------------------
6 | require 'data.base_data_loader'
7 |
8 | local class = require 'class'
9 | data_util = paths.dofile('data_util.lua')
10 |
11 | AlignedDataLoader = class('AlignedDataLoader', 'BaseDataLoader')
12 |
13 | function AlignedDataLoader:__init(conf)
14 | BaseDataLoader.__init(self, conf)
15 | conf = conf or {}
16 | end
17 |
18 | function AlignedDataLoader:name()
19 | return 'AlignedDataLoader'
20 | end
21 |
22 | function AlignedDataLoader:Initialize(opt)
23 | opt.align_data = 1
24 | self.idx_A = {1, opt.input_nc}
25 | self.idx_B = {opt.input_nc+1, opt.input_nc+opt.output_nc}
26 | local nc = 3--opt.input_nc + opt.output_nc
27 | self.data = data_util.load_dataset('', opt, nc)
28 | end
29 |
30 | -- actually fetches the data
31 | -- |return|: a table of two tables, each corresponding to
32 | -- the batch for dataset A and dataset B
33 | function AlignedDataLoader:LoadBatchForAllDatasets()
34 | local batch_data, path = self.data:getBatch()
35 | local batchA = batch_data[{ {}, self.idx_A, {}, {} }]
36 | local batchB = batch_data[{ {}, self.idx_B, {}, {} }]
37 |
38 | return batchA, batchB, path, path
39 | end
40 |
41 | -- returns the size of each dataset
42 | function AlignedDataLoader:size(dataset)
43 | return self.data:size()
44 | end
45 |
--------------------------------------------------------------------------------
/data/base_data_loader.lua:
--------------------------------------------------------------------------------
1 | --------------------------------------------------------------------------------
2 | -- Base Class for Providing Data
3 | --------------------------------------------------------------------------------
4 |
5 | local class = require 'class'
6 | require 'torch'
7 |
8 | BaseDataLoader = class('BaseDataLoader')
9 |
10 | function BaseDataLoader:__init(conf)
11 | conf = conf or {}
12 | self.data_tm = torch.Timer()
13 | end
14 |
15 | function BaseDataLoader:name()
16 | return 'BaseDataLoader'
17 | end
18 |
19 | function BaseDataLoader:Initialize(opt)
20 | end
21 |
22 | -- actually fetches the data
23 | -- |return|: a table of two tables, each corresponding to
24 | -- the batch for dataset A and dataset B
25 | function BaseDataLoader:LoadBatchForAllDatasets()
26 | return {},{},{},{}
27 | end
28 |
29 | -- returns the next batch
30 | -- a wrapper of getBatch(), which is meant to be overriden by subclasses
31 | -- |return|: a table of two tables, each corresponding to
32 | -- the batch for dataset A and dataset B
33 | function BaseDataLoader:GetNextBatch()
34 | self.data_tm:reset()
35 | self.data_tm:resume()
36 | local dataA, dataB, pathA, pathB = self:LoadBatchForAllDatasets()
37 | self.data_tm:stop()
38 | return dataA, dataB, pathA, pathB
39 | end
40 |
41 | function BaseDataLoader:time_elapsed_to_fetch_data()
42 | return self.data_tm:time().real
43 | end
44 |
45 | -- returns the size of each dataset
46 | function BaseDataLoader:size(dataset)
47 | return 0
48 | end
49 |
50 |
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/data/data.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | This data loader is a modified version of the one from dcgan.torch
3 | (see https://github.com/soumith/dcgan.torch/blob/master/data/data.lua).
4 |
5 | Copyright (c) 2016, Deepak Pathak [See LICENSE file for details]
6 | ]]--
7 |
8 | local Threads = require 'threads'
9 | Threads.serialization('threads.sharedserialize')
10 |
11 | local data = {}
12 |
13 | local result = {}
14 | local unpack = unpack and unpack or table.unpack
15 |
16 | function data.new(n, opt_)
17 | opt_ = opt_ or {}
18 | local self = {}
19 | for k,v in pairs(data) do
20 | self[k] = v
21 | end
22 |
23 | local donkey_file = 'donkey_folder.lua'
24 | -- print('n..' .. n)
25 | if n > 0 then
26 | local options = opt_
27 | self.threads = Threads(n,
28 | function() require 'torch' end,
29 | function(idx)
30 | opt = options
31 | tid = idx
32 | local seed = (opt.manualSeed and opt.manualSeed or 0) + idx
33 | torch.manualSeed(seed)
34 | torch.setnumthreads(1)
35 | print(string.format('Starting donkey with id: %d seed: %d', tid, seed))
36 | assert(options, 'options not found')
37 | assert(opt, 'opt not given')
38 | print(opt)
39 | paths.dofile(donkey_file)
40 | end
41 |
42 | )
43 | else
44 | if donkey_file then paths.dofile(donkey_file) end
45 | -- print('empty threads')
46 | self.threads = {}
47 | function self.threads:addjob(f1, f2) f2(f1()) end
48 | function self.threads:dojob() end
49 | function self.threads:synchronize() end
50 | end
51 |
52 | local nSamples = 0
53 | self.threads:addjob(function() return trainLoader:size() end,
54 | function(c) nSamples = c end)
55 | self.threads:synchronize()
56 | self._size = nSamples
57 |
58 | for i = 1, n do
59 | self.threads:addjob(self._getFromThreads,
60 | self._pushResult)
61 | end
62 | -- print(self.threads)
63 | return self
64 | end
65 |
66 | function data._getFromThreads()
67 | assert(opt.batchSize, 'opt.batchSize not found')
68 | return trainLoader:sample(opt.batchSize)
69 | end
70 |
71 | function data._pushResult(...)
72 | local res = {...}
73 | if res == nil then
74 | self.threads:synchronize()
75 | end
76 | result[1] = res
77 | end
78 |
79 |
80 |
81 | function data:getBatch()
82 | -- queue another job
83 | self.threads:addjob(self._getFromThreads, self._pushResult)
84 | self.threads:dojob()
85 | local res = result[1]
86 |
87 | img_data = res[1]
88 | img_paths = res[3]
89 |
90 | result[1] = nil
91 | if torch.type(img_data) == 'table' then
92 | img_data = unpack(img_data)
93 | end
94 |
95 |
96 | return img_data, img_paths
97 | end
98 |
99 | function data:size()
100 | return self._size
101 | end
102 |
103 | return data
104 |
--------------------------------------------------------------------------------
/data/data_util.lua:
--------------------------------------------------------------------------------
1 | local data_util = {}
2 |
3 | require 'torch'
4 | -- options = require '../options.lua'
5 | -- load dataset from the file system
6 | -- |name|: name of the dataset. It's currently either 'A' or 'B'
7 | function data_util.load_dataset(name, opt, nc)
8 | local tensortype = torch.getdefaulttensortype()
9 | torch.setdefaulttensortype('torch.FloatTensor')
10 |
11 | local new_opt = options.clone(opt)
12 | new_opt.manualSeed = torch.random(1, 10000) -- fix seed
13 | new_opt.nc = nc
14 | torch.manualSeed(new_opt.manualSeed)
15 | local data_loader = paths.dofile('../data/data.lua')
16 | new_opt.phase = new_opt.phase .. name
17 | local data = data_loader.new(new_opt.nThreads, new_opt)
18 | print("Dataset Size " .. name .. ": ", data:size())
19 |
20 | torch.setdefaulttensortype(tensortype)
21 | return data
22 | end
23 |
24 | return data_util
25 |
--------------------------------------------------------------------------------
/data/dataset.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2015-present, Facebook, Inc.
3 | All rights reserved.
4 |
5 | This source code is licensed under the BSD-style license found in the
6 | LICENSE file in the root directory of this source tree. An additional grant
7 | of patent rights can be found in the PATENTS file in the same directory.
8 | ]]--
9 |
10 | require 'torch'
11 | torch.setdefaulttensortype('torch.FloatTensor')
12 | local ffi = require 'ffi'
13 | local class = require('pl.class')
14 | local dir = require 'pl.dir'
15 | local tablex = require 'pl.tablex'
16 | local argcheck = require 'argcheck'
17 | require 'sys'
18 | require 'xlua'
19 | require 'image'
20 |
21 | local dataset = torch.class('dataLoader')
22 |
23 | local initcheck = argcheck{
24 | pack=true,
25 | help=[[
26 | A dataset class for images in a flat folder structure (folder-name is class-name).
27 | Optimized for extremely large datasets (upwards of 14 million images).
28 | Tested only on Linux (as it uses command-line linux utilities to scale up)
29 | ]],
30 | {check=function(paths)
31 | local out = true;
32 | for k,v in ipairs(paths) do
33 | if type(v) ~= 'string' then
34 | print('paths can only be of string input');
35 | out = false
36 | end
37 | end
38 | return out
39 | end,
40 | name="paths",
41 | type="table",
42 | help="Multiple paths of directories with images"},
43 |
44 | {name="sampleSize",
45 | type="table",
46 | help="a consistent sample size to resize the images"},
47 |
48 | {name="split",
49 | type="number",
50 | help="Percentage of split to go to Training"
51 | },
52 | {name="serial_batches",
53 | type="number",
54 | help="if randomly sample training images"},
55 |
56 | {name="samplingMode",
57 | type="string",
58 | help="Sampling mode: random | balanced ",
59 | default = "balanced"},
60 |
61 | {name="verbose",
62 | type="boolean",
63 | help="Verbose mode during initialization",
64 | default = false},
65 |
66 | {name="loadSize",
67 | type="table",
68 | help="a size to load the images to, initially",
69 | opt = true},
70 |
71 | {name="forceClasses",
72 | type="table",
73 | help="If you want this loader to map certain classes to certain indices, "
74 | .. "pass a classes table that has {classname : classindex} pairs."
75 | .. " For example: {3 : 'dog', 5 : 'cat'}"
76 | .. "This function is very useful when you want two loaders to have the same "
77 | .. "class indices (trainLoader/testLoader for example)",
78 | opt = true},
79 |
80 | {name="sampleHookTrain",
81 | type="function",
82 | help="applied to sample during training(ex: for lighting jitter). "
83 | .. "It takes the image path as input",
84 | opt = true},
85 |
86 | {name="sampleHookTest",
87 | type="function",
88 | help="applied to sample during testing",
89 | opt = true},
90 | }
91 |
92 | function dataset:__init(...)
93 |
94 | -- argcheck
95 | local args = initcheck(...)
96 | print(args)
97 | for k,v in pairs(args) do self[k] = v end
98 |
99 | if not self.loadSize then self.loadSize = self.sampleSize; end
100 |
101 | if not self.sampleHookTrain then self.sampleHookTrain = self.defaultSampleHook end
102 | if not self.sampleHookTest then self.sampleHookTest = self.defaultSampleHook end
103 | self.image_count = 1
104 | -- print('image_count_init', self.image_count)
105 | -- find class names
106 | self.classes = {}
107 | local classPaths = {}
108 | if self.forceClasses then
109 | for k,v in pairs(self.forceClasses) do
110 | self.classes[k] = v
111 | classPaths[k] = {}
112 | end
113 | end
114 | local function tableFind(t, o) for k,v in pairs(t) do if v == o then return k end end end
115 | -- loop over each paths folder, get list of unique class names,
116 | -- also store the directory paths per class
117 | -- for each class,
118 | for k,path in ipairs(self.paths) do
119 | -- print('path', path)
120 | local dirs = {} -- hack
121 | dirs[1] = path
122 | -- local dirs = dir.getdirectories(path);
123 | for k,dirpath in ipairs(dirs) do
124 | local class = paths.basename(dirpath)
125 | local idx = tableFind(self.classes, class)
126 | -- print(class)
127 | -- print(idx)
128 | if not idx then
129 | table.insert(self.classes, class)
130 | idx = #self.classes
131 | classPaths[idx] = {}
132 | end
133 | if not tableFind(classPaths[idx], dirpath) then
134 | table.insert(classPaths[idx], dirpath);
135 | end
136 | end
137 | end
138 |
139 | self.classIndices = {}
140 | for k,v in ipairs(self.classes) do
141 | self.classIndices[v] = k
142 | end
143 |
144 | -- define command-line tools, try your best to maintain OSX compatibility
145 | local wc = 'wc'
146 | local cut = 'cut'
147 | local find = 'find -H' -- if folder name is symlink, do find inside it after dereferencing
148 |
149 | if ffi.os == 'OSX' then
150 | wc = 'gwc'
151 | cut = 'gcut'
152 | find = 'gfind'
153 | end
154 | ----------------------------------------------------------------------
155 | -- Options for the GNU find command
156 | local extensionList = {'jpg', 'png','JPG','PNG','JPEG', 'ppm', 'PPM', 'bmp', 'BMP'}
157 | local findOptions = ' -iname "*.' .. extensionList[1] .. '"'
158 | for i=2,#extensionList do
159 | findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"'
160 | end
161 |
162 | -- find the image path names
163 | self.imagePath = torch.CharTensor() -- path to each image in dataset
164 | self.imageClass = torch.LongTensor() -- class index of each image (class index in self.classes)
165 | self.classList = {} -- index of imageList to each image of a particular class
166 | self.classListSample = self.classList -- the main list used when sampling data
167 |
168 | print('running "find" on each class directory, and concatenate all'
169 | .. ' those filenames into a single file containing all image paths for a given class')
170 | -- so, generates one file per class
171 | local classFindFiles = {}
172 | for i=1,#self.classes do
173 | classFindFiles[i] = os.tmpname()
174 | end
175 | local combinedFindList = os.tmpname();
176 |
177 | local tmpfile = os.tmpname()
178 | local tmphandle = assert(io.open(tmpfile, 'w'))
179 | -- iterate over classes
180 | for i, class in ipairs(self.classes) do
181 | -- iterate over classPaths
182 | for j,path in ipairs(classPaths[i]) do
183 | local command = find .. ' "' .. path .. '" ' .. findOptions
184 | .. ' >>"' .. classFindFiles[i] .. '" \n'
185 | tmphandle:write(command)
186 | end
187 | end
188 | io.close(tmphandle)
189 | os.execute('bash ' .. tmpfile)
190 | os.execute('rm -f ' .. tmpfile)
191 |
192 | print('now combine all the files to a single large file')
193 | local tmpfile = os.tmpname()
194 | local tmphandle = assert(io.open(tmpfile, 'w'))
195 | -- concat all finds to a single large file in the order of self.classes
196 | for i=1,#self.classes do
197 | local command = 'cat "' .. classFindFiles[i] .. '" >>' .. combinedFindList .. ' \n'
198 | tmphandle:write(command)
199 | end
200 | io.close(tmphandle)
201 | os.execute('bash ' .. tmpfile)
202 | os.execute('rm -f ' .. tmpfile)
203 |
204 | --==========================================================================
205 | print('load the large concatenated list of sample paths to self.imagePath')
206 | local cmd = wc .. " -L '"
207 | .. combinedFindList .. "' |"
208 | .. cut .. " -f1 -d' '"
209 | print('cmd..' .. cmd)
210 | local maxPathLength = tonumber(sys.fexecute(wc .. " -L '"
211 | .. combinedFindList .. "' |"
212 | .. cut .. " -f1 -d' '")) + 1
213 | local length = tonumber(sys.fexecute(wc .. " -l '"
214 | .. combinedFindList .. "' |"
215 | .. cut .. " -f1 -d' '"))
216 | assert(length > 0, "Could not find any image file in the given input paths")
217 | assert(maxPathLength > 0, "paths of files are length 0?")
218 | self.imagePath:resize(length, maxPathLength):fill(0)
219 | local s_data = self.imagePath:data()
220 | local count = 0
221 | for line in io.lines(combinedFindList) do
222 | ffi.copy(s_data, line)
223 | s_data = s_data + maxPathLength
224 | if self.verbose and count % 10000 == 0 then
225 | xlua.progress(count, length)
226 | end;
227 | count = count + 1
228 | end
229 |
230 | self.numSamples = self.imagePath:size(1)
231 | if self.verbose then print(self.numSamples .. ' samples found.') end
232 | --==========================================================================
233 | print('Updating classList and imageClass appropriately')
234 | self.imageClass:resize(self.numSamples)
235 | local runningIndex = 0
236 | for i=1,#self.classes do
237 | if self.verbose then xlua.progress(i, #(self.classes)) end
238 | local length = tonumber(sys.fexecute(wc .. " -l '"
239 | .. classFindFiles[i] .. "' |"
240 | .. cut .. " -f1 -d' '"))
241 | if length == 0 then
242 | error('Class has zero samples')
243 | else
244 | self.classList[i] = torch.linspace(runningIndex + 1, runningIndex + length, length):long()
245 | self.imageClass[{{runningIndex + 1, runningIndex + length}}]:fill(i)
246 | end
247 | runningIndex = runningIndex + length
248 | end
249 |
250 | --==========================================================================
251 | -- clean up temporary files
252 | print('Cleaning up temporary files')
253 | local tmpfilelistall = ''
254 | for i=1,#(classFindFiles) do
255 | tmpfilelistall = tmpfilelistall .. ' "' .. classFindFiles[i] .. '"'
256 | if i % 1000 == 0 then
257 | os.execute('rm -f ' .. tmpfilelistall)
258 | tmpfilelistall = ''
259 | end
260 | end
261 | os.execute('rm -f ' .. tmpfilelistall)
262 | os.execute('rm -f "' .. combinedFindList .. '"')
263 | --==========================================================================
264 |
265 | if self.split == 100 then
266 | self.testIndicesSize = 0
267 | else
268 | print('Splitting training and test sets to a ratio of '
269 | .. self.split .. '/' .. (100-self.split))
270 | self.classListTrain = {}
271 | self.classListTest = {}
272 | self.classListSample = self.classListTrain
273 | local totalTestSamples = 0
274 | -- split the classList into classListTrain and classListTest
275 | for i=1,#self.classes do
276 | local list = self.classList[i]
277 | local count = self.classList[i]:size(1)
278 | local splitidx = math.floor((count * self.split / 100) + 0.5) -- +round
279 | local perm = torch.randperm(count)
280 | self.classListTrain[i] = torch.LongTensor(splitidx)
281 | for j=1,splitidx do
282 | self.classListTrain[i][j] = list[perm[j]]
283 | end
284 | if splitidx == count then -- all samples were allocated to train set
285 | self.classListTest[i] = torch.LongTensor()
286 | else
287 | self.classListTest[i] = torch.LongTensor(count-splitidx)
288 | totalTestSamples = totalTestSamples + self.classListTest[i]:size(1)
289 | local idx = 1
290 | for j=splitidx+1,count do
291 | self.classListTest[i][idx] = list[perm[j]]
292 | idx = idx + 1
293 | end
294 | end
295 | end
296 | -- Now combine classListTest into a single tensor
297 | self.testIndices = torch.LongTensor(totalTestSamples)
298 | self.testIndicesSize = totalTestSamples
299 | local tdata = self.testIndices:data()
300 | local tidx = 0
301 | for i=1,#self.classes do
302 | local list = self.classListTest[i]
303 | if list:dim() ~= 0 then
304 | local ldata = list:data()
305 | for j=0,list:size(1)-1 do
306 | tdata[tidx] = ldata[j]
307 | tidx = tidx + 1
308 | end
309 | end
310 | end
311 | end
312 | end
313 |
314 | -- size(), size(class)
315 | function dataset:size(class, list)
316 | list = list or self.classList
317 | if not class then
318 | return self.numSamples
319 | elseif type(class) == 'string' then
320 | return list[self.classIndices[class]]:size(1)
321 | elseif type(class) == 'number' then
322 | return list[class]:size(1)
323 | end
324 | end
325 |
326 | -- getByClass
327 | function dataset:getByClass(class)
328 | local index = 0
329 | if self.serial_batches == 1 then
330 | index = math.fmod(self.image_count-1, self.classListSample[class]:nElement())+1
331 | self.image_count = self.image_count +1
332 | else
333 | index = math.ceil(torch.uniform() * self.classListSample[class]:nElement())
334 | end
335 |
336 | local imgpath = ffi.string(torch.data(self.imagePath[self.classListSample[class][index]]))
337 | return self:sampleHookTrain(imgpath), imgpath
338 | end
339 |
340 | -- converts a table of samples (and corresponding labels) to a clean tensor
341 | local function tableToOutput(self, dataTable, scalarTable)
342 | local data, scalarLabels, labels
343 | if opt.resize_or_crop == 'crop' or opt.resize_or_crop == 'scale_width' or opt.resize_or_crop == 'scale_height' then
344 | assert(#scalarTable == 1)
345 | data = torch.Tensor(1,
346 | dataTable[1]:size(1), dataTable[1]:size(2), dataTable[1]:size(3))
347 | data[1]:copy(dataTable[1])
348 | scalarLabels = torch.LongTensor(#scalarTable):fill(-1111)
349 | else
350 | local quantity = #scalarTable
351 | data = torch.Tensor(quantity,
352 | self.sampleSize[1], self.sampleSize[2], self.sampleSize[3])
353 | scalarLabels = torch.LongTensor(quantity):fill(-1111)
354 | for i=1,#dataTable do
355 | data[i]:copy(dataTable[i])
356 | scalarLabels[i] = scalarTable[i]
357 | end
358 | end
359 | return data, scalarLabels
360 | end
361 |
362 | -- sampler, samples from the training set.
363 | function dataset:sample(quantity)
364 | assert(quantity)
365 | local dataTable = {}
366 | local scalarTable = {}
367 | local samplePaths = {}
368 | for i=1,quantity do
369 | local class = torch.random(1, #self.classes)
370 | local out, imgpath = self:getByClass(class)
371 | table.insert(dataTable, out)
372 | table.insert(scalarTable, class)
373 | samplePaths[i] = imgpath
374 | end
375 |
376 | local data, scalarLabels = tableToOutput(self, dataTable, scalarTable)
377 | return data, scalarLabels, samplePaths-- filePaths
378 | end
379 |
380 | function dataset:get(i1, i2)
381 | local indices = torch.range(i1, i2);
382 | local quantity = i2 - i1 + 1;
383 | assert(quantity > 0)
384 | -- now that indices has been initialized, get the samples
385 | local dataTable = {}
386 | local scalarTable = {}
387 | for i=1,quantity do
388 | -- load the sample
389 | local imgpath = ffi.string(torch.data(self.imagePath[indices[i]]))
390 | local out = self:sampleHookTest(imgpath)
391 | table.insert(dataTable, out)
392 | table.insert(scalarTable, self.imageClass[indices[i]])
393 | end
394 | local data, scalarLabels = tableToOutput(self, dataTable, scalarTable)
395 | return data, scalarLabels
396 | end
397 |
398 | return dataset
399 |
--------------------------------------------------------------------------------
/data/donkey_folder.lua:
--------------------------------------------------------------------------------
1 |
2 | --[[
3 | This data loader is a modified version of the one from dcgan.torch
4 | (see https://github.com/soumith/dcgan.torch/blob/master/data/donkey_folder.lua).
5 | Copyright (c) 2016, Deepak Pathak [See LICENSE file for details]
6 | Copyright (c) 2015-present, Facebook, Inc.
7 | All rights reserved.
8 | This source code is licensed under the BSD-style license found in the
9 | LICENSE file in the root directory of this source tree. An additional grant
10 | of patent rights can be found in the PATENTS file in the same directory.
11 | ]]--
12 |
13 | require 'image'
14 | paths.dofile('dataset.lua')
15 | -- This file contains the data-loading logic and details.
16 | -- It is run by each data-loader thread.
17 | ------------------------------------------
18 | -------- COMMON CACHES and PATHS
19 | -- Check for existence of opt.data
20 | if opt.DATA_ROOT then
21 | opt.data = paths.concat(opt.DATA_ROOT, opt.phase)
22 | else
23 | print(os.getenv('DATA_ROOT'))
24 | opt.data = paths.concat(os.getenv('DATA_ROOT'), opt.phase)
25 | end
26 |
27 | if not paths.dirp(opt.data) then
28 | error('Did not find directory: ' .. opt.data)
29 | end
30 |
31 | -- a cache file of the training metadata (if doesnt exist, will be created)
32 | local cache_prefix = opt.data:gsub('/', '_')
33 | os.execute(('mkdir -p %s'):format(opt.cache_dir))
34 | local trainCache = paths.concat(opt.cache_dir, cache_prefix .. '_trainCache.t7')
35 |
36 | --------------------------------------------------------------------------------------------
37 | local input_nc = opt.nc -- input channels
38 | local loadSize = {input_nc, opt.loadSize}
39 | local sampleSize = {input_nc, opt.fineSize}
40 |
41 | local function loadImage(path)
42 | local input = image.load(path, 3, 'float')
43 | local h = input:size(2)
44 | local w = input:size(3)
45 |
46 | local imA = image.crop(input, 0, 0, w/2, h)
47 | imA = image.scale(imA, loadSize[2], loadSize[2])
48 | local imB = image.crop(input, w/2, 0, w, h)
49 | imB = image.scale(imB, loadSize[2], loadSize[2])
50 |
51 | local perm = torch.LongTensor{3, 2, 1}
52 | imA = imA:index(1, perm)
53 | imA = imA:mul(2):add(-1)
54 | imB = imB:index(1, perm)
55 | imB = imB:mul(2):add(-1)
56 |
57 | assert(imA:max()<=1,"A: badly scaled inputs")
58 | assert(imA:min()>=-1,"A: badly scaled inputs")
59 | assert(imB:max()<=1,"B: badly scaled inputs")
60 | assert(imB:min()>=-1,"B: badly scaled inputs")
61 |
62 |
63 | local oW = sampleSize[2]
64 | local oH = sampleSize[2]
65 | local iH = imA:size(2)
66 | local iW = imA:size(3)
67 |
68 | if iH~=oH then
69 | h1 = math.ceil(torch.uniform(1e-2, iH-oH))
70 | end
71 |
72 | if iW~=oW then
73 | w1 = math.ceil(torch.uniform(1e-2, iW-oW))
74 | end
75 | if iH ~= oH or iW ~= oW then
76 | imA = image.crop(imA, w1, h1, w1 + oW, h1 + oH)
77 | imB = image.crop(imB, w1, h1, w1 + oW, h1 + oH)
78 | end
79 |
80 | if opt.flip == 1 and torch.uniform() > 0.5 then
81 | imA = image.hflip(imA)
82 | imB = image.hflip(imB)
83 | end
84 |
85 | local concatenated = torch.cat(imA,imB,1)
86 |
87 | return concatenated
88 | end
89 |
90 |
91 | local function loadSingleImage(path)
92 | local im = image.load(path, input_nc, 'float')
93 | if opt.resize_or_crop == 'resize_and_crop' then
94 | im = image.scale(im, loadSize[2], loadSize[2])
95 | end
96 | if input_nc == 3 then
97 | local perm = torch.LongTensor{3, 2, 1}
98 | im = im:index(1, perm)--:mul(256.0): brg, rgb
99 | im = im:mul(2):add(-1)
100 | end
101 | assert(im:max()<=1,"A: badly scaled inputs")
102 | assert(im:min()>=-1,"A: badly scaled inputs")
103 |
104 | local oW = sampleSize[2]
105 | local oH = sampleSize[2]
106 | local iH = im:size(2)
107 | local iW = im:size(3)
108 | if (opt.resize_or_crop == 'resize_and_crop' ) then
109 | local h1, w1 = 0, 0
110 | if iH~=oH then
111 | h1 = math.ceil(torch.uniform(1e-2, iH-oH))
112 | end
113 | if iW~=oW then
114 | w1 = math.ceil(torch.uniform(1e-2, iW-oW))
115 | end
116 | if iH ~= oH or iW ~= oW then
117 | im = image.crop(im, w1, h1, w1 + oW, h1 + oH)
118 | end
119 | elseif (opt.resize_or_crop == 'combined') then
120 | local sH = math.min(math.ceil(oH * torch.uniform(1+1e-2, 2.0-1e-2)), iH-1e-2)
121 | local sW = math.min(math.ceil(oW * torch.uniform(1+1e-2, 2.0-1e-2)), iW-1e-2)
122 | local h1 = math.ceil(torch.uniform(1e-2, iH-sH))
123 | local w1 = math.ceil(torch.uniform(1e-2, iW-sW))
124 | im = image.crop(im, w1, h1, w1 + sW, h1 + sH)
125 | im = image.scale(im, oW, oH)
126 | elseif (opt.resize_or_crop == 'crop') then
127 | local w = math.min(math.min(oH, iH),iW)
128 | w = math.floor(w/4)*4
129 | local x = math.floor(torch.uniform(0, iW - w))
130 | local y = math.floor(torch.uniform(0, iH - w))
131 | im = image.crop(im, x, y, x+w, y+w)
132 | elseif (opt.resize_or_crop == 'scale_width') then
133 | w = oW
134 | h = torch.floor(iH * oW/iW)
135 | im = image.scale(im, w, h)
136 | elseif (opt.resize_or_crop == 'scale_height') then
137 | h = oH
138 | w = torch.floor(iW * oH / iH)
139 | im = image.scale(im, w, h)
140 | end
141 |
142 | if opt.flip == 1 and torch.uniform() > 0.5 then
143 | im = image.hflip(im)
144 | end
145 |
146 | return im
147 |
148 | end
149 |
150 | -- channel-wise mean and std. Calculate or load them from disk later in the script.
151 | local mean,std
152 | --------------------------------------------------------------------------------
153 | -- Hooks that are used for each image that is loaded
154 |
155 | -- function to load the image, jitter it appropriately (random crops etc.)
156 | local trainHook_singleimage = function(self, path)
157 | collectgarbage()
158 | -- print('load single image')
159 | local im = loadSingleImage(path)
160 | return im
161 | end
162 |
163 | -- function that loads images that have juxtaposition
164 | -- of two images from two domains
165 | local trainHook_doubleimage = function(self, path)
166 | -- print('load double image')
167 | collectgarbage()
168 |
169 | local im = loadImage(path)
170 | return im
171 | end
172 |
173 |
174 | if opt.align_data > 0 then
175 | sample_nc = input_nc*2
176 | trainHook = trainHook_doubleimage
177 | else
178 | sample_nc = input_nc
179 | trainHook = trainHook_singleimage
180 | end
181 |
182 | trainLoader = dataLoader{
183 | paths = {opt.data},
184 | loadSize = {input_nc, loadSize[2], loadSize[2]},
185 | sampleSize = {sample_nc, sampleSize[2], sampleSize[2]},
186 | split = 100,
187 | serial_batches = opt.serial_batches,
188 | verbose = true
189 | }
190 |
191 | trainLoader.sampleHookTrain = trainHook
192 | collectgarbage()
193 |
194 | -- do some sanity checks on trainLoader
195 | do
196 | local class = trainLoader.imageClass
197 | local nClasses = #trainLoader.classes
198 | assert(class:max() <= nClasses, "class logic has error")
199 | assert(class:min() >= 1, "class logic has error")
200 | end
201 |
--------------------------------------------------------------------------------
/data/unaligned_data_loader.lua:
--------------------------------------------------------------------------------
1 | --------------------------------------------------------------------------------
2 | -- Subclass of BaseDataLoader that provides data from two datasets.
3 | -- The samples from the datasets are not aligned.
4 | -- The datasets can have different sizes
5 | --------------------------------------------------------------------------------
6 | require 'data.base_data_loader'
7 |
8 | local class = require 'class'
9 | data_util = paths.dofile('data_util.lua')
10 |
11 | UnalignedDataLoader = class('UnalignedDataLoader', 'BaseDataLoader')
12 |
13 | function UnalignedDataLoader:__init(conf)
14 | BaseDataLoader.__init(self, conf)
15 | conf = conf or {}
16 | end
17 |
18 | function UnalignedDataLoader:name()
19 | return 'UnalignedDataLoader'
20 | end
21 |
22 | function UnalignedDataLoader:Initialize(opt)
23 | opt.align_data = 0
24 | self.dataA = data_util.load_dataset('A', opt, opt.input_nc)
25 | self.dataB = data_util.load_dataset('B', opt, opt.output_nc)
26 | end
27 |
28 | -- actually fetches the data
29 | -- |return|: a table of two tables, each corresponding to
30 | -- the batch for dataset A and dataset B
31 | function UnalignedDataLoader:LoadBatchForAllDatasets()
32 | local batchA, pathA = self.dataA:getBatch()
33 | local batchB, pathB = self.dataB:getBatch()
34 | return batchA, batchB, pathA, pathB
35 | end
36 |
37 | -- returns the size of each dataset
38 | function UnalignedDataLoader:size(dataset)
39 | if dataset == 'A' then
40 | return self.dataA:size()
41 | end
42 |
43 | if dataset == 'B' then
44 | return self.dataB:size()
45 | end
46 |
47 | return math.max(self.dataA:size(), self.dataB:size())
48 | -- return the size of the largest dataset by default
49 | end
50 |
--------------------------------------------------------------------------------
/datasets/bibtex/cityscapes.tex:
--------------------------------------------------------------------------------
1 | @inproceedings{Cordts2016Cityscapes,
2 | title={The Cityscapes Dataset for Semantic Urban Scene Understanding},
3 | author={Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt},
4 | booktitle={Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
5 | year={2016}
6 | }
7 |
--------------------------------------------------------------------------------
/datasets/bibtex/facades.tex:
--------------------------------------------------------------------------------
1 | @INPROCEEDINGS{Tylecek13,
2 | author = {Radim Tyle{\v c}ek, Radim {\v S}{\' a}ra},
3 | title = {Spatial Pattern Templates for Recognition of Objects with Regular Structure},
4 | booktitle = {Proc. GCPR},
5 | year = {2013},
6 | address = {Saarbrucken, Germany},
7 | }
8 |
--------------------------------------------------------------------------------
/datasets/download_dataset.sh:
--------------------------------------------------------------------------------
1 | FILE=$1
2 |
3 | if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then
4 | echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
5 | exit 1
6 | fi
7 |
8 | if [[ $FILE == "cityscapes" ]]; then
9 | echo "Due to license issue, we cannot provide the Cityscapes dataset from our repository. Please download the Cityscapes dataset from https://cityscapes-dataset.com, and use the script ./datasets/prepare_cityscapes_dataset.py."
10 | echo "You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip. For further instruction, please read ./datasets/prepare_cityscapes_dataset.py"
11 | exit 1
12 | fi
13 |
14 | URL=https://efrosgans.eecs.berkeley.edu/cyclegan/datasets/$FILE.zip
15 | ZIP_FILE=./datasets/$FILE.zip
16 | TARGET_DIR=./datasets/$FILE/
17 | wget -N $URL -O $ZIP_FILE
18 | mkdir $TARGET_DIR
19 | unzip $ZIP_FILE -d ./datasets/
20 | rm $ZIP_FILE
21 |
--------------------------------------------------------------------------------
/datasets/prepare_cityscapes_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | from PIL import Image
4 |
5 | help_msg = """
6 | The dataset can be downloaded from https://cityscapes-dataset.com.
7 | Please download the datasets [gtFine_trainvaltest.zip] and [leftImg8bit_trainvaltest.zip] and unzip them.
8 | gtFine contains the semantics segmentations. Use --gtFine_dir to specify the path to the unzipped gtFine_trainvaltest directory.
9 | leftImg8bit contains the dashcam photographs. Use --leftImg8bit_dir to specify the path to the unzipped leftImg8bit_trainvaltest directory.
10 | The processed images will be placed at --output_dir.
11 |
12 | Example usage:
13 |
14 | python prepare_cityscapes_dataset.py --gitFine_dir ./gtFine/ --leftImg8bit_dir ./leftImg8bit --output_dir ./datasets/cityscapes/
15 | """
16 |
17 | def load_resized_img(path):
18 | return Image.open(path).convert('RGB').resize((256, 256))
19 |
20 | def check_matching_pair(segmap_path, photo_path):
21 | segmap_identifier = os.path.basename(segmap_path).replace('_gtFine_color', '')
22 | photo_identifier = os.path.basename(photo_path).replace('_leftImg8bit', '')
23 |
24 | assert segmap_identifier == photo_identifier, \
25 | "[%s] and [%s] don't seem to be matching. Aborting." % (segmap_path, photo_path)
26 |
27 |
28 | def process_cityscapes(gtFine_dir, leftImg8bit_dir, output_dir, phase):
29 | save_phase = 'test' if phase == 'val' else 'train'
30 | savedir = os.path.join(output_dir, save_phase)
31 | os.makedirs(savedir, exist_ok=True)
32 | os.makedirs(savedir + 'A', exist_ok=True)
33 | os.makedirs(savedir + 'B', exist_ok=True)
34 | print("Directory structure prepared at %s" % output_dir)
35 |
36 | segmap_expr = os.path.join(gtFine_dir, phase) + "/*/*_color.png"
37 | segmap_paths = glob.glob(segmap_expr)
38 | segmap_paths = sorted(segmap_paths)
39 |
40 | photo_expr = os.path.join(leftImg8bit_dir, phase) + "/*/*_leftImg8bit.png"
41 | photo_paths = glob.glob(photo_expr)
42 | photo_paths = sorted(photo_paths)
43 |
44 | assert len(segmap_paths) == len(photo_paths), \
45 | "%d images that match [%s], and %d images that match [%s]. Aborting." % (len(segmap_paths), segmap_expr, len(photo_paths), photo_expr)
46 |
47 | for i, (segmap_path, photo_path) in enumerate(zip(segmap_paths, photo_paths)):
48 | check_matching_pair(segmap_path, photo_path)
49 | segmap = load_resized_img(segmap_path)
50 | photo = load_resized_img(photo_path)
51 |
52 | # data for pix2pix where the two images are placed side-by-side
53 | sidebyside = Image.new('RGB', (512, 256))
54 | sidebyside.paste(segmap, (256, 0))
55 | sidebyside.paste(photo, (0, 0))
56 | savepath = os.path.join(savedir, "%d.jpg" % i)
57 | sidebyside.save(savepath, format='JPEG', subsampling=0, quality=100)
58 |
59 | # data for cyclegan where the two images are stored at two distinct directories
60 | savepath = os.path.join(savedir + 'A', "%d_A.jpg" % i)
61 | photo.save(savepath, format='JPEG', subsampling=0, quality=100)
62 | savepath = os.path.join(savedir + 'B', "%d_B.jpg" % i)
63 | segmap.save(savepath, format='JPEG', subsampling=0, quality=100)
64 |
65 | if i % (len(segmap_paths) // 10) == 0:
66 | print("%d / %d: last image saved at %s, " % (i, len(segmap_paths), savepath))
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 | if __name__ == '__main__':
78 | import argparse
79 | parser = argparse.ArgumentParser()
80 | parser.add_argument('--gtFine_dir', type=str, required=True,
81 | help='Path to the Cityscapes gtFine directory.')
82 | parser.add_argument('--leftImg8bit_dir', type=str, required=True,
83 | help='Path to the Cityscapes leftImg8bit_trainvaltest directory.')
84 | parser.add_argument('--output_dir', type=str, required=True,
85 | default='./datasets/cityscapes',
86 | help='Directory the output images will be written to.')
87 | opt = parser.parse_args()
88 |
89 | print(help_msg)
90 |
91 | print('Preparing Cityscapes Dataset for val phase')
92 | process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "val")
93 | print('Preparing Cityscapes Dataset for train phase')
94 | process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "train")
95 |
96 | print('Done')
97 |
98 |
99 |
100 |
--------------------------------------------------------------------------------
/examples/test_vangogh_style_on_ae_photos.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | ## This script download the dataset and pre-trained network,
4 | ## and generates style transferred images.
5 |
6 | # Download the dataset. The downloaded dataset is stored in ./datasets/${DATASET_NAME}
7 | DATASET_NAME='ae_photos'
8 | bash ./datasets/download_dataset.sh $DATASET_NAME
9 |
10 | # Download the pre-trained model. The downloaded model is stored in ./models/${MODEL_NAME}_pretrained/latest_net_G.t7
11 | MODEL_NAME='style_vangogh'
12 | bash ./pretrained_models/download_model.sh $MODEL_NAME
13 |
14 | # Run style transfer using the downloaded dataset and model
15 | DATA_ROOT=./datasets/$DATASET_NAME name=${MODEL_NAME}_pretrained model=one_direction_test phase=test how_many='all' loadSize=256 fineSize=256 resize_or_crop='scale_width' th test.lua
16 |
17 | if [ $? == 0 ]; then
18 | echo "The result can be viewed at ./results/${MODEL_NAME}_pretrained/latest_test/index.html"
19 | fi
20 |
--------------------------------------------------------------------------------
/examples/train_maps.sh:
--------------------------------------------------------------------------------
1 | DB_NAME='maps'
2 | GPU_ID=1
3 | DISPLAY_ID=1
4 | NET_G=resnet_6blocks
5 | NET_D=basic
6 | MODEL=cycle_gan
7 | SAVE_EPOCH=5
8 | ALIGN_DATA=0
9 | LAMBDA=10
10 | NF=64
11 |
12 |
13 | EXPR_NAME=${DB_NAME}_${MODEL}_${LAMBDA}
14 |
15 | CHECKPOINT_DIR=./checkpoints/
16 | LOG_FILE=${CHECKPOINT_DIR}${EXPR_NAME}/log.txt
17 | mkdir -p ${CHECKPOINT_DIR}${EXPR_NAME}
18 |
19 | DATA_ROOT=./datasets/$DB_NAME align_data=$ALIGN_DATA use_lsgan=1 \
20 | which_direction='AtoB' display_plot=$PLOT pool_size=50 niter=100 niter_decay=100 \
21 | which_model_netG=$NET_G which_model_netD=$NET_D model=$MODEL lr=0.0002 print_freq=200 lambda_A=$LAMBDA lambda_B=$LAMBDA \
22 | loadSize=143 fineSize=128 gpu=$GPU_ID display_winsize=128 \
23 | name=$EXPR_NAME flip=1 save_epoch_freq=$SAVE_EPOCH \
24 | continue_train=0 display_id=$DISPLAY_ID \
25 | checkpoints_dir=$CHECKPOINT_DIR\
26 | th train.lua | tee -a $LOG_FILE
27 |
--------------------------------------------------------------------------------
/imgs/failure_putin.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyanz/CycleGAN/40b4498526de6b566f94a000d98fe18b9b8921db/imgs/failure_putin.jpg
--------------------------------------------------------------------------------
/imgs/horse2zebra.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyanz/CycleGAN/40b4498526de6b566f94a000d98fe18b9b8921db/imgs/horse2zebra.gif
--------------------------------------------------------------------------------
/imgs/objects.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyanz/CycleGAN/40b4498526de6b566f94a000d98fe18b9b8921db/imgs/objects.jpg
--------------------------------------------------------------------------------
/imgs/painting2photo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyanz/CycleGAN/40b4498526de6b566f94a000d98fe18b9b8921db/imgs/painting2photo.jpg
--------------------------------------------------------------------------------
/imgs/paper_thumbnail.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyanz/CycleGAN/40b4498526de6b566f94a000d98fe18b9b8921db/imgs/paper_thumbnail.jpg
--------------------------------------------------------------------------------
/imgs/photo2painting.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyanz/CycleGAN/40b4498526de6b566f94a000d98fe18b9b8921db/imgs/photo2painting.jpg
--------------------------------------------------------------------------------
/imgs/photo_enhancement.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyanz/CycleGAN/40b4498526de6b566f94a000d98fe18b9b8921db/imgs/photo_enhancement.jpg
--------------------------------------------------------------------------------
/imgs/season.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyanz/CycleGAN/40b4498526de6b566f94a000d98fe18b9b8921db/imgs/season.jpg
--------------------------------------------------------------------------------
/imgs/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyanz/CycleGAN/40b4498526de6b566f94a000d98fe18b9b8921db/imgs/teaser.jpg
--------------------------------------------------------------------------------
/models/architectures.lua:
--------------------------------------------------------------------------------
1 | require 'nngraph'
2 |
3 |
4 | ----------------------------------------------------------------------------
5 | local function weights_init(m)
6 | local name = torch.type(m)
7 | if name:find('Convolution') then
8 | m.weight:normal(0.0, 0.02)
9 | m.bias:fill(0)
10 | elseif name:find('Normalization') then
11 | if m.weight then m.weight:normal(1.0, 0.02) end
12 | if m.bias then m.bias:fill(0) end
13 | end
14 | end
15 |
16 |
17 | normalization = nil
18 |
19 | function set_normalization(norm)
20 | if norm == 'instance' then
21 | require 'util.InstanceNormalization'
22 | print('use InstanceNormalization')
23 | normalization = nn.InstanceNormalization
24 | elseif norm == 'batch' then
25 | print('use SpatialBatchNormalization')
26 | normalization = nn.SpatialBatchNormalization
27 | end
28 | end
29 |
30 | function defineG(input_nc, output_nc, ngf, which_model_netG, nz, arch)
31 | local netG = nil
32 | if which_model_netG == "encoder_decoder" then netG = defineG_encoder_decoder(input_nc, output_nc, ngf)
33 | elseif which_model_netG == "unet128" then netG = defineG_unet128(input_nc, output_nc, ngf)
34 | elseif which_model_netG == "unet256" then netG = defineG_unet256(input_nc, output_nc, ngf)
35 | elseif which_model_netG == "resnet_6blocks" then netG = defineG_resnet_6blocks(input_nc, output_nc, ngf)
36 | elseif which_model_netG == "resnet_9blocks" then netG = defineG_resnet_9blocks(input_nc, output_nc, ngf)
37 | else error("unsupported netG model")
38 | end
39 | netG:apply(weights_init)
40 |
41 | return netG
42 | end
43 |
44 | function defineD(input_nc, ndf, which_model_netD, n_layers_D, use_sigmoid)
45 | local netD = nil
46 | if which_model_netD == "basic" then netD = defineD_basic(input_nc, ndf, use_sigmoid)
47 | elseif which_model_netD == "imageGAN" then netD = defineD_imageGAN(input_nc, ndf, use_sigmoid)
48 | elseif which_model_netD == "n_layers" then netD = defineD_n_layers(input_nc, ndf, n_layers_D, use_sigmoid)
49 | else error("unsupported netD model")
50 | end
51 | netD:apply(weights_init)
52 |
53 | return netD
54 | end
55 |
56 | function defineG_encoder_decoder(input_nc, output_nc, ngf)
57 | -- input is (nc) x 256 x 256
58 | local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
59 | -- input is (ngf) x 128 x 128
60 | local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
61 | -- input is (ngf * 2) x 64 x 64
62 | local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
63 | -- input is (ngf * 4) x 32 x 32
64 | local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
65 | -- input is (ngf * 8) x 16 x 16
66 | local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
67 | -- input is (ngf * 8) x 8 x 8
68 | local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
69 | -- input is (ngf * 8) x 4 x 4
70 | local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
71 | -- input is (ngf * 8) x 2 x 2
72 | local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- normalization(ngf * 8)
73 | -- input is (ngf * 8) x 1 x 1
74 |
75 | local d1 = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
76 | -- input is (ngf * 8) x 2 x 2
77 | local d2 = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
78 | -- input is (ngf * 8) x 4 x 4
79 | local d3 = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
80 | -- input is (ngf * 8) x 8 x 8
81 | local d4 = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
82 | -- input is (ngf * 8) x 16 x 16
83 | local d5 = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
84 | -- input is (ngf * 4) x 32 x 32
85 | local d6 = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
86 | -- input is (ngf * 2) x 64 x 64
87 | local d7 = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf)
88 | -- input is (ngf) x128 x 128
89 | local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf, output_nc, 4, 4, 2, 2, 1, 1)
90 | -- input is (nc) x 256 x 256
91 | local o1 = d8 - nn.Tanh()
92 |
93 | local netG = nn.gModule({e1},{o1})
94 | return netG
95 | end
96 |
97 |
98 | function defineG_unet128(input_nc, output_nc, ngf)
99 | local netG = nil
100 | -- input is (nc) x 128 x 128
101 | local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
102 | -- input is (ngf) x 64 x 64
103 | local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
104 | -- input is (ngf * 2) x 32 x 32
105 | local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
106 | -- input is (ngf * 4) x 16 x 16
107 | local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
108 | -- input is (ngf * 8) x 8 x 8
109 | local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
110 | -- input is (ngf * 8) x 4 x 4
111 | local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
112 | -- input is (ngf * 8) x 2 x 2
113 | local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- normalization(ngf * 8)
114 | -- input is (ngf * 8) x 1 x 1
115 |
116 | local d1_ = e7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
117 | -- input is (ngf * 8) x 2 x 2
118 | local d1 = {d1_,e6} - nn.JoinTable(2)
119 | local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
120 | -- input is (ngf * 8) x 4 x 4
121 | local d2 = {d2_,e5} - nn.JoinTable(2)
122 | local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
123 | -- input is (ngf * 8) x 8 x 8
124 | local d3 = {d3_,e4} - nn.JoinTable(2)
125 | local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
126 | -- input is (ngf * 8) x 16 x 16
127 | local d4 = {d4_,e3} - nn.JoinTable(2)
128 | local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
129 | -- input is (ngf * 4) x 32 x 32
130 | local d5 = {d5_,e2} - nn.JoinTable(2)
131 | local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf)
132 | -- input is (ngf * 2) x 64 x 64
133 | local d6 = {d6_,e1} - nn.JoinTable(2)
134 |
135 | local d7 = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1)
136 | -- input is (nc) x 128 x 128
137 |
138 | local o1 = d7 - nn.Tanh()
139 | local netG = nn.gModule({e1},{o1})
140 | return netG
141 | end
142 |
143 |
144 | function defineG_unet256(input_nc, output_nc, ngf)
145 | local netG = nil
146 | -- input is (nc) x 256 x 256
147 | local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
148 | -- input is (ngf) x 128 x 128
149 | local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
150 | -- input is (ngf * 2) x 64 x 64
151 | local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
152 | -- input is (ngf * 4) x 32 x 32
153 | local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
154 | -- input is (ngf * 8) x 16 x 16
155 | local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
156 | -- input is (ngf * 8) x 8 x 8
157 | local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
158 | -- input is (ngf * 8) x 4 x 4
159 | local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
160 | -- input is (ngf * 8) x 2 x 2
161 | local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- - normalization(ngf * 8)
162 | -- input is (ngf * 8) x 1 x 1
163 |
164 | local d1_ = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
165 | -- input is (ngf * 8) x 2 x 2
166 | local d1 = {d1_,e7} - nn.JoinTable(2)
167 | local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
168 | -- input is (ngf * 8) x 4 x 4
169 | local d2 = {d2_,e6} - nn.JoinTable(2)
170 | local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
171 | -- input is (ngf * 8) x 8 x 8
172 | local d3 = {d3_,e5} - nn.JoinTable(2)
173 | local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
174 | -- input is (ngf * 8) x 16 x 16
175 | local d4 = {d4_,e4} - nn.JoinTable(2)
176 | local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
177 | -- input is (ngf * 4) x 32 x 32
178 | local d5 = {d5_,e3} - nn.JoinTable(2)
179 | local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
180 | -- input is (ngf * 2) x 64 x 64
181 | local d6 = {d6_,e2} - nn.JoinTable(2)
182 | local d7_ = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf)
183 | -- input is (ngf) x128 x 128
184 | local d7 = {d7_,e1} - nn.JoinTable(2)
185 | local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1)
186 | -- input is (nc) x 256 x 256
187 |
188 | local o1 = d8 - nn.Tanh()
189 | local netG = nn.gModule({e1},{o1})
190 | return netG
191 | end
192 |
193 | --------------------------------------------------------------------------------
194 | -- Justin Johnson's model from https://github.com/jcjohnson/fast-neural-style/
195 | --------------------------------------------------------------------------------
196 |
197 | local function build_conv_block(dim, padding_type)
198 | local conv_block = nn.Sequential()
199 | local p = 0
200 | if padding_type == 'reflect' then
201 | conv_block:add(nn.SpatialReflectionPadding(1, 1, 1, 1))
202 | elseif padding_type == 'replicate' then
203 | conv_block:add(nn.SpatialReplicationPadding(1, 1, 1, 1))
204 | elseif padding_type == 'zero' then
205 | p = 1
206 | end
207 | conv_block:add(nn.SpatialConvolution(dim, dim, 3, 3, 1, 1, p, p))
208 | conv_block:add(normalization(dim))
209 | conv_block:add(nn.ReLU(true))
210 | if padding_type == 'reflect' then
211 | conv_block:add(nn.SpatialReflectionPadding(1, 1, 1, 1))
212 | elseif padding_type == 'replicate' then
213 | conv_block:add(nn.SpatialReplicationPadding(1, 1, 1, 1))
214 | end
215 | conv_block:add(nn.SpatialConvolution(dim, dim, 3, 3, 1, 1, p, p))
216 | conv_block:add(normalization(dim))
217 | return conv_block
218 | end
219 |
220 |
221 | local function build_res_block(dim, padding_type)
222 | local conv_block = build_conv_block(dim, padding_type)
223 | local res_block = nn.Sequential()
224 | local concat = nn.ConcatTable()
225 | concat:add(conv_block)
226 | concat:add(nn.Identity())
227 |
228 | res_block:add(concat):add(nn.CAddTable())
229 | return res_block
230 | end
231 |
232 | function defineG_resnet_6blocks(input_nc, output_nc, ngf)
233 | padding_type = 'reflect'
234 | local ks = 3
235 | local netG = nil
236 | local f = 7
237 | local p = (f - 1) / 2
238 | local data = -nn.Identity()
239 | local e1 = data - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(input_nc, ngf, f, f, 1, 1) - normalization(ngf) - nn.ReLU(true)
240 | local e2 = e1 - nn.SpatialConvolution(ngf, ngf*2, ks, ks, 2, 2, 1, 1) - normalization(ngf*2) - nn.ReLU(true)
241 | local e3 = e2 - nn.SpatialConvolution(ngf*2, ngf*4, ks, ks, 2, 2, 1, 1) - normalization(ngf*4) - nn.ReLU(true)
242 | local d1 = e3 - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
243 | - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
244 | local d2 = d1 - nn.SpatialFullConvolution(ngf*4, ngf*2, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf*2) - nn.ReLU(true)
245 | local d3 = d2 - nn.SpatialFullConvolution(ngf*2, ngf, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf) - nn.ReLU(true)
246 | local d4 = d3 - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(ngf, output_nc, f, f, 1, 1) - nn.Tanh()
247 | netG = nn.gModule({data},{d4})
248 | return netG
249 | end
250 |
251 | function defineG_resnet_9blocks(input_nc, output_nc, ngf)
252 | padding_type = 'reflect'
253 | local ks = 3
254 | local netG = nil
255 | local f = 7
256 | local p = (f - 1) / 2
257 | local data = -nn.Identity()
258 | local e1 = data - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(input_nc, ngf, f, f, 1, 1) - normalization(ngf) - nn.ReLU(true)
259 | local e2 = e1 - nn.SpatialConvolution(ngf, ngf*2, ks, ks, 2, 2, 1, 1) - normalization(ngf*2) - nn.ReLU(true)
260 | local e3 = e2 - nn.SpatialConvolution(ngf*2, ngf*4, ks, ks, 2, 2, 1, 1) - normalization(ngf*4) - nn.ReLU(true)
261 | local d1 = e3 - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
262 | - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
263 | - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
264 | local d2 = d1 - nn.SpatialFullConvolution(ngf*4, ngf*2, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf*2) - nn.ReLU(true)
265 | local d3 = d2 - nn.SpatialFullConvolution(ngf*2, ngf, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf) - nn.ReLU(true)
266 | local d4 = d3 - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(ngf, output_nc, f, f, 1, 1) - nn.Tanh()
267 | netG = nn.gModule({data},{d4})
268 | return netG
269 | end
270 |
271 | function defineD_imageGAN(input_nc, ndf, use_sigmoid)
272 | local netD = nn.Sequential()
273 |
274 | -- input is (nc) x 256 x 256
275 | netD:add(nn.SpatialConvolution(input_nc, ndf, 4, 4, 2, 2, 1, 1))
276 | netD:add(nn.LeakyReLU(0.2, true))
277 | -- state size: (ndf) x 128 x 128
278 | netD:add(nn.SpatialConvolution(ndf, ndf * 2, 4, 4, 2, 2, 1, 1))
279 | netD:add(nn.SpatialBatchNormalization(ndf * 2)):add(nn.LeakyReLU(0.2, true))
280 | -- state size: (ndf*2) x 64 x 64
281 | netD:add(nn.SpatialConvolution(ndf * 2, ndf*4, 4, 4, 2, 2, 1, 1))
282 | netD:add(nn.SpatialBatchNormalization(ndf * 4)):add(nn.LeakyReLU(0.2, true))
283 | -- state size: (ndf*4) x 32 x 32
284 | netD:add(nn.SpatialConvolution(ndf * 4, ndf * 8, 4, 4, 2, 2, 1, 1))
285 | netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true))
286 | -- state size: (ndf*8) x 16 x 16
287 | netD:add(nn.SpatialConvolution(ndf * 8, ndf * 8, 4, 4, 2, 2, 1, 1))
288 | netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true))
289 | -- state size: (ndf*8) x 8 x 8
290 | netD:add(nn.SpatialConvolution(ndf * 8, ndf * 8, 4, 4, 2, 2, 1, 1))
291 | netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true))
292 | -- state size: (ndf*8) x 4 x 4
293 | netD:add(nn.SpatialConvolution(ndf * 8, 1, 4, 4, 2, 2, 1, 1))
294 | -- state size: 1 x 1 x 1
295 | if use_sigmoid then
296 | netD:add(nn.Sigmoid())
297 | end
298 |
299 | return netD
300 | end
301 |
302 |
303 |
304 | function defineD_basic(input_nc, ndf, use_sigmoid)
305 | n_layers = 3
306 | return defineD_n_layers(input_nc, ndf, n_layers, use_sigmoid)
307 | end
308 |
309 | -- rf=1
310 | function defineD_pixelGAN(input_nc, ndf, use_sigmoid)
311 |
312 | local netD = nn.Sequential()
313 |
314 | -- input is (nc) x 256 x 256
315 | netD:add(nn.SpatialConvolution(input_nc, ndf, 1, 1, 1, 1, 0, 0))
316 | netD:add(nn.LeakyReLU(0.2, true))
317 | -- state size: (ndf) x 256 x 256
318 | netD:add(nn.SpatialConvolution(ndf, ndf * 2, 1, 1, 1, 1, 0, 0))
319 | netD:add(normalization(ndf * 2)):add(nn.LeakyReLU(0.2, true))
320 | -- state size: (ndf*2) x 256 x 256
321 | netD:add(nn.SpatialConvolution(ndf * 2, 1, 1, 1, 1, 1, 0, 0))
322 | -- state size: 1 x 256 x 256
323 | if use_sigmoid then
324 | netD:add(nn.Sigmoid())
325 | -- state size: 1 x 30 x 30
326 | end
327 |
328 | return netD
329 | end
330 |
331 | -- if n=0, then use pixelGAN (rf=1)
332 | -- else rf is 16 if n=1
333 | -- 34 if n=2
334 | -- 70 if n=3
335 | -- 142 if n=4
336 | -- 286 if n=5
337 | -- 574 if n=6
338 | function defineD_n_layers(input_nc, ndf, n_layers, use_sigmoid, kw, dropout_ratio)
339 |
340 | if dropout_ratio == nil then
341 | dropout_ratio = 0.0
342 | end
343 |
344 | if kw == nil then
345 | kw = 4
346 | end
347 | padw = math.ceil((kw-1)/2)
348 |
349 | if n_layers==0 then
350 | return defineD_pixelGAN(input_nc, ndf, use_sigmoid)
351 | else
352 |
353 | local netD = nn.Sequential()
354 |
355 | -- input is (nc) x 256 x 256
356 | -- print('input_nc', input_nc)
357 | netD:add(nn.SpatialConvolution(input_nc, ndf, kw, kw, 2, 2, padw, padw))
358 | netD:add(nn.LeakyReLU(0.2, true))
359 |
360 | local nf_mult = 1
361 | local nf_mult_prev = 1
362 | for n = 1, n_layers-1 do
363 | nf_mult_prev = nf_mult
364 | nf_mult = math.min(2^n,8)
365 | netD:add(nn.SpatialConvolution(ndf * nf_mult_prev, ndf * nf_mult, kw, kw, 2, 2, padw,padw))
366 | netD:add(normalization(ndf * nf_mult)):add(nn.Dropout(dropout_ratio))
367 | netD:add(nn.LeakyReLU(0.2, true))
368 | end
369 |
370 | -- state size: (ndf*M) x N x N
371 | nf_mult_prev = nf_mult
372 | nf_mult = math.min(2^n_layers,8)
373 | netD:add(nn.SpatialConvolution(ndf * nf_mult_prev, ndf * nf_mult, kw, kw, 1, 1, padw, padw))
374 | netD:add(normalization(ndf * nf_mult)):add(nn.LeakyReLU(0.2, true))
375 | -- state size: (ndf*M*2) x (N-1) x (N-1)
376 | netD:add(nn.SpatialConvolution(ndf * nf_mult, 1, kw, kw, 1, 1, padw,padw))
377 | -- state size: 1 x (N-2) x (N-2)
378 | if use_sigmoid then
379 | netD:add(nn.Sigmoid())
380 | end
381 | -- state size: 1 x (N-2) x (N-2)
382 | return netD
383 | end
384 | end
385 |
--------------------------------------------------------------------------------
/models/base_model.lua:
--------------------------------------------------------------------------------
1 | --------------------------------------------------------------------------------
2 | -- Base Class for Providing Models
3 | --------------------------------------------------------------------------------
4 |
5 | local class = require 'class'
6 |
7 | BaseModel = class('BaseModel')
8 |
9 | function BaseModel:__init(conf)
10 | conf = conf or {}
11 | end
12 |
13 | -- Returns the name of the model
14 | function BaseModel:model_name()
15 | return 'DoesNothingModel'
16 | end
17 |
18 | -- Defines models and networks
19 | function BaseModel:Initialize(opt)
20 | models = {}
21 | return models
22 | end
23 |
24 | -- Runs the forward pass of the network
25 | function BaseModel:Forward(input, opt)
26 | output = {}
27 | return output
28 | end
29 |
30 | -- Runs the backprop gradient descent
31 | -- Corresponds to a single batch of data
32 | function BaseModel:OptimizeParameters(opt)
33 | end
34 |
35 | -- This function can be used to reset momentum after each epoch
36 | function BaseModel:RefreshParameters(opt)
37 | end
38 |
39 | -- This function can be used to reset momentum after each epoch
40 | function BaseModel:UpdateLearningRate(opt)
41 | end
42 | -- Save the current model to the file system
43 | function BaseModel:Save(prefix, opt)
44 | end
45 |
46 | -- returns a string that describes the current errors
47 | function BaseModel:GetCurrentErrorDescription()
48 | return "No Error exists in BaseModel"
49 | end
50 |
51 | -- returns current errors
52 | function BaseModel:GetCurrentErrors(opt)
53 | return {}
54 | end
55 |
56 | -- returns a table of image/label pairs that describe
57 | -- the current results.
58 | -- |return|: a table of table. List of image/label pairs
59 | function BaseModel:GetCurrentVisuals(opt, size)
60 | return {}
61 | end
62 |
63 | -- returns a string that describes the display plot configuration
64 | function BaseModel:DisplayPlot(opt)
65 | return {}
66 | end
67 |
--------------------------------------------------------------------------------
/models/bigan_model.lua:
--------------------------------------------------------------------------------
1 | local class = require 'class'
2 | require 'models.base_model'
3 | require 'models.architectures'
4 | require 'util.image_pool'
5 | util = paths.dofile('../util/util.lua')
6 | content = paths.dofile('../util/content_loss.lua')
7 |
8 | BiGANModel = class('BiGANModel', 'BaseModel')
9 |
10 | function BiGANModel:__init(conf)
11 | BaseModel.__init(self, conf)
12 | conf = conf or {}
13 | end
14 |
15 | function BiGANModel:model_name()
16 | return 'BiGANModel'
17 | end
18 |
19 | function BiGANModel:InitializeStates(use_wgan)
20 | optimState = {learningRate=opt.lr, beta1=opt.beta1,}
21 | return optimState
22 | end
23 | -- Defines models and networks
24 | function BiGANModel:Initialize(opt)
25 | if opt.test == 0 then
26 | self.realABPool = ImagePool(opt.pool_size)
27 | self.fakeABPool = ImagePool(opt.pool_size)
28 | end
29 | -- define tensors
30 | local d_input_nc = opt.input_nc + opt.output_nc
31 | self.real_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize)
32 | self.fake_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize)
33 | -- load/define models
34 | self.criterionGAN = nn.MSECriterion()
35 |
36 | local netG, netE, netD = nil, nil, nil
37 | if opt.continue_train == 1 then
38 | if opt.test == 1 then -- which_epoch option exists in test mode
39 | netG = util.load_test_model('G', opt)
40 | netE = util.load_test_model('E', opt)
41 | netD = util.load_test_model('D', opt)
42 | else
43 | netG = util.load_model('G', opt)
44 | netE = util.load_model('E', opt)
45 | netD = util.load_model('D', opt)
46 | end
47 | else
48 | -- netG_test = defineG(opt.input_nc, opt.output_nc, opt.ngf, "resnet_unet", opt.arch)
49 | -- os.exit()
50 | netD = defineD(d_input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, false) -- no sigmoid layer
51 | print('netD...', netD)
52 | netG = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.arch)
53 | print('netG...', netG)
54 | netE = defineG(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.arch)
55 | print('netE...', netE)
56 |
57 | end
58 |
59 | self.netD = netD
60 | self.netG = netG
61 | self.netE = netE
62 |
63 | -- define real/fake labels
64 | netD_output_size = self.netD:forward(self.real_AB):size()
65 | self.fake_label = torch.Tensor(netD_output_size):fill(0.0)
66 | self.real_label = torch.Tensor(netD_output_size):fill(1.0) -- no soft smoothing
67 |
68 | self.optimStateD = self:InitializeStates()
69 | self.optimStateG = self:InitializeStates()
70 | self.optimStateE = self:InitializeStates()
71 | self.A_idx = {{}, {1, opt.input_nc}, {}, {}}
72 | self.B_idx = {{}, {opt.input_nc+1, opt.input_nc+opt.output_nc}, {}, {}}
73 | self:RefreshParameters()
74 |
75 | print('---------- # Learnable Parameters --------------')
76 | print(('G = %d'):format(self.parametersG:size(1)))
77 | print(('E = %d'):format(self.parametersE:size(1)))
78 | print(('D = %d'):format(self.parametersD:size(1)))
79 | print('------------------------------------------------')
80 | -- os.exit()
81 | end
82 |
83 | -- Runs the forward pass of the network and
84 | -- saves the result to member variables of the class
85 | function BiGANModel:Forward(input, opt)
86 | if opt.which_direction == 'BtoA' then
87 | local temp = input.real_A
88 | input.real_A = input.real_B
89 | input.real_B = temp
90 | end
91 | self.real_AB[self.A_idx]:copy(input.real_A)
92 | self.fake_AB[self.B_idx]:copy(input.real_B)
93 | self.real_A = self.real_AB[self.A_idx]
94 | self.real_B = self.fake_AB[self.B_idx]
95 | self.fake_B = self.netG:forward(self.real_A):clone()
96 | self.fake_A = self.netE:forward(self.real_B):clone()
97 | self.real_AB[self.B_idx]:copy(self.fake_B) -- real_AB: real_A, fake_B -> real_label
98 | self.fake_AB[self.A_idx]:copy(self.fake_A) -- fake_AB: fake_A, real_B -> fake_label
99 | -- if opt.test == 0 then
100 | -- self.real_AB = self.realABPool:Query(self.real_AB) -- batch history
101 | -- self.fake_AB = self.fakeABPool:Query(self.fake_AB) -- batch history
102 | -- end
103 | end
104 |
105 | -- create closure to evaluate f(X) and df/dX of discriminator
106 | function BiGANModel:fDx_basic(x, gradParams, netD, real_AB, fake_AB, opt)
107 | util.BiasZero(netD)
108 | gradParams:zero()
109 | -- Real log(D_A(B))
110 | local output = netD:forward(real_AB):clone()
111 | local errD_real = self.criterionGAN:forward(output, self.real_label)
112 | local df_do = self.criterionGAN:backward(output, self.real_label)
113 | netD:backward(real_AB, df_do)
114 | -- Fake + log(1 - D_A(G(A)))
115 | output = netD:forward(fake_AB):clone()
116 | local errD_fake = self.criterionGAN:forward(output, self.fake_label)
117 | local df_do2 = self.criterionGAN:backward(output, self.fake_label)
118 | netD:backward(fake_AB, df_do2)
119 | -- Compute loss
120 | local errD = (errD_real + errD_fake) / 2.0
121 | return errD, gradParams
122 | end
123 |
124 |
125 | function BiGANModel:fDx(x, opt)
126 | -- use image pool that stores the old fake images
127 | real_AB = self.realABPool:Query(self.real_AB)
128 | fake_AB = self.fakeABPool:Query(self.fake_AB)
129 | self.errD, gradParams = self:fDx_basic(x, self.gradParametersD, self.netD, real_AB, fake_AB, opt)
130 | return self.errD, gradParams
131 | end
132 |
133 |
134 |
135 | function BiGANModel:fGx_basic(x, netG, netD, gradParametersG, opt)
136 | util.BiasZero(netG)
137 | util.BiasZero(netD)
138 | gradParametersG:zero()
139 |
140 | -- First. G(A) should fake the discriminator
141 | local output = netD:forward(self.real_AB):clone()
142 | local errG = self.criterionGAN:forward(output, self.fake_label)
143 | local dgan_loss_dd = self.criterionGAN:backward(output, self.fake_label)
144 | local dgan_loss_do = netD:updateGradInput(self.real_AB, dgan_loss_dd)
145 | netG:backward(self.real_A, dgan_loss_do[self.B_idx]) -- real_AB: real_A, fake_B -> real_label
146 | return gradParametersG, errG
147 | end
148 |
149 |
150 | function BiGANModel:fGx(x, opt)
151 | self.gradParametersG, self.errG = self:fGx_basic(x, self.netG, self.netD,
152 | self.gradParametersG, opt)
153 | return self.errG, self.gradParametersG
154 | end
155 |
156 |
157 | function BiGANModel:fEx_basic(x, netE, netD, gradParametersE, opt)
158 | util.BiasZero(netE)
159 | util.BiasZero(netD)
160 | gradParametersE:zero()
161 |
162 | -- First. G(A) should fake the discriminator
163 | local output = netD:forward(self.fake_AB):clone()
164 | local errE= self.criterionGAN:forward(output, self.real_label)
165 | local dgan_loss_dd = self.criterionGAN:backward(output, self.real_label)
166 | local dgan_loss_do = netD:updateGradInput(self.fake_AB, dgan_loss_dd)
167 | netE:backward(self.real_B, dgan_loss_do[self.A_idx])-- fake_AB: fake_A, real_B -> fake_label
168 | return gradParametersE, errE
169 | end
170 |
171 |
172 | function BiGANModel:fEx(x, opt)
173 | self.gradParametersE, self.errE = self:fEx_basic(x, self.netE, self.netD,
174 | self.gradParametersE, opt)
175 | return self.errE, self.gradParametersE
176 | end
177 |
178 |
179 | function BiGANModel:OptimizeParameters(opt)
180 | local fG = function(x) return self:fGx(x, opt) end
181 | local fE = function(x) return self:fEx(x, opt) end
182 | local fD = function(x) return self:fDx(x, opt) end
183 | optim.adam(fD, self.parametersD, self.optimStateD)
184 | optim.adam(fG, self.parametersG, self.optimStateG)
185 | optim.adam(fE, self.parametersE, self.optimStateE)
186 | end
187 |
188 | function BiGANModel:RefreshParameters()
189 | self.parametersD, self.gradParametersD = nil, nil -- nil them to avoid spiking memory
190 | self.parametersG, self.gradParametersG = nil, nil
191 | self.parametersE, self.gradParametersE = nil, nil
192 | -- define parameters of optimization
193 | self.parametersD, self.gradParametersD = self.netD:getParameters()
194 | self.parametersG, self.gradParametersG = self.netG:getParameters()
195 | self.parametersE, self.gradParametersE = self.netE:getParameters()
196 | end
197 |
198 | function BiGANModel:Save(prefix, opt)
199 | util.save_model(self.netG, prefix .. '_net_G.t7', 1)
200 | util.save_model(self.netE, prefix .. '_net_E.t7', 1)
201 | util.save_model(self.netD, prefix .. '_net_D.t7', 1)
202 | end
203 |
204 | function BiGANModel:GetCurrentErrorDescription()
205 | description = ('D: %.4f G: %.4f E: %.4f'):format(
206 | self.errD and self.errD or -1,
207 | self.errG and self.errG or -1,
208 | self.errE and self.errE or -1)
209 | return description
210 | end
211 |
212 | function BiGANModel:GetCurrentErrors()
213 | local errors = {errD=self.errD, errG=self.errG, errE=self.errE}
214 | return errors
215 | end
216 |
217 | -- returns a string that describes the display plot configuration
218 | function BiGANModel:DisplayPlot(opt)
219 | return 'errD,errG,errE'
220 | end
221 | function BiGANModel:UpdateLearningRate(opt)
222 | local lrd = opt.lr / opt.niter_decay
223 | local old_lr = self.optimStateD['learningRate']
224 | local lr = old_lr - lrd
225 | self.optimStateD['learningRate'] = lr
226 | self.optimStateG['learningRate'] = lr
227 | self.optimStateE['learningRate'] = lr
228 | print(('update learning rate: %f -> %f'):format(old_lr, lr))
229 | end
230 |
231 | local function MakeIm3(im)
232 | -- print('before im_size', im:size())
233 | local im3 = nil
234 | if im:size(2) == 1 then
235 | im3 = torch.repeatTensor(im, 1,3,1,1)
236 | else
237 | im3 = im
238 | end
239 | -- print('after im_size', im:size())
240 | -- print('after im3_size', im3:size())
241 | return im3
242 | end
243 | function BiGANModel:GetCurrentVisuals(opt, size)
244 | if not size then
245 | size = opt.display_winsize
246 | end
247 |
248 | local visuals = {}
249 | table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'})
250 | table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'})
251 | table.insert(visuals, {img=MakeIm3(self.real_B), label='real_B'})
252 | table.insert(visuals, {img=MakeIm3(self.fake_A), label='fake_A'})
253 | return visuals
254 | end
255 |
--------------------------------------------------------------------------------
/models/content_gan_model.lua:
--------------------------------------------------------------------------------
1 | local class = require 'class'
2 | require 'models.base_model'
3 | require 'models.architectures'
4 | require 'util.image_pool'
5 | util = paths.dofile('../util/util.lua')
6 | content = paths.dofile('../util/content_loss.lua')
7 |
8 | ContentGANModel = class('ContentGANModel', 'BaseModel')
9 |
10 | function ContentGANModel:__init(conf)
11 | BaseModel.__init(self, conf)
12 | conf = conf or {}
13 | end
14 |
15 | function ContentGANModel:model_name()
16 | return 'ContentGANModel'
17 | end
18 |
19 | function ContentGANModel:InitializeStates()
20 | local optimState = {learningRate=opt.lr, beta1=opt.beta1,}
21 | return optimState
22 | end
23 | -- Defines models and networks
24 | function ContentGANModel:Initialize(opt)
25 | if opt.test == 0 then
26 | self.fakePool = ImagePool(opt.pool_size)
27 | end
28 | -- define tensors
29 | self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
30 | self.fake_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
31 | self.real_B = self.fake_B:clone() --torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
32 |
33 | -- load/define models
34 | self.criterionGAN = nn.MSECriterion()
35 | self.criterionContent = nn.AbsCriterion()
36 | self.contentFunc = content.defineContent(opt.content_loss, opt.layer_name)
37 | self.netG, self.netD = nil, nil
38 | if opt.continue_train == 1 then
39 | if opt.which_epoch then -- which_epoch option exists in test mode
40 | self.netG = util.load_test_model('G_A', opt)
41 | self.netD = util.load_test_model('D_A', opt)
42 | else
43 | self.netG = util.load_model('G_A', opt)
44 | self.netD = util.load_model('D_A', opt)
45 | end
46 | else
47 | self.netG = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG)
48 | print('netG...', self.netG)
49 | self.netD = defineD(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, false)
50 | print('netD...', self.netD)
51 | end
52 | -- define real/fake labels
53 | netD_output_size = self.netD:forward(self.real_A):size()
54 | self.fake_label = torch.Tensor(netD_output_size):fill(0.0)
55 | self.real_label = torch.Tensor(netD_output_size):fill(1.0) -- no soft smoothing
56 | self.optimStateD = self:InitializeStates()
57 | self.optimStateG = self:InitializeStates()
58 | self:RefreshParameters()
59 | print('---------- # Learnable Parameters --------------')
60 | print(('G = %d'):format(self.parametersG:size(1)))
61 | print(('D = %d'):format(self.parametersD:size(1)))
62 | print('------------------------------------------------')
63 | -- os.exit()
64 | end
65 |
66 | -- Runs the forward pass of the network and
67 | -- saves the result to member variables of the class
68 | function ContentGANModel:Forward(input, opt)
69 | if opt.which_direction == 'BtoA' then
70 | local temp = input.real_A
71 | input.real_A = input.real_B
72 | input.real_B = temp
73 | end
74 |
75 | self.real_A:copy(input.real_A)
76 | self.real_B:copy(input.real_B)
77 | self.fake_B = self.netG:forward(self.real_A):clone()
78 | -- output = {self.fake_B}
79 | output = {}
80 | -- if opt.test == 1 then
81 |
82 | -- end
83 | return output
84 | end
85 |
86 | -- create closure to evaluate f(X) and df/dX of discriminator
87 | function ContentGANModel:fDx_basic(x, gradParams, netD, netG,
88 | real_target, fake_target, opt)
89 | util.BiasZero(netD)
90 | util.BiasZero(netG)
91 | gradParams:zero()
92 |
93 | local errD_real, errD_rec, errD_fake, errD = 0, 0, 0, 0
94 | -- Real log(D_A(B))
95 | local output = netD:forward(real_target)
96 | errD_real = self.criterionGAN:forward(output, self.real_label)
97 | df_do = self.criterionGAN:backward(output, self.real_label)
98 | netD:backward(real_target, df_do)
99 |
100 | -- Fake + log(1 - D_A(G_A(A)))
101 | output = netD:forward(fake_target)
102 | errD_fake = self.criterionGAN:forward(output, self.fake_label)
103 | df_do = self.criterionGAN:backward(output, self.fake_label)
104 | netD:backward(fake_target, df_do)
105 | errD = (errD_real + errD_fake) / 2.0
106 | -- print('errD', errD
107 | return errD, gradParams
108 | end
109 |
110 |
111 | function ContentGANModel:fDx(x, opt)
112 | fake_B = self.fakePool:Query(self.fake_B)
113 | self.errD, gradParams = self:fDx_basic(x, self.gradparametersD, self.netD, self.netG,
114 | self.real_B, fake_B, opt)
115 | return self.errD, gradParams
116 | end
117 |
118 | function ContentGANModel:fGx_basic(x, netG_source, netD_source, real_source, real_target, fake_target,
119 | gradParametersG_source, opt)
120 | util.BiasZero(netD_source)
121 | util.BiasZero(netG_source)
122 | gradParametersG_source:zero()
123 | -- GAN loss
124 | -- local df_d_GAN = torch.zeros(fake_target:size())
125 | -- local errGAN = 0
126 | -- local errRec = 0
127 | --- Domain GAN loss: D_A(G_A(A))
128 | local output = netD_source.output -- [hack] forward was already executed in fDx, so save computation netD_source:forward(fake_B) ---
129 | local errGAN = self.criterionGAN:forward(output, self.real_label)
130 | local df_do = self.criterionGAN:backward(output, self.real_label)
131 | local df_d_GAN = netD_source:updateGradInput(fake_target, df_do) ---:narrow(2,fake_AB:size(2)-output_nc+1, output_nc)
132 |
133 | -- content loss
134 | -- print('content_loss', opt.content_loss)
135 | -- function content.lossUpdate(criterionContent, real_source, fake_target, contentFunc, loss_type, weight)
136 | local errContent, df_d_content = content.lossUpdate(self.criterionContent, real_source, fake_target, self.contentFunc, opt.content_loss, opt.lambda_A)
137 | netG_source:forward(real_source)
138 | netG_source:backward(real_source, df_d_GAN + df_d_content)
139 | -- print('errD', errGAN)
140 | return gradParametersG_source, errGAN, errContent
141 | end
142 |
143 | function ContentGANModel:fGx(x, opt)
144 | self.gradparametersG, self.errG, self.errCont =
145 | self:fGx_basic(x, self.netG, self.netD,
146 | self.real_A, self.real_B, self.fake_B,
147 | self.gradparametersG, opt)
148 | return self.errG, self.gradparametersG
149 | end
150 |
151 | function ContentGANModel:OptimizeParameters(opt)
152 | local fDx = function(x) return self:fDx(x, opt) end
153 | local fGx = function(x) return self:fGx(x, opt) end
154 | optim.adam(fDx, self.parametersD, self.optimStateD)
155 | optim.adam(fGx, self.parametersG, self.optimStateG)
156 | end
157 |
158 | function ContentGANModel:RefreshParameters()
159 | self.parametersD, self.gradparametersD = nil, nil -- nil them to avoid spiking memory
160 | self.parametersG, self.gradparametersG = nil, nil
161 | -- define parameters of optimization
162 | self.parametersG, self.gradparametersG = self.netG:getParameters()
163 | self.parametersD, self.gradparametersD = self.netD:getParameters()
164 | end
165 |
166 | function ContentGANModel:Save(prefix, opt)
167 | util.save_model(self.netG, prefix .. '_net_G_A.t7', 1.0)
168 | util.save_model(self.netD, prefix .. '_net_D_A.t7', 1.0)
169 | end
170 |
171 | function ContentGANModel:GetCurrentErrorDescription()
172 | description = ('G: %.4f D: %.4f Content: %.4f'):format(self.errG and self.errG or -1,
173 | self.errD and self.errD or -1,
174 | self.errCont and self.errCont or -1)
175 | return description
176 | end
177 |
178 |
179 | function ContentGANModel:GetCurrentErrors()
180 | local errors = {errG=self.errG and self.errG or -1, errD=self.errD and self.errD or -1,
181 | errCont=self.errCont and self.errCont or -1}
182 | return errors
183 | end
184 |
185 | -- returns a string that describes the display plot configuration
186 | function ContentGANModel:DisplayPlot(opt)
187 | return 'errG,errD,errCont'
188 | end
189 |
190 |
191 | function ContentGANModel:GetCurrentVisuals(opt, size)
192 | if not size then
193 | size = opt.display_winsize
194 | end
195 |
196 | local visuals = {}
197 | table.insert(visuals, {img=self.real_A, label='real_A'})
198 | table.insert(visuals, {img=self.fake_B, label='fake_B'})
199 | table.insert(visuals, {img=self.real_B, label='real_B'})
200 | return visuals
201 | end
202 |
--------------------------------------------------------------------------------
/models/cycle_gan_model.lua:
--------------------------------------------------------------------------------
1 | local class = require 'class'
2 | require 'models.base_model'
3 | require 'models.architectures'
4 | require 'util.image_pool'
5 |
6 | util = paths.dofile('../util/util.lua')
7 | CycleGANModel = class('CycleGANModel', 'BaseModel')
8 |
9 | function CycleGANModel:__init(conf)
10 | BaseModel.__init(self, conf)
11 | conf = conf or {}
12 | end
13 |
14 | function CycleGANModel:model_name()
15 | return 'CycleGANModel'
16 | end
17 |
18 | function CycleGANModel:InitializeStates(use_wgan)
19 | optimState = {learningRate=opt.lr, beta1=opt.beta1,}
20 | return optimState
21 | end
22 | -- Defines models and networks
23 | function CycleGANModel:Initialize(opt)
24 | if opt.test == 0 then
25 | self.fakeAPool = ImagePool(opt.pool_size)
26 | self.fakeBPool = ImagePool(opt.pool_size)
27 | end
28 | -- define tensors
29 | if opt.test == 0 then -- allocate tensors for training
30 | self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
31 | self.real_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
32 | self.fake_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
33 | self.fake_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
34 | self.rec_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
35 | self.rec_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
36 | end
37 | -- load/define models
38 | local use_lsgan = ((opt.use_lsgan ~= nil) and (opt.use_lsgan == 1))
39 | if not use_lsgan then
40 | self.criterionGAN = nn.BCECriterion()
41 | else
42 | self.criterionGAN = nn.MSECriterion()
43 | end
44 | self.criterionRec = nn.AbsCriterion()
45 |
46 | local netG_A, netD_A, netG_B, netD_B = nil, nil, nil, nil
47 | if opt.continue_train == 1 then
48 | if opt.test == 1 then -- test mode
49 | netG_A = util.load_test_model('G_A', opt)
50 | netG_B = util.load_test_model('G_B', opt)
51 |
52 | --setup optnet to save a little bit of memory
53 | if opt.use_optnet == 1 then
54 | local sample_input = torch.randn(1, opt.input_nc, 2, 2)
55 | local optnet = require 'optnet'
56 | optnet.optimizeMemory(netG_A, sample_input, {inplace=true, reuseBuffers=true})
57 | optnet.optimizeMemory(netG_B, sample_input, {inplace=true, reuseBuffers=true})
58 | end
59 | else
60 | netG_A = util.load_model('G_A', opt)
61 | netG_B = util.load_model('G_B', opt)
62 | netD_A = util.load_model('D_A', opt)
63 | netD_B = util.load_model('D_B', opt)
64 | end
65 | else
66 | local use_sigmoid = (not use_lsgan)
67 | -- netG_test = defineG(opt.input_nc, opt.output_nc, opt.ngf, "resnet_unet", opt.arch)
68 | -- os.exit()
69 | netG_A = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.arch)
70 | print('netG_A...', netG_A)
71 | netD_A = defineD(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, use_sigmoid) -- no sigmoid layer
72 | print('netD_A...', netD_A)
73 | netG_B = defineG(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.arch)
74 | print('netG_B...', netG_B)
75 | netD_B = defineD(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, use_sigmoid) -- no sigmoid layer
76 | print('netD_B', netD_B)
77 | end
78 |
79 | self.netD_A = netD_A
80 | self.netG_A = netG_A
81 | self.netG_B = netG_B
82 | self.netD_B = netD_B
83 |
84 | -- define real/fake labels
85 | if opt.test == 0 then
86 | local D_A_size = self.netD_A:forward(self.real_B):size() -- hack: assume D_size_A = D_size_B
87 | self.fake_label_A = torch.Tensor(D_A_size):fill(0.0)
88 | self.real_label_A = torch.Tensor(D_A_size):fill(1.0) -- no soft smoothing
89 | local D_B_size = self.netD_B:forward(self.real_A):size() -- hack: assume D_size_A = D_size_B
90 | self.fake_label_B = torch.Tensor(D_B_size):fill(0.0)
91 | self.real_label_B = torch.Tensor(D_B_size):fill(1.0) -- no soft smoothing
92 | self.optimStateD_A = self:InitializeStates()
93 | self.optimStateG_A = self:InitializeStates()
94 | self.optimStateD_B = self:InitializeStates()
95 | self.optimStateG_B = self:InitializeStates()
96 | self:RefreshParameters()
97 | print('---------- # Learnable Parameters --------------')
98 | print(('G_A = %d'):format(self.parametersG_A:size(1)))
99 | print(('D_A = %d'):format(self.parametersD_A:size(1)))
100 | print(('G_B = %d'):format(self.parametersG_B:size(1)))
101 | print(('D_B = %d'):format(self.parametersD_B:size(1)))
102 | print('------------------------------------------------')
103 | end
104 | end
105 |
106 | -- Runs the forward pass of the network and
107 | -- saves the result to member variables of the class
108 | function CycleGANModel:Forward(input, opt)
109 | if opt.which_direction == 'BtoA' then
110 | local temp = input.real_A:clone()
111 | input.real_A = input.real_B:clone()
112 | input.real_B = temp
113 | end
114 |
115 | if opt.test == 0 then
116 | self.real_A:copy(input.real_A)
117 | self.real_B:copy(input.real_B)
118 | end
119 |
120 | if opt.test == 1 then -- forward for test
121 | if opt.gpu > 0 then
122 | self.real_A = input.real_A:cuda()
123 | self.real_B = input.real_B:cuda()
124 | else
125 | self.real_A = input.real_A:clone()
126 | self.real_B = input.real_B:clone()
127 | end
128 | self.fake_B = self.netG_A:forward(self.real_A):clone()
129 | self.fake_A = self.netG_B:forward(self.real_B):clone()
130 | self.rec_A = self.netG_B:forward(self.fake_B):clone()
131 | self.rec_B = self.netG_A:forward(self.fake_A):clone()
132 | end
133 | end
134 |
135 | -- create closure to evaluate f(X) and df/dX of discriminator
136 | function CycleGANModel:fDx_basic(x, gradParams, netD, netG, real, fake, real_label, fake_label, opt)
137 | util.BiasZero(netD)
138 | util.BiasZero(netG)
139 | gradParams:zero()
140 | -- Real log(D_A(B))
141 | local output = netD:forward(real)
142 | local errD_real = self.criterionGAN:forward(output, real_label)
143 | local df_do = self.criterionGAN:backward(output, real_label)
144 | netD:backward(real, df_do)
145 | -- Fake + log(1 - D_A(G_A(A)))
146 | output = netD:forward(fake)
147 | local errD_fake = self.criterionGAN:forward(output, fake_label)
148 | local df_do2 = self.criterionGAN:backward(output, fake_label)
149 | netD:backward(fake, df_do2)
150 | -- Compute loss
151 | local errD = (errD_real + errD_fake) / 2.0
152 | return errD, gradParams
153 | end
154 |
155 |
156 | function CycleGANModel:fDAx(x, opt)
157 | -- use image pool that stores the old fake images
158 | fake_B = self.fakeBPool:Query(self.fake_B)
159 | self.errD_A, gradParams = self:fDx_basic(x, self.gradparametersD_A, self.netD_A, self.netG_A,
160 | self.real_B, fake_B, self.real_label_A, self.fake_label_A, opt)
161 | return self.errD_A, gradParams
162 | end
163 |
164 |
165 | function CycleGANModel:fDBx(x, opt)
166 | -- use image pool that stores the old fake images
167 | fake_A = self.fakeAPool:Query(self.fake_A)
168 | self.errD_B, gradParams = self:fDx_basic(x, self.gradparametersD_B, self.netD_B, self.netG_B,
169 | self.real_A, fake_A, self.real_label_B, self.fake_label_B, opt)
170 | return self.errD_B, gradParams
171 | end
172 |
173 |
174 | function CycleGANModel:fGx_basic(x, gradParams, netG, netD, netE, real, real2, real_label, lambda1, lambda2, opt)
175 | util.BiasZero(netD)
176 | util.BiasZero(netG)
177 | util.BiasZero(netE) -- inverse mapping
178 | gradParams:zero()
179 |
180 | -- G should be identity if real2 is fed.
181 | local errI = nil
182 | local identity = nil
183 | if opt.lambda_identity > 0 then
184 | identity = netG:forward(real2):clone()
185 | errI = self.criterionRec:forward(identity, real2) * lambda2 * opt.lambda_identity
186 | local didentity_loss_do = self.criterionRec:backward(identity, real2):mul(lambda2):mul(opt.lambda_identity)
187 | netG:backward(real2, didentity_loss_do)
188 | end
189 |
190 | --- GAN loss: D_A(G_A(A))
191 | local fake = netG:forward(real):clone()
192 | local output = netD:forward(fake)
193 | local errG = self.criterionGAN:forward(output, real_label)
194 | local df_do1 = self.criterionGAN:backward(output, real_label)
195 | local df_d_GAN = netD:updateGradInput(fake, df_do1) --
196 |
197 | -- forward cycle loss
198 | local rec = netE:forward(fake):clone()
199 | local errRec = self.criterionRec:forward(rec, real) * lambda1
200 | local df_do2 = self.criterionRec:backward(rec, real):mul(lambda1)
201 | local df_do_rec = netE:updateGradInput(fake, df_do2)
202 |
203 | netG:backward(real, df_d_GAN + df_do_rec)
204 |
205 | -- backward cycle loss
206 | local fake2 = netE:forward(real2)--:clone()
207 | local rec2 = netG:forward(fake2)--:clone()
208 | local errAdapt = self.criterionRec:forward(rec2, real2) * lambda2
209 | local df_do_coadapt = self.criterionRec:backward(rec2, real2):mul(lambda2)
210 | netG:backward(fake2, df_do_coadapt)
211 |
212 | return gradParams, errG, errRec, errI, fake, rec, identity
213 | end
214 |
215 | function CycleGANModel:fGAx(x, opt)
216 | self.gradparametersG_A, self.errG_A, self.errRec_A, self.errI_A, self.fake_B, self.rec_A, self.identity_B =
217 | self:fGx_basic(x, self.gradparametersG_A, self.netG_A, self.netD_A, self.netG_B, self.real_A, self.real_B,
218 | self.real_label_A, opt.lambda_A, opt.lambda_B, opt)
219 | return self.errG_A, self.gradparametersG_A
220 | end
221 |
222 | function CycleGANModel:fGBx(x, opt)
223 | self.gradparametersG_B, self.errG_B, self.errRec_B, self.errI_B, self.fake_A, self.rec_B, self.identity_A =
224 | self:fGx_basic(x, self.gradparametersG_B, self.netG_B, self.netD_B, self.netG_A, self.real_B, self.real_A,
225 | self.real_label_B, opt.lambda_B, opt.lambda_A, opt)
226 | return self.errG_B, self.gradparametersG_B
227 | end
228 |
229 |
230 | function CycleGANModel:OptimizeParameters(opt)
231 | local fDA = function(x) return self:fDAx(x, opt) end
232 | local fGA = function(x) return self:fGAx(x, opt) end
233 | local fDB = function(x) return self:fDBx(x, opt) end
234 | local fGB = function(x) return self:fGBx(x, opt) end
235 |
236 | optim.adam(fGA, self.parametersG_A, self.optimStateG_A)
237 | optim.adam(fDA, self.parametersD_A, self.optimStateD_A)
238 | optim.adam(fGB, self.parametersG_B, self.optimStateG_B)
239 | optim.adam(fDB, self.parametersD_B, self.optimStateD_B)
240 | end
241 |
242 | function CycleGANModel:RefreshParameters()
243 | self.parametersD_A, self.gradparametersD_A = nil, nil -- nil them to avoid spiking memory
244 | self.parametersG_A, self.gradparametersG_A = nil, nil
245 | self.parametersG_B, self.gradparametersG_B = nil, nil
246 | self.parametersD_B, self.gradparametersD_B = nil, nil
247 | -- define parameters of optimization
248 | self.parametersG_A, self.gradparametersG_A = self.netG_A:getParameters()
249 | self.parametersD_A, self.gradparametersD_A = self.netD_A:getParameters()
250 | self.parametersG_B, self.gradparametersG_B = self.netG_B:getParameters()
251 | self.parametersD_B, self.gradparametersD_B = self.netD_B:getParameters()
252 | end
253 |
254 | function CycleGANModel:Save(prefix, opt)
255 | util.save_model(self.netG_A, prefix .. '_net_G_A.t7', 1)
256 | util.save_model(self.netD_A, prefix .. '_net_D_A.t7', 1)
257 | util.save_model(self.netG_B, prefix .. '_net_G_B.t7', 1)
258 | util.save_model(self.netD_B, prefix .. '_net_D_B.t7', 1)
259 | end
260 |
261 | function CycleGANModel:GetCurrentErrorDescription()
262 | description = ('[A] G: %.4f D: %.4f Rec: %.4f I: %.4f || [B] G: %.4f D: %.4f Rec: %.4f I:%.4f'):format(
263 | self.errG_A and self.errG_A or -1,
264 | self.errD_A and self.errD_A or -1,
265 | self.errRec_A and self.errRec_A or -1,
266 | self.errI_A and self.errI_A or -1,
267 | self.errG_B and self.errG_B or -1,
268 | self.errD_B and self.errD_B or -1,
269 | self.errRec_B and self.errRec_B or -1,
270 | self.errI_B and self.errI_B or -1)
271 | return description
272 | end
273 |
274 | function CycleGANModel:GetCurrentErrors()
275 | local errors = {errG_A=self.errG_A, errD_A=self.errD_A, errRec_A=self.errRec_A, errI_A=self.errI_A,
276 | errG_B=self.errG_B, errD_B=self.errD_B, errRec_B=self.errRec_B, errI_B=self.errI_B}
277 | return errors
278 | end
279 |
280 | -- returns a string that describes the display plot configuration
281 | function CycleGANModel:DisplayPlot(opt)
282 | if opt.lambda_identity > 0 then
283 | return 'errG_A,errD_A,errRec_A,errI_A,errG_B,errD_B,errRec_B,errI_B'
284 | else
285 | return 'errG_A,errD_A,errRec_A,errG_B,errD_B,errRec_B'
286 | end
287 | end
288 |
289 | function CycleGANModel:UpdateLearningRate(opt)
290 | local lrd = opt.lr / opt.niter_decay
291 | local old_lr = self.optimStateD_A['learningRate']
292 | local lr = old_lr - lrd
293 | self.optimStateD_A['learningRate'] = lr
294 | self.optimStateD_B['learningRate'] = lr
295 | self.optimStateG_A['learningRate'] = lr
296 | self.optimStateG_B['learningRate'] = lr
297 | print(('update learning rate: %f -> %f'):format(old_lr, lr))
298 | end
299 |
300 | local function MakeIm3(im)
301 | if im:size(2) == 1 then
302 | local im3 = torch.repeatTensor(im, 1,3,1,1)
303 | return im3
304 | else
305 | return im
306 | end
307 | end
308 |
309 | function CycleGANModel:GetCurrentVisuals(opt, size)
310 | local visuals = {}
311 | table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'})
312 | table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'})
313 | table.insert(visuals, {img=MakeIm3(self.rec_A), label='rec_A'})
314 | if opt.test == 0 and opt.lambda_identity > 0 then
315 | table.insert(visuals, {img=MakeIm3(self.identity_A), label='identity_A'})
316 | end
317 | table.insert(visuals, {img=MakeIm3(self.real_B), label='real_B'})
318 | table.insert(visuals, {img=MakeIm3(self.fake_A), label='fake_A'})
319 | table.insert(visuals, {img=MakeIm3(self.rec_B), label='rec_B'})
320 | if opt.test == 0 and opt.lambda_identity > 0 then
321 | table.insert(visuals, {img=MakeIm3(self.identity_B), label='identity_B'})
322 | end
323 | return visuals
324 | end
325 |
--------------------------------------------------------------------------------
/models/one_direction_test_model.lua:
--------------------------------------------------------------------------------
1 | local class = require 'class'
2 | require 'models.base_model'
3 | require 'models.architectures'
4 | require 'util.image_pool'
5 |
6 | util = paths.dofile('../util/util.lua')
7 | OneDirectionTestModel = class('OneDirectionTestModel', 'BaseModel')
8 |
9 | function OneDirectionTestModel:__init(conf)
10 | BaseModel.__init(self, conf)
11 | conf = conf or {}
12 | end
13 |
14 | function OneDirectionTestModel:model_name()
15 | return 'OneDirectionTestModel'
16 | end
17 |
18 | -- Defines models and networks
19 | function OneDirectionTestModel:Initialize(opt)
20 | -- define tensors
21 | self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
22 |
23 | -- load/define models
24 | self.netG_A = util.load_test_model('G', opt)
25 |
26 | -- setup optnet to save a bit of memory
27 | if opt.use_optnet == 1 then
28 | local optnet = require 'optnet'
29 | local sample_input = torch.randn(1, opt.input_nc, 2, 2)
30 | optnet.optimizeMemory(self.netG_A, sample_input, {inplace=true, reuseBuffers=true})
31 | end
32 |
33 | self:RefreshParameters()
34 |
35 | print('---------- # Learnable Parameters --------------')
36 | print(('G_A = %d'):format(self.parametersG_A:size(1)))
37 | print('------------------------------------------------')
38 | end
39 |
40 | -- Runs the forward pass of the network and
41 | -- saves the result to member variables of the class
42 | function OneDirectionTestModel:Forward(input, opt)
43 | if opt.which_direction == 'BtoA' then
44 | input.real_A = input.real_B:clone()
45 | end
46 |
47 | self.real_A = input.real_A:clone()
48 | if opt.gpu > 0 then
49 | self.real_A = self.real_A:cuda()
50 | end
51 |
52 | self.fake_B = self.netG_A:forward(self.real_A):clone()
53 | end
54 |
55 | function OneDirectionTestModel:RefreshParameters()
56 | self.parametersG_A, self.gradparametersG_A = nil, nil
57 | self.parametersG_A, self.gradparametersG_A = self.netG_A:getParameters()
58 | end
59 |
60 |
61 | local function MakeIm3(im)
62 | if im:size(2) == 1 then
63 | local im3 = torch.repeatTensor(im, 1,3,1,1)
64 | return im3
65 | else
66 | return im
67 | end
68 | end
69 |
70 | function OneDirectionTestModel:GetCurrentVisuals(opt, size)
71 | if not size then
72 | size = opt.display_winsize
73 | end
74 |
75 | local visuals = {}
76 | table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'})
77 | table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'})
78 | return visuals
79 | end
80 |
--------------------------------------------------------------------------------
/models/pix2pix_model.lua:
--------------------------------------------------------------------------------
1 | local class = require 'class'
2 | require 'models.base_model'
3 | require 'models.architectures'
4 | require 'util.image_pool'
5 | util = paths.dofile('../util/util.lua')
6 | Pix2PixModel = class('Pix2PixModel', 'BaseModel')
7 |
8 | function Pix2PixModel:__init(conf)
9 | conf = conf or {}
10 | end
11 |
12 | -- Returns the name of the model
13 | function Pix2PixModel:model_name()
14 | return 'Pix2PixModel'
15 | end
16 |
17 | function Pix2PixModel:InitializeStates()
18 | return {learningRate=opt.lr, beta1=opt.beta1,}
19 | end
20 |
21 | -- Defines models and networks
22 | function Pix2PixModel:Initialize(opt) -- use lsgan
23 | -- define tensors
24 | local d_input_nc = opt.input_nc + opt.output_nc
25 | self.real_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize)
26 | self.fake_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize)
27 | if opt.test == 0 then
28 | self.fakeABPool = ImagePool(opt.pool_size)
29 | end
30 | -- load/define models
31 | self.criterionGAN = nn.MSECriterion()
32 | self.criterionL1 = nn.AbsCriterion()
33 |
34 | local netG, netD = nil, nil
35 | if opt.continue_train == 1 then
36 | if opt.test == 1 then -- only load model G for test
37 | netG = util.load_test_model('G', opt)
38 | else
39 | netG = util.load_model('G', opt)
40 | netD = util.load_model('D', opt)
41 | end
42 | else
43 | netG = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG)
44 | netD = defineD(d_input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, false) -- with sigmoid
45 | end
46 |
47 | self.netD = netD
48 | self.netG = netG
49 |
50 | -- define real/fake labels
51 | if opt.test == 0 then
52 | netD_output_size = self.netD:forward(self.real_AB):size()
53 | self.fake_label = torch.Tensor(netD_output_size):fill(0.0)
54 | self.real_label = torch.Tensor(netD_output_size):fill(1.0) -- no soft smoothing
55 |
56 | self.optimStateD = self:InitializeStates()
57 | self.optimStateG = self:InitializeStates()
58 |
59 | self:RefreshParameters()
60 |
61 | print('---------- # Learnable Parameters --------------')
62 | print(('G = %d'):format(self.parametersG:size(1)))
63 | print(('D = %d'):format(self.parametersD:size(1)))
64 | print('------------------------------------------------')
65 | end
66 |
67 | self.A_idx = {{}, {1, opt.input_nc}, {}, {}}
68 | self.B_idx = {{}, {opt.input_nc+1, opt.input_nc+opt.output_nc}, {}, {}}
69 | end
70 |
71 | -- Runs the forward pass of the network
72 | function Pix2PixModel:Forward(input, opt)
73 | if opt.which_direction == 'BtoA' then
74 | local temp = input.real_A
75 | input.real_A = input.real_B
76 | input.real_B = temp
77 | end
78 |
79 | if opt.test == 0 then
80 | self.real_AB[self.A_idx]:copy(input.real_A)
81 | self.real_AB[self.B_idx]:copy(input.real_B)
82 | self.real_A = self.real_AB[self.A_idx]
83 | self.real_B = self.real_AB[self.B_idx]
84 |
85 | self.fake_AB[self.A_idx]:copy(self.real_A)
86 | self.fake_B = self.netG:forward(self.real_A):clone()
87 | self.fake_AB[self.B_idx]:copy(self.fake_B)
88 | else
89 | if opt.gpu > 0 then
90 | self.real_A = input.real_A:cuda()
91 | self.real_B = input.real_B:cuda()
92 | else
93 | self.real_A = input.real_A:clone()
94 | self.real_B = input.real_B:clone()
95 | end
96 | self.fake_B = self.netG:forward(self.real_A):clone()
97 | end
98 | end
99 |
100 | -- create closure to evaluate f(X) and df/dX of discriminator
101 | function Pix2PixModel:fDx_basic(x, gradParams, netD, netG, real, fake, opt)
102 | util.BiasZero(netD)
103 | util.BiasZero(netG)
104 | gradParams:zero()
105 |
106 | -- Real log(D(B))
107 | local output = netD:forward(real)
108 | local errD_real = self.criterionGAN:forward(output, self.real_label)
109 | local df_do = self.criterionGAN:backward(output, self.real_label)
110 | netD:backward(real, df_do)
111 | -- Fake + log(1 - D(G(A)))
112 | output = netD:forward(fake)
113 | local errD_fake = self.criterionGAN:forward(output, self.fake_label)
114 | local df_do2 = self.criterionGAN:backward(output, self.fake_label)
115 | netD:backward(fake, df_do2)
116 | -- calculate loss
117 | local errD = (errD_real + errD_fake) / 2.0
118 | return errD, gradParams
119 | end
120 |
121 |
122 | function Pix2PixModel:fDx(x, opt)
123 | fake_AB = self.fakeABPool:Query(self.fake_AB)
124 | self.errD, gradParams = self:fDx_basic(x, self.gradParametersD, self.netD, self.netG,
125 | self.real_AB, fake_AB, opt)
126 | return self.errD, gradParams
127 | end
128 |
129 | function Pix2PixModel:fGx_basic(x, netG, netD, real, fake, gradParametersG, opt)
130 | util.BiasZero(netG)
131 | util.BiasZero(netD)
132 | gradParametersG:zero()
133 |
134 | -- First. G(A) should fake the discriminator
135 | local output = netD:forward(fake)
136 | local errG = self.criterionGAN:forward(output, self.real_label)
137 | local dgan_loss_dd = self.criterionGAN:backward(output, self.real_label)
138 | local dgan_loss_do = netD:updateGradInput(fake, dgan_loss_dd)
139 |
140 | -- Second. G(A) should be close to the real
141 | real_B = real[self.B_idx]
142 | real_A = real[self.A_idx]
143 | fake_B = fake[self.B_idx]
144 | local errL1 = self.criterionL1:forward(fake_B, real_B) * opt.lambda_A
145 | local dl1_loss_do = self.criterionL1:backward(fake_B, real_B) * opt.lambda_A
146 | netG:backward(real_A, dgan_loss_do[self.B_idx] + dl1_loss_do)
147 |
148 | return gradParametersG, errG, errL1
149 | end
150 |
151 | function Pix2PixModel:fGx(x, opt)
152 | self.gradParametersG, self.errG, self.errL1 = self:fGx_basic(x, self.netG, self.netD,
153 | self.real_AB, self.fake_AB, self.gradParametersG, opt)
154 | return self.errG, self.gradParametersG
155 | end
156 |
157 | -- Runs the backprop gradient descent
158 | -- Corresponds to a single batch of data
159 | function Pix2PixModel:OptimizeParameters(opt)
160 | local fD = function(x) return self:fDx(x, opt) end
161 | local fG = function(x) return self:fGx(x, opt) end
162 | optim.adam(fD, self.parametersD, self.optimStateD)
163 | optim.adam(fG, self.parametersG, self.optimStateG)
164 | end
165 |
166 | -- This function can be used to reset momentum after each epoch
167 | function Pix2PixModel:RefreshParameters()
168 | self.parametersD, self.gradParametersD = nil, nil -- nil them to avoid spiking memory
169 | self.parametersG, self.gradParametersG = nil, nil
170 |
171 | -- define parameters of optimization
172 | self.parametersG, self.gradParametersG = self.netG:getParameters()
173 | self.parametersD, self.gradParametersD = self.netD:getParameters()
174 | end
175 |
176 | -- This function updates the learning rate; lr for the first opt.niter iterations; graduatlly decreases the lr to 0 for the next opt.niter_decay iterations
177 | function Pix2PixModel:UpdateLearningRate(opt)
178 | local lrd = opt.lr / opt.niter_decay
179 | local old_lr = self.optimStateD['learningRate']
180 | local lr = old_lr - lrd
181 | self.optimStateD['learningRate'] = lr
182 | self.optimStateG['learningRate'] = lr
183 | print(('update learning rate: %f -> %f'):format(old_lr, lr))
184 | end
185 |
186 |
187 | -- Save the current model to the file system
188 | function Pix2PixModel:Save(prefix, opt)
189 | util.save_model(self.netG, prefix .. '_net_G.t7', 1.0)
190 | util.save_model(self.netD, prefix .. '_net_D.t7', 1.0)
191 | end
192 |
193 | -- returns a string that describes the current errors
194 | function Pix2PixModel:GetCurrentErrorDescription()
195 | description = ('G: %.4f D: %.4f L1: %.4f'):format(
196 | self.errG and self.errG or -1, self.errD and self.errD or -1, self.errL1 and self.errL1 or -1)
197 | return description
198 |
199 | end
200 |
201 |
202 | -- returns a string that describes the display plot configuration
203 | function Pix2PixModel:DisplayPlot(opt)
204 | return 'errG,errD,errL1'
205 | end
206 |
207 |
208 | -- returns current errors
209 | function Pix2PixModel:GetCurrentErrors()
210 | local errors = {errG=self.errG, errD=self.errD, errL1=self.errL1}
211 | return errors
212 | end
213 |
214 | -- returns a table of image/label pairs that describe
215 | -- the current results.
216 | -- |return|: a table of table. List of image/label pairs
217 | function Pix2PixModel:GetCurrentVisuals(opt, size)
218 | if not size then
219 | size = opt.display_winsize
220 | end
221 |
222 | local visuals = {}
223 | table.insert(visuals, {img=self.real_A, label='real_A'})
224 | table.insert(visuals, {img=self.fake_B, label='fake_B'})
225 | table.insert(visuals, {img=self.real_B, label='real_B'})
226 |
227 | return visuals
228 | end
229 |
--------------------------------------------------------------------------------
/options.lua:
--------------------------------------------------------------------------------
1 | --------------------------------------------------------------------------------
2 | -- Configure options
3 | --------------------------------------------------------------------------------
4 |
5 | local options = {}
6 | -- options for train
7 | local opt_train = {
8 | DATA_ROOT = '', -- path to images (should have subfolders 'train', 'val', etc)
9 | batchSize = 1, -- # images in batch
10 | loadSize = 143, -- scale images to this size
11 | fineSize = 128, -- then crop to this size
12 | ngf = 64, -- # of gen filters in first conv layer
13 | ndf = 64, -- # of discrim filters in first conv layer
14 | input_nc = 3, -- # of input image channels
15 | output_nc = 3, -- # of output image channels
16 | niter = 100, -- # of iter at starting learning rate
17 | niter_decay = 100, -- # of iter to linearly decay learning rate to zero
18 | lr = 0.0002, -- initial learning rate for adam
19 | beta1 = 0.5, -- momentum term of adam
20 | ntrain = math.huge, -- # of examples per epoch. math.huge for full dataset
21 | flip = 1, -- if flip the images for data argumentation
22 | display_id = 10, -- display window id.
23 | display_winsize = 128, -- display window size
24 | display_freq = 25, -- display the current results every display_freq iterations
25 | gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X
26 | name = '', -- name of the experiment, should generally be passed on the command line
27 | which_direction = 'AtoB', -- AtoB or BtoA
28 | phase = 'train', -- train, val, test, etc
29 | nThreads = 2, -- # threads for loading data
30 | save_epoch_freq = 1, -- save a model every save_epoch_freq epochs (does not overwrite previously saved models)
31 | save_latest_freq = 5000, -- save the latest model every latest_freq sgd iterations (overwrites the previous latest model)
32 | print_freq = 50, -- print the debug information every print_freq iterations
33 | save_display_freq = 2500, -- save the current display of results every save_display_freq_iterations
34 | continue_train = 0, -- if continue training, load the latest model: 1: true, 0: false
35 | serial_batches = 0, -- if 1, takes images in order to make batches, otherwise takes them randomly
36 | checkpoints_dir = './checkpoints', -- models are saved here
37 | cache_dir = './cache', -- cache files are saved here
38 | cudnn = 1, -- set to 0 to not use cudnn
39 | which_model_netD = 'basic', -- selects model to use for netD
40 | which_model_netG = 'resnet_6blocks', -- selects model to use for netG
41 | norm = 'instance', -- batch or instance normalization
42 | n_layers_D = 3, -- only used if which_model_netD=='n_layers'
43 | content_loss = 'pixel', -- content loss type: pixel, vgg
44 | layer_name = 'pixel', -- layer used in content loss (e.g. relu4_2)
45 | lambda_A = 10.0, -- weight for cycle loss (A -> B -> A)
46 | lambda_B = 10.0, -- weight for cycle loss (B -> A -> B)
47 | model = 'cycle_gan', -- which mode to run. 'cycle_gan', 'pix2pix', 'bigan', 'content_gan'
48 | use_lsgan = 1, -- if 1, use least square GAN, if 0, use vanilla GAN
49 | align_data = 0, -- if > 0, use the dataloader for where the images are aligned
50 | pool_size = 50, -- the size of image buffer that stores previously generated images
51 | resize_or_crop = 'resize_and_crop', -- resizing/cropping strategy: resize_and_crop | crop | scale_width | scale_height
52 | lambda_identity = 0.5, -- use identity mapping. Setting opt.lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set opt.lambda_identity = 0.1
53 | use_optnet = 0, -- use optnet to save GPU memory during test
54 | }
55 |
56 | -- options for test
57 | local opt_test = {
58 | DATA_ROOT = '', -- path to images (should have subfolders 'train', 'val', etc)
59 | loadSize = 128, -- scale images to this size
60 | fineSize = 128, -- then crop to this size
61 | flip = 0, -- horizontal mirroring data augmentation
62 | display = 1, -- display samples while training. 0 = false
63 | display_id = 200, -- display window id.
64 | gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X
65 | how_many = 'all', -- how many test images to run (set to all to run on every image found in the data/phase folder)
66 | phase = 'test', -- train, val, test, etc
67 | aspect_ratio = 1.0, -- aspect ratio of result images
68 | norm = 'instance', -- batchnorm or isntance norm
69 | name = '', -- name of experiment, selects which model to run, should generally should be passed on command line
70 | input_nc = 3, -- # of input image channels
71 | output_nc = 3, -- # of output image channels
72 | serial_batches = 1, -- if 1, takes images in order to make batches, otherwise takes them randomly
73 | cudnn = 1, -- set to 0 to not use cudnn (untested)
74 | checkpoints_dir = './checkpoints', -- loads models from here
75 | cache_dir = './cache', -- cache files are saved here
76 | results_dir='./results/', -- saves results here
77 | which_epoch = 'latest', -- which epoch to test? set to 'latest' to use latest cached model
78 | model = 'cycle_gan', -- which mode to run. 'cycle_gan', 'pix2pix', 'bigan', 'content_gan'; to use pretrained model, select `one_direction_test`
79 | align_data = 0, -- if > 0, use the dataloader for pix2pix
80 | which_direction = 'AtoB', -- AtoB or BtoA
81 | resize_or_crop = 'resize_and_crop', -- resizing/cropping strategy: resize_and_crop | crop | scale_width | scale_height
82 | }
83 |
84 | --------------------------------------------------------------------------------
85 | -- util functions
86 | --------------------------------------------------------------------------------
87 | function options.clone(opt)
88 | local copy = {}
89 | for orig_key, orig_value in pairs(opt) do
90 | copy[orig_key] = orig_value
91 | end
92 | return copy
93 | end
94 |
95 | function options.parse_options(mode)
96 | if mode == 'train' then
97 | opt = opt_train
98 | opt.test = 0
99 | elseif mode == 'test' then
100 | opt = opt_test
101 | opt.test = 1
102 | else
103 | print("Invalid option [" .. mode .. "]")
104 | return nil
105 | end
106 |
107 | -- one-line argument parser. parses enviroment variables to override the defaults
108 | for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end
109 | if mode == 'test' then
110 | opt.nThreads = 1
111 | opt.continue_train = 1
112 | opt.batchSize = 1 -- test code only supports batchSize=1
113 | end
114 |
115 | -- print by keys
116 | keyset = {}
117 | for k,v in pairs(opt) do
118 | table.insert(keyset, k)
119 | end
120 | table.sort(keyset)
121 | print("------------------- Options -------------------")
122 | for i,k in ipairs(keyset) do
123 | print(('%+25s: %s'):format(k, opt[k]))
124 | end
125 | print("-----------------------------------------------")
126 |
127 | -- save opt to checkpoints
128 | paths.mkdir(opt.checkpoints_dir)
129 | paths.mkdir(paths.concat(opt.checkpoints_dir, opt.name))
130 | opt.visual_dir = paths.concat(opt.checkpoints_dir, opt.name, 'visuals')
131 | paths.mkdir(opt.visual_dir)
132 | -- save opt to the disk
133 | fd = io.open(paths.concat(opt.checkpoints_dir, opt.name, 'opt_' .. mode .. '.txt'), 'w')
134 | for i,k in ipairs(keyset) do
135 | fd:write(("%+25s: %s\n"):format(k, opt[k]))
136 | end
137 | fd:close()
138 |
139 | return opt
140 | end
141 |
142 |
143 | return options
144 |
--------------------------------------------------------------------------------
/pretrained_models/download_model.sh:
--------------------------------------------------------------------------------
1 | FILE=$1
2 |
3 | echo "Note: available models are apple2orange, facades_photo2label, map2sat, orange2apple, style_cezanne, style_ukiyoe, summer2winter_yosemite, zebra2horse, facades_label2photo, horse2zebra,monet2photo, sat2map, style_monet,style_vangogh, winter2summer_yosemite, iphone2dslr_flower"
4 |
5 | echo "Specified [$FILE]"
6 |
7 | mkdir -p ./checkpoints/${FILE}_pretrained
8 | URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/models/$FILE.t7
9 | MODEL_FILE=./checkpoints/${FILE}_pretrained/latest_net_G.t7
10 | wget -N $URL -O $MODEL_FILE
11 |
--------------------------------------------------------------------------------
/pretrained_models/download_vgg.sh:
--------------------------------------------------------------------------------
1 | URL1=https://people.eecs.berkeley.edu/~taesung_park/projects/CycleGAN/models/places_vgg.caffemodel
2 | MODEL_FILE1=./models/places_vgg.caffemodel
3 | URL2=https://people.eecs.berkeley.edu/~taesung_park/projects/CycleGAN/models/places_vgg.prototxt
4 | MODEL_FILE2=./models/places_vgg.prototxt
5 | wget -N $URL1 -O $MODEL_FILE1
6 | wget -N $URL2 -O $MODEL_FILE2
7 |
--------------------------------------------------------------------------------
/pretrained_models/places_vgg.prototxt:
--------------------------------------------------------------------------------
1 | name: "VGG-Places365"
2 | input: "data"
3 | input_dim: 1
4 | input_dim: 3
5 | input_dim: 224
6 | input_dim: 224
7 | layer {
8 | name: "conv1_1"
9 | type: "Convolution"
10 | bottom: "data"
11 | top: "conv1_1"
12 | param {
13 | lr_mult: 1.0
14 | decay_mult: 1.0
15 | }
16 | param {
17 | lr_mult: 2.0
18 | decay_mult: 0.0
19 | }
20 | convolution_param {
21 | num_output: 64
22 | pad: 1
23 | kernel_size: 3
24 | weight_filler {
25 | type: "gaussian"
26 | std: 0.01
27 | }
28 | bias_filler {
29 | type: "constant"
30 | value: 0.0
31 | }
32 | }
33 | }
34 | layer {
35 | name: "relu1_1"
36 | type: "ReLU"
37 | bottom: "conv1_1"
38 | top: "conv1_1"
39 | }
40 | layer {
41 | name: "conv1_2"
42 | type: "Convolution"
43 | bottom: "conv1_1"
44 | top: "conv1_2"
45 | param {
46 | lr_mult: 1.0
47 | decay_mult: 1.0
48 | }
49 | param {
50 | lr_mult: 2.0
51 | decay_mult: 0.0
52 | }
53 | convolution_param {
54 | num_output: 64
55 | pad: 1
56 | kernel_size: 3
57 | weight_filler {
58 | type: "gaussian"
59 | std: 0.01
60 | }
61 | bias_filler {
62 | type: "constant"
63 | value: 0.0
64 | }
65 | }
66 | }
67 | layer {
68 | name: "relu1_2"
69 | type: "ReLU"
70 | bottom: "conv1_2"
71 | top: "conv1_2"
72 | }
73 | layer {
74 | name: "pool1"
75 | type: "Pooling"
76 | bottom: "conv1_2"
77 | top: "pool1"
78 | pooling_param {
79 | pool: MAX
80 | kernel_size: 2
81 | stride: 2
82 | }
83 | }
84 | layer {
85 | name: "conv2_1"
86 | type: "Convolution"
87 | bottom: "pool1"
88 | top: "conv2_1"
89 | param {
90 | lr_mult: 1.0
91 | decay_mult: 1.0
92 | }
93 | param {
94 | lr_mult: 2.0
95 | decay_mult: 0.0
96 | }
97 | convolution_param {
98 | num_output: 128
99 | pad: 1
100 | kernel_size: 3
101 | weight_filler {
102 | type: "gaussian"
103 | std: 0.01
104 | }
105 | bias_filler {
106 | type: "constant"
107 | value: 0.0
108 | }
109 | }
110 | }
111 | layer {
112 | name: "relu2_1"
113 | type: "ReLU"
114 | bottom: "conv2_1"
115 | top: "conv2_1"
116 | }
117 | layer {
118 | name: "conv2_2"
119 | type: "Convolution"
120 | bottom: "conv2_1"
121 | top: "conv2_2"
122 | param {
123 | lr_mult: 1.0
124 | decay_mult: 1.0
125 | }
126 | param {
127 | lr_mult: 2.0
128 | decay_mult: 0.0
129 | }
130 | convolution_param {
131 | num_output: 128
132 | pad: 1
133 | kernel_size: 3
134 | weight_filler {
135 | type: "gaussian"
136 | std: 0.01
137 | }
138 | bias_filler {
139 | type: "constant"
140 | value: 0.0
141 | }
142 | }
143 | }
144 | layer {
145 | name: "relu2_2"
146 | type: "ReLU"
147 | bottom: "conv2_2"
148 | top: "conv2_2"
149 | }
150 | layer {
151 | name: "pool2"
152 | type: "Pooling"
153 | bottom: "conv2_2"
154 | top: "pool2"
155 | pooling_param {
156 | pool: MAX
157 | kernel_size: 2
158 | stride: 2
159 | }
160 | }
161 | layer {
162 | name: "conv3_1"
163 | type: "Convolution"
164 | bottom: "pool2"
165 | top: "conv3_1"
166 | param {
167 | lr_mult: 1.0
168 | decay_mult: 1.0
169 | }
170 | param {
171 | lr_mult: 2.0
172 | decay_mult: 0.0
173 | }
174 | convolution_param {
175 | num_output: 256
176 | pad: 1
177 | kernel_size: 3
178 | weight_filler {
179 | type: "gaussian"
180 | std: 0.01
181 | }
182 | bias_filler {
183 | type: "constant"
184 | value: 0.0
185 | }
186 | }
187 | }
188 | layer {
189 | name: "relu3_1"
190 | type: "ReLU"
191 | bottom: "conv3_1"
192 | top: "conv3_1"
193 | }
194 | layer {
195 | name: "conv3_2"
196 | type: "Convolution"
197 | bottom: "conv3_1"
198 | top: "conv3_2"
199 | param {
200 | lr_mult: 1.0
201 | decay_mult: 1.0
202 | }
203 | param {
204 | lr_mult: 2.0
205 | decay_mult: 0.0
206 | }
207 | convolution_param {
208 | num_output: 256
209 | pad: 1
210 | kernel_size: 3
211 | weight_filler {
212 | type: "gaussian"
213 | std: 0.01
214 | }
215 | bias_filler {
216 | type: "constant"
217 | value: 0.0
218 | }
219 | }
220 | }
221 | layer {
222 | name: "relu3_2"
223 | type: "ReLU"
224 | bottom: "conv3_2"
225 | top: "conv3_2"
226 | }
227 | layer {
228 | name: "conv3_3"
229 | type: "Convolution"
230 | bottom: "conv3_2"
231 | top: "conv3_3"
232 | param {
233 | lr_mult: 1.0
234 | decay_mult: 1.0
235 | }
236 | param {
237 | lr_mult: 2.0
238 | decay_mult: 0.0
239 | }
240 | convolution_param {
241 | num_output: 256
242 | pad: 1
243 | kernel_size: 3
244 | weight_filler {
245 | type: "gaussian"
246 | std: 0.01
247 | }
248 | bias_filler {
249 | type: "constant"
250 | value: 0.0
251 | }
252 | }
253 | }
254 | layer {
255 | name: "relu3_3"
256 | type: "ReLU"
257 | bottom: "conv3_3"
258 | top: "conv3_3"
259 | }
260 | layer {
261 | name: "pool3"
262 | type: "Pooling"
263 | bottom: "conv3_3"
264 | top: "pool3"
265 | pooling_param {
266 | pool: MAX
267 | kernel_size: 2
268 | stride: 2
269 | }
270 | }
271 | layer {
272 | name: "conv4_1"
273 | type: "Convolution"
274 | bottom: "pool3"
275 | top: "conv4_1"
276 | param {
277 | lr_mult: 1.0
278 | decay_mult: 1.0
279 | }
280 | param {
281 | lr_mult: 2.0
282 | decay_mult: 0.0
283 | }
284 | convolution_param {
285 | num_output: 512
286 | pad: 1
287 | kernel_size: 3
288 | weight_filler {
289 | type: "gaussian"
290 | std: 0.01
291 | }
292 | bias_filler {
293 | type: "constant"
294 | value: 0.0
295 | }
296 | }
297 | }
298 | layer {
299 | name: "relu4_1"
300 | type: "ReLU"
301 | bottom: "conv4_1"
302 | top: "conv4_1"
303 | }
304 | layer {
305 | name: "conv4_2"
306 | type: "Convolution"
307 | bottom: "conv4_1"
308 | top: "conv4_2"
309 | param {
310 | lr_mult: 1.0
311 | decay_mult: 1.0
312 | }
313 | param {
314 | lr_mult: 2.0
315 | decay_mult: 0.0
316 | }
317 | convolution_param {
318 | num_output: 512
319 | pad: 1
320 | kernel_size: 3
321 | weight_filler {
322 | type: "gaussian"
323 | std: 0.01
324 | }
325 | bias_filler {
326 | type: "constant"
327 | value: 0.0
328 | }
329 | }
330 | }
331 | layer {
332 | name: "relu4_2"
333 | type: "ReLU"
334 | bottom: "conv4_2"
335 | top: "conv4_2"
336 | }
337 | layer {
338 | name: "conv4_3"
339 | type: "Convolution"
340 | bottom: "conv4_2"
341 | top: "conv4_3"
342 | param {
343 | lr_mult: 1.0
344 | decay_mult: 1.0
345 | }
346 | param {
347 | lr_mult: 2.0
348 | decay_mult: 0.0
349 | }
350 | convolution_param {
351 | num_output: 512
352 | pad: 1
353 | kernel_size: 3
354 | weight_filler {
355 | type: "gaussian"
356 | std: 0.01
357 | }
358 | bias_filler {
359 | type: "constant"
360 | value: 0.0
361 | }
362 | }
363 | }
364 | layer {
365 | name: "relu4_3"
366 | type: "ReLU"
367 | bottom: "conv4_3"
368 | top: "conv4_3"
369 | }
370 | layer {
371 | name: "pool4"
372 | type: "Pooling"
373 | bottom: "conv4_3"
374 | top: "pool4"
375 | pooling_param {
376 | pool: MAX
377 | kernel_size: 2
378 | stride: 2
379 | }
380 | }
381 | layer {
382 | name: "conv5_1"
383 | type: "Convolution"
384 | bottom: "pool4"
385 | top: "conv5_1"
386 | param {
387 | lr_mult: 1.0
388 | decay_mult: 1.0
389 | }
390 | param {
391 | lr_mult: 2.0
392 | decay_mult: 0.0
393 | }
394 | convolution_param {
395 | num_output: 512
396 | pad: 1
397 | kernel_size: 3
398 | weight_filler {
399 | type: "gaussian"
400 | std: 0.01
401 | }
402 | bias_filler {
403 | type: "constant"
404 | value: 0.0
405 | }
406 | }
407 | }
408 | layer {
409 | name: "relu5_1"
410 | type: "ReLU"
411 | bottom: "conv5_1"
412 | top: "conv5_1"
413 | }
414 | layer {
415 | name: "conv5_2"
416 | type: "Convolution"
417 | bottom: "conv5_1"
418 | top: "conv5_2"
419 | param {
420 | lr_mult: 1.0
421 | decay_mult: 1.0
422 | }
423 | param {
424 | lr_mult: 2.0
425 | decay_mult: 0.0
426 | }
427 | convolution_param {
428 | num_output: 512
429 | pad: 1
430 | kernel_size: 3
431 | weight_filler {
432 | type: "gaussian"
433 | std: 0.01
434 | }
435 | bias_filler {
436 | type: "constant"
437 | value: 0.0
438 | }
439 | }
440 | }
441 | layer {
442 | name: "relu5_2"
443 | type: "ReLU"
444 | bottom: "conv5_2"
445 | top: "conv5_2"
446 | }
447 | layer {
448 | name: "conv5_3"
449 | type: "Convolution"
450 | bottom: "conv5_2"
451 | top: "conv5_3"
452 | param {
453 | lr_mult: 1.0
454 | decay_mult: 1.0
455 | }
456 | param {
457 | lr_mult: 2.0
458 | decay_mult: 0.0
459 | }
460 | convolution_param {
461 | num_output: 512
462 | pad: 1
463 | kernel_size: 3
464 | weight_filler {
465 | type: "gaussian"
466 | std: 0.01
467 | }
468 | bias_filler {
469 | type: "constant"
470 | value: 0.0
471 | }
472 | }
473 | }
474 | layer {
475 | name: "relu5_3"
476 | type: "ReLU"
477 | bottom: "conv5_3"
478 | top: "conv5_3"
479 | }
480 | layer {
481 | name: "pool5"
482 | type: "Pooling"
483 | bottom: "conv5_3"
484 | top: "pool5"
485 | pooling_param {
486 | pool: MAX
487 | kernel_size: 2
488 | stride: 2
489 | }
490 | }
491 | layer {
492 | name: "fc6"
493 | type: "InnerProduct"
494 | bottom: "pool5"
495 | top: "fc6"
496 | param {
497 | lr_mult: 1.0
498 | decay_mult: 1.0
499 | }
500 | param {
501 | lr_mult: 2.0
502 | decay_mult: 0.0
503 | }
504 | inner_product_param {
505 | num_output: 4096
506 | weight_filler {
507 | type: "gaussian"
508 | std: 0.01
509 | }
510 | bias_filler {
511 | type: "constant"
512 | value: 0.0
513 | }
514 | }
515 | }
516 | layer {
517 | name: "relu6"
518 | type: "ReLU"
519 | bottom: "fc6"
520 | top: "fc6"
521 | }
522 | layer {
523 | name: "drop6"
524 | type: "Dropout"
525 | bottom: "fc6"
526 | top: "fc6"
527 | dropout_param {
528 | dropout_ratio: 0.5
529 | }
530 | }
531 | layer {
532 | name: "fc7"
533 | type: "InnerProduct"
534 | bottom: "fc6"
535 | top: "fc7"
536 | param {
537 | lr_mult: 1.0
538 | decay_mult: 1.0
539 | }
540 | param {
541 | lr_mult: 2.0
542 | decay_mult: 0.0
543 | }
544 | inner_product_param {
545 | num_output: 4096
546 | weight_filler {
547 | type: "gaussian"
548 | std: 0.01
549 | }
550 | bias_filler {
551 | type: "constant"
552 | value: 0.0
553 | }
554 | }
555 | }
556 | layer {
557 | name: "relu7"
558 | type: "ReLU"
559 | bottom: "fc7"
560 | top: "fc7"
561 | }
562 | layer {
563 | name: "drop7"
564 | type: "Dropout"
565 | bottom: "fc7"
566 | top: "fc7"
567 | dropout_param {
568 | dropout_ratio: 0.5
569 | }
570 | }
571 | layer {
572 | name: "fc8a"
573 | type: "InnerProduct"
574 | bottom: "fc7"
575 | top: "fc8a"
576 | param {
577 | lr_mult: 1.0
578 | decay_mult: 1.0
579 | }
580 | param {
581 | lr_mult: 2.0
582 | decay_mult: 0.0
583 | }
584 | inner_product_param {
585 | num_output: 365
586 | }
587 | }
588 | layer {
589 | name: "prob"
590 | type: "Softmax"
591 | bottom: "fc8a"
592 | top: "prob"
593 | }
594 |
--------------------------------------------------------------------------------
/test.lua:
--------------------------------------------------------------------------------
1 | -- usage: DATA_ROOT=/path/to/data/ name=expt1 which_direction=BtoA th test.lua
2 | --
3 | -- code derived from https://github.com/soumith/dcgan.torch and https://github.com/phillipi/pix2pix
4 | require 'image'
5 | require 'nn'
6 | require 'nngraph'
7 | require 'models.architectures'
8 |
9 |
10 | util = paths.dofile('util/util.lua')
11 | options = require 'options'
12 | opt = options.parse_options('test')
13 |
14 | -- initialize torch GPU/CPU mode
15 | if opt.gpu > 0 then
16 | require 'cutorch'
17 | require 'cunn'
18 | cutorch.setDevice(opt.gpu)
19 | print ("GPU Mode")
20 | torch.setdefaulttensortype('torch.CudaTensor')
21 | else
22 | torch.setdefaulttensortype('torch.FloatTensor')
23 | print ("CPU Mode")
24 | end
25 |
26 | -- setup visualization
27 | visualizer = require 'util/visualizer'
28 |
29 | function TableConcat(t1,t2)
30 | for i=1,#t2 do
31 | t1[#t1+1] = t2[i]
32 | end
33 | return t1
34 | end
35 |
36 |
37 | -- load data
38 | local data_loader = nil
39 | if opt.align_data > 0 then
40 | require 'data.aligned_data_loader'
41 | data_loader = AlignedDataLoader()
42 | else
43 | require 'data.unaligned_data_loader'
44 | data_loader = UnalignedDataLoader()
45 | end
46 | print( "DataLoader " .. data_loader:name() .. " was created.")
47 | data_loader:Initialize(opt)
48 |
49 | if opt.how_many == 'all' then
50 | opt.how_many = data_loader:size()
51 | end
52 |
53 | opt.how_many = math.min(opt.how_many, data_loader:size())
54 |
55 | -- set batch/instance normalization
56 | set_normalization(opt.norm)
57 |
58 | -- load model
59 | opt.continue_train = 1
60 | -- define model
61 | if opt.model == 'cycle_gan' then
62 | require 'models.cycle_gan_model'
63 | model = CycleGANModel()
64 | elseif opt.model == 'one_direction_test' then
65 | require 'models.one_direction_test_model'
66 | model = OneDirectionTestModel()
67 | elseif opt.model == 'pix2pix' then
68 | require 'models.pix2pix_model'
69 | model = Pix2PixModel()
70 | elseif opt.model == 'bigan' then
71 | require 'models.bigan_model'
72 | model = BiGANModel()
73 | elseif opt.model == 'content_gan' then
74 | require 'models.content_gan_model'
75 | model = ContentGANModel()
76 | else
77 | error('Please specify a correct model')
78 | end
79 | model:Initialize(opt)
80 |
81 | local pathsA = {} -- paths to images A tested on
82 | local pathsB = {} -- paths to images B tested on
83 | local web_dir = paths.concat(opt.results_dir, opt.name .. '/' .. opt.which_epoch .. '_' .. opt.phase)
84 | paths.mkdir(web_dir)
85 | local image_dir = paths.concat(web_dir, 'images')
86 | paths.mkdir(image_dir)
87 | s1 = opt.fineSize
88 | s2 = opt.fineSize / opt.aspect_ratio
89 |
90 | visuals = {}
91 |
92 | for n = 1, math.floor(opt.how_many) do
93 | print('processing batch ' .. n)
94 | local cur_dataA, cur_dataB, cur_pathsA, cur_pathsB = data_loader:GetNextBatch()
95 |
96 | cur_pathsA = util.basename_batch(cur_pathsA)
97 | cur_pathsB = util.basename_batch(cur_pathsB)
98 | print('pathsA', cur_pathsA)
99 | print('pathsB', cur_PathsB)
100 | model:Forward({real_A=cur_dataA, real_B=cur_dataB}, opt)
101 |
102 | visuals = model:GetCurrentVisuals(opt, opt.fineSize)
103 |
104 | for i,visual in ipairs(visuals) do
105 | if opt.resize_or_crop == 'scale_width' or opt.resize_or_crop == 'scale_height' then
106 | s1 = nil
107 | s2 = nil
108 | end
109 | visualizer.save_images(visual.img, paths.concat(image_dir, visual.label), {string.gsub(cur_pathsA[1],'.jpg','.png')}, s1, s2)
110 | end
111 |
112 |
113 | print('Saved images to: ', image_dir)
114 | pathsA = TableConcat(pathsA, cur_pathsA)
115 | pathsB = TableConcat(pathsB, cur_pathsB)
116 | end
117 |
118 | labels = {}
119 | for i,visual in ipairs(visuals) do
120 | table.insert(labels, visual.label)
121 | end
122 |
123 | -- make webpage
124 | io.output(paths.concat(web_dir, 'index.html'))
125 | io.write('')
126 | io.write(' Image | ')
127 | for i = 1, #labels do
128 | io.write('' .. labels[i] .. ' | ')
129 | end
130 | io.write('
')
131 |
132 | for n = 1,math.floor(opt.how_many) do
133 | io.write('')
134 | io.write('' .. tostring(n) .. ' | ')
135 | for j = 1, #labels do
136 | label = labels[j]
137 | io.write(' .. ') | ')
138 | end
139 | io.write('
')
140 | end
141 |
142 | io.write('
')
143 |
--------------------------------------------------------------------------------
/train.lua:
--------------------------------------------------------------------------------
1 | -- usage example: DATA_ROOT=/path/to/data/ which_direction=BtoA name=expt1 th train.lua
2 | -- code derived from https://github.com/soumith/dcgan.torch and https://github.com/phillipi/pix2pix
3 |
4 | require 'torch'
5 | require 'nn'
6 | require 'optim'
7 | util = paths.dofile('util/util.lua')
8 | content = paths.dofile('util/content_loss.lua')
9 | require 'image'
10 | require 'models.architectures'
11 |
12 | -- load configuration file
13 | options = require 'options'
14 | opt = options.parse_options('train')
15 |
16 | -- setup visualization
17 | visualizer = require 'util/visualizer'
18 |
19 | -- initialize torch GPU/CPU mode
20 | if opt.gpu > 0 then
21 | require 'cutorch'
22 | require 'cunn'
23 | cutorch.setDevice(opt.gpu)
24 | print ("GPU Mode")
25 | torch.setdefaulttensortype('torch.CudaTensor')
26 | else
27 | torch.setdefaulttensortype('torch.FloatTensor')
28 | print ("CPU Mode")
29 | end
30 |
31 | -- load data
32 | local data_loader = nil
33 | if opt.align_data > 0 then
34 | require 'data.aligned_data_loader'
35 | data_loader = AlignedDataLoader()
36 | else
37 | require 'data.unaligned_data_loader'
38 | data_loader = UnalignedDataLoader()
39 | end
40 | print( "DataLoader " .. data_loader:name() .. " was created.")
41 | data_loader:Initialize(opt)
42 |
43 | -- set batch/instance normalization
44 | set_normalization(opt.norm)
45 |
46 | --- timer
47 | local epoch_tm = torch.Timer()
48 | local tm = torch.Timer()
49 |
50 | -- define model
51 | local model = nil
52 | local display_plot = nil
53 | if opt.model == 'cycle_gan' then
54 | assert(data_loader:name() == 'UnalignedDataLoader')
55 | require 'models.cycle_gan_model'
56 | model = CycleGANModel()
57 | elseif opt.model == 'pix2pix' then
58 | require 'models.pix2pix_model'
59 | assert(data_loader:name() == 'AlignedDataLoader')
60 | model = Pix2PixModel()
61 | elseif opt.model == 'bigan' then
62 | assert(data_loader:name() == 'UnalignedDataLoader')
63 | require 'models.bigan_model'
64 | model = BiGANModel()
65 | elseif opt.model == 'content_gan' then
66 | require 'models.content_gan_model'
67 | assert(data_loader:name() == 'UnalignedDataLoader')
68 | model = ContentGANModel()
69 | else
70 | error('Please specify a correct model')
71 | end
72 |
73 | -- print the model name
74 | print('Model ' .. model:model_name() .. ' was specified.')
75 | model:Initialize(opt)
76 |
77 | -- set up the loss plot
78 | require 'util/plot_util'
79 | plotUtil = PlotUtil()
80 | display_plot = model:DisplayPlot(opt)
81 | plotUtil:Initialize(display_plot, opt.display_id, opt.name)
82 |
83 | --------------------------------------------------------------------------------
84 | -- Helper Functions
85 | --------------------------------------------------------------------------------
86 | function visualize_current_results()
87 | local visuals = model:GetCurrentVisuals(opt)
88 | for i,visual in ipairs(visuals) do
89 | visualizer.disp_image(visual.img, opt.display_winsize,
90 | opt.display_id+i, opt.name .. ' ' .. visual.label)
91 | end
92 | end
93 |
94 | function save_current_results(epoch, counter)
95 | local visuals = model:GetCurrentVisuals(opt)
96 | for i,visual in ipairs(visuals) do
97 | output_path = paths.concat(opt.visual_dir, 'train_epoch' .. epoch .. '_iter' .. counter .. '_' .. visual.label .. '.jpg')
98 | visualizer.save_results(visual.img, output_path)
99 | end
100 | end
101 |
102 | function print_current_errors(epoch, counter_in_epoch)
103 | print(('Epoch: [%d][%8d / %8d]\t Time: %.3f DataTime: %.3f '
104 | .. '%s'):
105 | format(epoch, ((counter_in_epoch-1) / opt.batchSize),
106 | math.floor(math.min(data_loader:size(), opt.ntrain) / opt.batchSize),
107 | tm:time().real / opt.batchSize,
108 | data_loader:time_elapsed_to_fetch_data() / opt.batchSize,
109 | model:GetCurrentErrorDescription()
110 | ))
111 | end
112 |
113 | function plot_current_errors(epoch, counter_ratio, opt)
114 | local errs = model:GetCurrentErrors(opt)
115 | local plot_vals = { epoch + counter_ratio}
116 | plotUtil:Display(plot_vals, errs)
117 | end
118 |
119 | --------------------------------------------------------------------------------
120 | -- Main Training Loop
121 | --------------------------------------------------------------------------------
122 | local counter = 0
123 | local num_batches = math.floor(math.min(data_loader:size(), opt.ntrain) / opt.batchSize)
124 | print('#training iterations: ' .. opt.niter+opt.niter_decay )
125 |
126 | for epoch = 1, opt.niter+opt.niter_decay do
127 | epoch_tm:reset()
128 | for counter_in_epoch = 1, math.min(data_loader:size(), opt.ntrain), opt.batchSize do
129 | tm:reset()
130 | -- load a batch and run G on that batch
131 | local real_dataA, real_dataB, _, _ = data_loader:GetNextBatch()
132 |
133 | model:Forward({real_A=real_dataA, real_B=real_dataB}, opt)
134 | -- run forward pass
135 | opt.counter = counter
136 | -- run backward pass
137 | model:OptimizeParameters(opt)
138 | -- display on the web server
139 | if counter % opt.display_freq == 0 and opt.display_id > 0 then
140 | visualize_current_results()
141 | end
142 |
143 | -- logging
144 | if counter % opt.print_freq == 0 then
145 | print_current_errors(epoch, counter_in_epoch)
146 | plot_current_errors(epoch, counter_in_epoch/num_batches, opt)
147 | end
148 |
149 | -- save latest model
150 | if counter % opt.save_latest_freq == 0 and counter > 0 then
151 | print(('saving the latest model (epoch %d, iters %d)'):format(epoch, counter))
152 | model:Save('latest', opt)
153 | end
154 |
155 | -- save latest results
156 | if counter % opt.save_display_freq == 0 then
157 | save_current_results(epoch, counter)
158 | end
159 | counter = counter + 1
160 | end
161 |
162 | -- save model at the end of epoch
163 | if epoch % opt.save_epoch_freq == 0 then
164 | print(('saving the model (epoch %d, iters %d)'):format(epoch, counter))
165 | model:Save('latest', opt)
166 | model:Save(epoch, opt)
167 | end
168 | -- print the timing information after each epoch
169 | print(('End of epoch %d / %d \t Time Taken: %.3f'):
170 | format(epoch, opt.niter+opt.niter_decay, epoch_tm:time().real))
171 |
172 | -- update learning rate
173 | if epoch > opt.niter then
174 | model:UpdateLearningRate(opt)
175 | end
176 | -- refresh parameters
177 | model:RefreshParameters(opt)
178 | end
179 |
--------------------------------------------------------------------------------
/util/InstanceNormalization.lua:
--------------------------------------------------------------------------------
1 | require 'nn'
2 |
3 | --[[
4 | Implements instance normalization as described in the paper
5 |
6 | Instance Normalization: The Missing Ingredient for Fast Stylization
7 | Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky
8 | https://arxiv.org/abs/1607.08022
9 | This implementation is based on
10 | https://github.com/DmitryUlyanov/texture_nets
11 | ]]
12 |
13 | local InstanceNormalization, parent = torch.class('nn.InstanceNormalization', 'nn.Module')
14 |
15 | function InstanceNormalization:__init(nOutput, eps, momentum, affine)
16 | parent.__init(self)
17 | self.running_mean = torch.zeros(nOutput)
18 | self.running_var = torch.ones(nOutput)
19 |
20 | self.eps = eps or 1e-5
21 | self.momentum = momentum or 0.0
22 | if affine ~= nil then
23 | assert(type(affine) == 'boolean', 'affine has to be true/false')
24 | self.affine = affine
25 | else
26 | self.affine = true
27 | end
28 |
29 | self.nOutput = nOutput
30 | self.prev_batch_size = -1
31 |
32 | if self.affine then
33 | self.weight = torch.Tensor(nOutput):uniform()
34 | self.bias = torch.Tensor(nOutput):zero()
35 | self.gradWeight = torch.Tensor(nOutput)
36 | self.gradBias = torch.Tensor(nOutput)
37 | end
38 | end
39 |
40 | function InstanceNormalization:updateOutput(input)
41 | self.output = self.output or input.new()
42 | assert(input:size(2) == self.nOutput)
43 |
44 | local batch_size = input:size(1)
45 |
46 | if batch_size ~= self.prev_batch_size or (self.bn and self:type() ~= self.bn:type()) then
47 | self.bn = nn.SpatialBatchNormalization(input:size(1)*input:size(2), self.eps, self.momentum, self.affine)
48 | self.bn:type(self:type())
49 | self.bn.running_mean:copy(self.running_mean:repeatTensor(batch_size))
50 | self.bn.running_var:copy(self.running_var:repeatTensor(batch_size))
51 |
52 | self.prev_batch_size = input:size(1)
53 | end
54 |
55 | -- Get statistics
56 | self.running_mean:copy(self.bn.running_mean:view(input:size(1),self.nOutput):mean(1))
57 | self.running_var:copy(self.bn.running_var:view(input:size(1),self.nOutput):mean(1))
58 |
59 | -- Set params for BN
60 | if self.affine then
61 | self.bn.weight:copy(self.weight:repeatTensor(batch_size))
62 | self.bn.bias:copy(self.bias:repeatTensor(batch_size))
63 | end
64 |
65 | local input_1obj = input:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4))
66 | self.output = self.bn:forward(input_1obj):viewAs(input)
67 |
68 | return self.output
69 | end
70 |
71 | function InstanceNormalization:updateGradInput(input, gradOutput)
72 | self.gradInput = self.gradInput or gradOutput.new()
73 |
74 | assert(self.bn)
75 |
76 | local input_1obj = input:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4))
77 | local gradOutput_1obj = gradOutput:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4))
78 |
79 | if self.affine then
80 | self.bn.gradWeight:zero()
81 | self.bn.gradBias:zero()
82 | end
83 |
84 | self.gradInput = self.bn:backward(input_1obj, gradOutput_1obj):viewAs(input)
85 |
86 | if self.affine then
87 | self.gradWeight:add(self.bn.gradWeight:view(input:size(1),self.nOutput):sum(1))
88 | self.gradBias:add(self.bn.gradBias:view(input:size(1),self.nOutput):sum(1))
89 | end
90 | return self.gradInput
91 | end
92 |
93 | function InstanceNormalization:clearState()
94 | self.output = self.output.new()
95 | self.gradInput = self.gradInput.new()
96 |
97 | if self.bn then
98 | self.bn:clearState()
99 | end
100 | end
101 |
102 | function InstanceNormalization:evaluate()
103 | end
104 |
105 | function InstanceNormalization:training()
106 | end
107 |
--------------------------------------------------------------------------------
/util/VGG_preprocess.lua:
--------------------------------------------------------------------------------
1 | -- define nn module for VGG postprocessing
2 | local VGG_postprocess, parent = torch.class('nn.VGG_postprocess', 'nn.Module')
3 |
4 | function VGG_postprocess:__init()
5 | parent.__init(self)
6 | end
7 |
8 | function VGG_postprocess:updateOutput(input)
9 | self.output = input:add(1):mul(127.5)
10 | -- print(self.output:max(), self.output:min())
11 | if self.output:max() > 255 or self.output:min() < 0 then
12 | print(self.output:min(), self.output:max())
13 | end
14 | -- assert(self.output:min()>=0,"badly scaled inputs")
15 | -- assert(self.output:max()<=255,"badly scaled inputs")
16 |
17 | local mean_pixel = torch.FloatTensor({103.939, 116.779, 123.68})
18 | mean_pixel = mean_pixel:reshape(1,3,1,1)
19 | mean_pixel = mean_pixel:repeatTensor(input:size(1), 1, input:size(3), input:size(4)):cuda()
20 | self.output:add(-1, mean_pixel)
21 | return self.output
22 | end
23 |
24 | function VGG_postprocess:updateGradInput(input, gradOutput)
25 | self.gradInput = gradOutput:div(127.5)
26 | return self.gradInput
27 | end
28 |
--------------------------------------------------------------------------------
/util/content_loss.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 | local content = {}
4 |
5 | function content.defineVGG(content_layer)
6 | local contentFunc = nn.Sequential()
7 | require 'loadcaffe'
8 | require 'util/VGG_preprocess'
9 | cnn = loadcaffe.load('../models/vgg.prototxt', '../models/vgg.caffemodel', 'cudnn')
10 | contentFunc:add(nn.SpatialUpSamplingBilinear({oheight=224, owidth=224}))
11 | contentFunc:add(nn.VGG_postprocess())
12 | for i = 1, #cnn do
13 | local layer = cnn:get(i):clone()
14 | local name = layer.name
15 | local layer_type = torch.type(layer)
16 | contentFunc:add(layer)
17 | if name == content_layer then
18 | print("Setting up content layer: ", layer.name)
19 | break
20 | end
21 | end
22 | cnn = nil
23 | collectgarbage()
24 | print(contentFunc)
25 | return contentFunc
26 | end
27 |
28 | function content.defineAlexNet(content_layer)
29 | local contentFunc = nn.Sequential()
30 | require 'loadcaffe'
31 | require 'util/VGG_preprocess'
32 | cnn = loadcaffe.load('../models/alexnet.prototxt', '../models/alexnet.caffemodel', 'cudnn')
33 | contentFunc:add(nn.SpatialUpSamplingBilinear({oheight=224, owidth=224}))
34 | contentFunc:add(nn.VGG_postprocess())
35 | for i = 1, #cnn do
36 | local layer = cnn:get(i):clone()
37 | local name = layer.name
38 | local layer_type = torch.type(layer)
39 | contentFunc:add(layer)
40 | if name == content_layer then
41 | print("Setting up content layer: ", layer.name)
42 | break
43 | end
44 | end
45 | cnn = nil
46 | collectgarbage()
47 | print(contentFunc)
48 | return contentFunc
49 | end
50 |
51 |
52 |
53 | function content.defineContent(content_loss, layer_name)
54 | -- print('content_loss_define', content_loss)
55 | if content_loss == 'pixel' or content_loss == 'none' then
56 | return nil
57 | elseif content_loss == 'vgg' then
58 | return content.defineVGG(layer_name)
59 | else
60 | print("unsupported content loss")
61 | return nil
62 | end
63 | end
64 |
65 |
66 | function content.lossUpdate(criterionContent, real_source, fake_target, contentFunc, loss_type, weight)
67 | if loss_type == 'none' then
68 | local errCont = 0.0
69 | local df_d_content = torch.zeros(fake_target:size())
70 | return errCont, df_d_content
71 | elseif loss_type == 'pixel' then
72 | local errCont = criterionContent:forward(fake_target, real_source) * weight
73 | local df_do_content = criterionContent:backward(fake_target, real_source)*weight
74 | return errCont, df_do_content
75 | elseif loss_type == 'vgg' then
76 | local f_fake = contentFunc:forward(fake_target):clone()
77 | local f_real = contentFunc:forward(real_source):clone()
78 | local errCont = criterionContent:forward(f_fake, f_real) * weight
79 | local df_do_tmp = criterionContent:backward(f_fake, f_real) * weight
80 | local df_do_content = contentFunc:updateGradInput(fake_target, df_do_tmp)--:mul(weight)
81 | return errCont, df_do_content
82 | else error("unsupported content loss")
83 | end
84 | end
85 |
86 |
87 | return content
88 |
--------------------------------------------------------------------------------
/util/cudnn_convert_custom.lua:
--------------------------------------------------------------------------------
1 | -- modified from https://github.com/NVIDIA/torch-cudnn/blob/master/convert.lua
2 | -- removed error on nngraph
3 |
4 | -- modules that can be converted to nn seamlessly
5 | local layer_list = {
6 | 'BatchNormalization',
7 | 'SpatialBatchNormalization',
8 | 'SpatialConvolution',
9 | 'SpatialCrossMapLRN',
10 | 'SpatialFullConvolution',
11 | 'SpatialMaxPooling',
12 | 'SpatialAveragePooling',
13 | 'ReLU',
14 | 'Tanh',
15 | 'Sigmoid',
16 | 'SoftMax',
17 | 'LogSoftMax',
18 | 'VolumetricBatchNormalization',
19 | 'VolumetricConvolution',
20 | 'VolumetricFullConvolution',
21 | 'VolumetricMaxPooling',
22 | 'VolumetricAveragePooling',
23 | }
24 |
25 | -- goes over a given net and converts all layers to dst backend
26 | -- for example: net = cudnn_convert_custom(net, cudnn)
27 | -- same as cudnn.convert with gModule check commented out
28 | function cudnn_convert_custom(net, dst, exclusion_fn)
29 | return net:replace(function(x)
30 | --if torch.type(x) == 'nn.gModule' then
31 | -- io.stderr:write('Warning: cudnn.convert does not work with nngraph yet. Ignoring nn.gModule')
32 | -- return x
33 | --end
34 | local y = 0
35 | local src = dst == nn and cudnn or nn
36 | local src_prefix = src == nn and 'nn.' or 'cudnn.'
37 | local dst_prefix = dst == nn and 'nn.' or 'cudnn.'
38 |
39 | local function convert(v)
40 | local y = {}
41 | torch.setmetatable(y, dst_prefix..v)
42 | if v == 'ReLU' then y = dst.ReLU() end -- because parameters
43 | for k,u in pairs(x) do y[k] = u end
44 | if src == cudnn and x.clearDesc then x.clearDesc(y) end
45 | if src == cudnn and v == 'SpatialAveragePooling' then
46 | y.divide = true
47 | y.count_include_pad = v.mode == 'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING'
48 | end
49 | if src == nn and string.find(v, 'Convolution') then
50 | y.groups = 1
51 | end
52 | return y
53 | end
54 |
55 | if exclusion_fn and exclusion_fn(x) then
56 | return x
57 | end
58 | local t = torch.typename(x)
59 | if t == 'nn.SpatialConvolutionMM' then
60 | y = convert('SpatialConvolution')
61 | elseif t == 'inn.SpatialCrossResponseNormalization' then
62 | y = convert('SpatialCrossMapLRN')
63 | else
64 | for i,v in ipairs(layer_list) do
65 | if torch.typename(x) == src_prefix..v then
66 | y = convert(v)
67 | end
68 | end
69 | end
70 | return y == 0 and x or y
71 | end)
72 | end
73 |
--------------------------------------------------------------------------------
/util/image_pool.lua:
--------------------------------------------------------------------------------
1 | local class = require 'class'
2 | ImagePool= class('ImagePool')
3 |
4 | require 'torch'
5 | require 'image'
6 |
7 | function ImagePool:__init(pool_size)
8 | self.pool_size = pool_size
9 | if pool_size > 0 then
10 | self.num_imgs = 0
11 | self.images = {}
12 | end
13 | end
14 |
15 | function ImagePool:model_name()
16 | return 'ImagePool'
17 | end
18 | --
19 | -- function ImagePool:Initialize(pool_size)
20 | -- -- torch.manualSeed(0)
21 | -- -- assert(pool_size > 0)
22 | -- self.pool_size = pool_size
23 | -- if pool_size > 0 then
24 | -- self.num_imgs = 0
25 | -- self.images = {}
26 | -- end
27 | -- end
28 |
29 | function ImagePool:Query(image)
30 | -- print('query image')
31 | if self.pool_size == 0 then
32 | -- print('get identical image')
33 | return image
34 | end
35 | if self.num_imgs < self.pool_size then
36 | -- self.images.insert(image:clone())
37 | self.num_imgs = self.num_imgs + 1
38 | self.images[self.num_imgs] = image
39 | return image
40 | else
41 | local p = math.random()
42 | -- print('p' ,p)
43 | -- os.exit()
44 | if p > 0.5 then
45 | -- print('use old image')
46 | -- random_id = torch.Tensor(1)
47 | -- random_id:random(1, self.pool_size)
48 | local random_id = math.random(self.pool_size)
49 | -- print('random_id', random_id)
50 | local tmp = self.images[random_id]:clone()
51 | self.images[random_id] = image:clone()
52 | return tmp
53 | else
54 | return image
55 | end
56 |
57 | end
58 |
59 | end
60 |
--------------------------------------------------------------------------------
/util/plot_util.lua:
--------------------------------------------------------------------------------
1 | local class = require 'class'
2 | PlotUtil = class('PlotUtil')
3 |
4 |
5 | require 'torch'
6 | disp = require 'display'
7 | util = require 'util/util'
8 | require 'image'
9 |
10 | local unpack = unpack or table.unpack
11 |
12 | function PlotUtil:__init(conf)
13 | conf = conf or {}
14 | end
15 |
16 | function PlotUtil:model_name()
17 | return 'PlotUtil'
18 | end
19 |
20 | function PlotUtil:Initialize(display_plot, display_id, name)
21 | self.display_plot = string.split(string.gsub(display_plot, "%s+", ""), ",")
22 |
23 | self.plot_config = {
24 | title = name .. ' loss over time',
25 | labels = {'epoch', unpack(self.display_plot)},
26 | ylabel = 'loss',
27 | win = display_id,
28 | }
29 |
30 | self.plot_data = {}
31 | print('display_opt', self.display_plot)
32 | end
33 |
34 |
35 | function PlotUtil:Display(plot_vals, loss)
36 | for k, v in ipairs(self.display_plot) do
37 | if loss[v] ~= nil then
38 | plot_vals[#plot_vals + 1] = loss[v]
39 | end
40 | end
41 |
42 | table.insert(self.plot_data, plot_vals)
43 | disp.plot(self.plot_data, self.plot_config)
44 | end
45 |
--------------------------------------------------------------------------------
/util/util.lua:
--------------------------------------------------------------------------------
1 | --
2 | -- code derived from https://github.com/soumith/dcgan.torch
3 | --
4 |
5 | local util = {}
6 |
7 | require 'torch'
8 |
9 |
10 | function util.BiasZero(net)
11 | net:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end)
12 | end
13 |
14 |
15 | function util.checkEqual(A, B, name)
16 | local dif = (A:float()-B:float()):abs():mean()
17 | print(name, dif)
18 | end
19 |
20 | function util.containsValue(table, value)
21 | for k, v in pairs(table) do
22 | if v == value then return true end
23 | end
24 | return false
25 | end
26 |
27 |
28 | function util.CheckTensor(A, name)
29 | print(name, A:min(), A:max(), A:mean())
30 | end
31 |
32 |
33 | function util.normalize(img)
34 | -- rescale image to 0 .. 1
35 | local min = img:min()
36 | local max = img:max()
37 |
38 | img = torch.FloatTensor(img:size()):copy(img)
39 | img:add(-min):mul(1/(max-min))
40 | return img
41 | end
42 |
43 | function util.normalizeBatch(batch)
44 | for i = 1, batch:size(1) do
45 | batch[i] = util.normalize(batch[i]:squeeze())
46 | end
47 | return batch
48 | end
49 |
50 | function util.basename_batch(batch)
51 | for i = 1, #batch do
52 | batch[i] = paths.basename(batch[i])
53 | end
54 | return batch
55 | end
56 |
57 |
58 |
59 | -- default preprocessing
60 | --
61 | -- Preprocesses an image before passing it to a net
62 | -- Converts from RGB to BGR and rescales from [0,1] to [-1,1]
63 | function util.preprocess(img)
64 | -- RGB to BGR
65 | if img:size(1) == 3 then
66 | local perm = torch.LongTensor{3, 2, 1}
67 | img = img:index(1, perm)
68 | end
69 | -- [0,1] to [-1,1]
70 | img = img:mul(2):add(-1)
71 |
72 | -- check that input is in expected range
73 | assert(img:max()<=1,"badly scaled inputs")
74 | assert(img:min()>=-1,"badly scaled inputs")
75 |
76 | return img
77 | end
78 |
79 | -- Undo the above preprocessing.
80 | function util.deprocess(img)
81 | -- BGR to RGB
82 | if img:size(1) == 3 then
83 | local perm = torch.LongTensor{3, 2, 1}
84 | img = img:index(1, perm)
85 | end
86 |
87 | -- [-1,1] to [0,1]
88 | img = img:add(1):div(2)
89 |
90 | return img
91 | end
92 |
93 | function util.preprocess_batch(batch)
94 | for i = 1, batch:size(1) do
95 | batch[i] = util.preprocess(batch[i]:squeeze())
96 | end
97 | return batch
98 | end
99 |
100 | function util.print_tensor(name, x)
101 | print(name, x:size(), x:min(), x:mean(), x:max())
102 | end
103 |
104 | function util.deprocess_batch(batch)
105 | for i = 1, batch:size(1) do
106 | batch[i] = util.deprocess(batch[i]:squeeze())
107 | end
108 | return batch
109 | end
110 |
111 |
112 | function util.scaleBatch(batch,s1,s2)
113 | -- print('s1', s1)
114 | -- print('s2', s2)
115 | local scaled_batch = torch.Tensor(batch:size(1),batch:size(2),s1,s2)
116 | for i = 1, batch:size(1) do
117 | scaled_batch[i] = image.scale(batch[i],s1,s2):squeeze()
118 | end
119 | return scaled_batch
120 | end
121 |
122 |
123 |
124 | function util.toTrivialBatch(input)
125 | return input:reshape(1,input:size(1),input:size(2),input:size(3))
126 | end
127 | function util.fromTrivialBatch(input)
128 | return input[1]
129 | end
130 |
131 | -- input is between -1 and 1
132 | function util.jitter(input)
133 | local noise = torch.rand(input:size())/256.0
134 | input:add(1.0):mul(0.5*255.0/256.0):add(noise):add(-0.5):mul(2.0)
135 | --local scaled = (input+1.0)*0.5
136 | --local jittered = scaled*255.0/256.0 + torch.rand(input:size())/256.0
137 | --local scaled_back = (jittered-0.5)*2.0
138 | --return scaled_back
139 | end
140 |
141 | function util.scaleImage(input, loadSize)
142 |
143 | -- replicate bw images to 3 channels
144 | if input:size(1)==1 then
145 | input = torch.repeatTensor(input,3,1,1)
146 | end
147 |
148 | input = image.scale(input, loadSize, loadSize)
149 |
150 | return input
151 | end
152 |
153 | function util.getAspectRatio(path)
154 | local input = image.load(path, 3, 'float')
155 | local ar = input:size(3)/input:size(2)
156 | return ar
157 | end
158 |
159 | function util.loadImage(path, loadSize, nc)
160 | local input = image.load(path, 3, 'float')
161 | input= util.preprocess(util.scaleImage(input, loadSize))
162 |
163 | if nc == 1 then
164 | input = input[{{1}, {}, {}}]
165 | end
166 |
167 | return input
168 | end
169 |
170 | function file_exists(filename)
171 | local f = io.open(filename,"r")
172 | if f ~= nil then io.close(f) return true else return false end
173 | end
174 |
175 | -- TO DO: loading code is rather hacky; clean it up and make sure it works on all types of nets / cpu/gpu configurations
176 | function load_helper(filename, opt)
177 | fileExists = file_exists(filename)
178 | if not fileExists then
179 | print('model not found! ' .. filename)
180 | return nil
181 | end
182 | print(('loading previously trained model (%s)'):format(filename))
183 | if opt.norm == 'instance' then
184 | print('use InstanceNormalization')
185 | require 'util.InstanceNormalization'
186 | end
187 |
188 | if opt.cudnn>0 then
189 | require 'cudnn'
190 | end
191 |
192 | local net = torch.load(filename)
193 | if opt.gpu > 0 then
194 | require 'cunn'
195 | net:cuda()
196 |
197 | -- calling cuda on cudnn saved nngraphs doesn't change all variables to cuda, so do it below
198 | if net.forwardnodes then
199 | for i=1,#net.forwardnodes do
200 | if net.forwardnodes[i].data.module then
201 | net.forwardnodes[i].data.module:cuda()
202 | end
203 | end
204 | end
205 |
206 | else
207 | net:float()
208 | end
209 | net:apply(function(m) if m.weight then
210 | m.gradWeight = m.weight:clone():zero();
211 | m.gradBias = m.bias:clone():zero(); end end)
212 | return net
213 | end
214 |
215 | function util.load_model(name, opt)
216 | -- if opt['lambda_'.. name] > 0.0 then
217 | -- print('not loading model '.. opt.checkpoints_dir .. opt.name ..
218 | -- 'latest_net_' .. name .. '.t7' .. ' because opt.lambda is not greater than zero')
219 | return load_helper(paths.concat(opt.checkpoints_dir, opt.name,
220 | 'latest_net_' .. name .. '.t7'), opt)
221 | -- end
222 | end
223 |
224 | function util.load_test_model(name, opt)
225 | return load_helper(paths.concat(opt.checkpoints_dir, opt.name,
226 | opt.which_epoch .. '_net_' .. name .. '.t7'), opt)
227 | end
228 |
229 |
230 | -- load dataset from the file system
231 | -- |name|: name of the dataset. It's currently either 'A' or 'B'
232 | -- function util.load_dataset(name, nc, opt, nc)
233 | -- local tensortype = torch.getdefaulttensortype()
234 | -- torch.setdefaulttensortype('torch.FloatTensor')
235 | --
236 | -- local new_opt = options.clone(opt)
237 | -- new_opt.manualSeed = torch.random(1, 10000) -- fix seed
238 | -- new_opt.nc = nc
239 | -- torch.manualSeed(new_opt.manualSeed)
240 | -- local data_loader = paths.dofile('../data/data.lua')
241 | -- new_opt.phase = new_opt.phase .. name
242 | -- local data = data_loader.new(new_opt.nThreads, new_opt)
243 | -- print("Dataset Size " .. name .. ": ", data:size())
244 | --
245 | -- torch.setdefaulttensortype(tensortype)
246 | -- return data
247 | -- end
248 |
249 |
250 |
251 | function util.cudnn(net)
252 | require 'cudnn'
253 | require 'util/cudnn_convert_custom'
254 | return cudnn_convert_custom(net, cudnn)
255 | end
256 |
257 | function util.save_model(net, net_name, weight)
258 | if weight > 0.0 then
259 | torch.save(paths.concat(opt.checkpoints_dir, opt.name, net_name), net:clearState())
260 | end
261 | end
262 |
263 |
264 |
265 |
266 | return util
267 |
--------------------------------------------------------------------------------
/util/visualizer.lua:
--------------------------------------------------------------------------------
1 | -------------------------------------------------------------
2 | -- Various utilities for visualization through the web server
3 | -------------------------------------------------------------
4 |
5 | local visualizer = {}
6 |
7 | require 'torch'
8 | disp = nil
9 | print(opt)
10 | if opt.display_id > 0 then -- [hack]: assume that opt already existed
11 | disp = require 'display'
12 | end
13 | util = require 'util/util'
14 | require 'image'
15 |
16 | -- function visualizer
17 | function visualizer.disp_image(img_data, win_size, display_id, title)
18 | images = util.deprocess_batch(util.scaleBatch(img_data:float(),win_size,win_size))
19 | disp.image(images, {win=display_id, title=title})
20 | end
21 |
22 | function visualizer.save_results(img_data, output_path)
23 | local tensortype = torch.getdefaulttensortype()
24 | torch.setdefaulttensortype('torch.FloatTensor')
25 | local image_out = nil
26 | local win_size = opt.display_winsize
27 | images = torch.squeeze(util.deprocess_batch(util.scaleBatch(img_data:float(), win_size, win_size)))
28 |
29 | if images:dim() == 3 then
30 | image_out = images
31 | else
32 | for i = 1,images:size(1) do
33 | img = images[i]
34 | if image_out == nil then
35 | image_out = img
36 | else
37 | image_out = torch.cat(image_out, img)
38 | end
39 | end
40 | end
41 | image.save(output_path, image_out)
42 | torch.setdefaulttensortype(tensortype)
43 | end
44 |
45 | function visualizer.save_images(imgs, save_dir, impaths, s1, s2)
46 | local tensortype = torch.getdefaulttensortype()
47 | torch.setdefaulttensortype('torch.FloatTensor')
48 | batchSize = imgs:size(1)
49 | imgs_f = util.deprocess_batch(imgs):float()
50 | paths.mkdir(save_dir)
51 | for i = 1, batchSize do -- imgs_f[i]:size(2), imgs_f[i]:size(3)/opt.aspect_ratio
52 | if s1 ~= nil and s2 ~= nil then
53 | im_s = image.scale(imgs_f[i], s1, s2):float()
54 | else
55 | im_s = imgs_f[i]:float()
56 | end
57 | img_to_save = torch.FloatTensor(im_s:size()):copy(im_s)
58 | image.save(paths.concat(save_dir, impaths[i]), img_to_save)
59 | end
60 | torch.setdefaulttensortype(tensortype)
61 | end
62 |
63 | return visualizer
64 |
--------------------------------------------------------------------------------