├── .gitignore ├── LICENSE ├── README.md ├── data ├── data.lua ├── dataset.lua └── donkey_folder.lua ├── datasets ├── bibtex │ ├── cityscapes.tex │ ├── facades.tex │ ├── handbags.tex │ ├── shoes.tex │ └── transattr.tex └── download_dataset.sh ├── imgs └── examples.jpg ├── models.lua ├── models └── download_model.sh ├── scripts ├── combine_A_and_B.py ├── edges │ ├── PostprocessHED.m │ └── batch_hed.py ├── eval_cityscapes │ ├── caffemodel │ │ └── deploy.prototxt │ ├── cityscapes.py │ ├── download_fcn8s.sh │ ├── evaluate.py │ └── util.py └── receptive_field_sizes.m ├── test.lua ├── train.lua └── util ├── cudnn_convert_custom.lua └── util.lua /.gitignore: -------------------------------------------------------------------------------- 1 | # 2 | *~ 3 | 4 | *.DS_Store 5 | cache/ 6 | results/ 7 | checkpoints/ 8 | 9 | # luarocks build files 10 | *.src.rock 11 | *.zip 12 | *.tar.gz 13 | *.t7 14 | 15 | # Object files 16 | *.o 17 | *.os 18 | *.ko 19 | *.obj 20 | *.elf 21 | 22 | # Precompiled Headers 23 | *.gch 24 | *.pch 25 | 26 | # Libraries 27 | *.lib 28 | *.a 29 | *.la 30 | *.lo 31 | *.def 32 | *.exp 33 | 34 | # Shared objects (inc. Windows DLLs) 35 | *.dll 36 | *.so 37 | *.so.* 38 | *.dylib 39 | 40 | # Executables 41 | *.exe 42 | *.out 43 | *.app 44 | *.i*86 45 | *.x86_64 46 | *.hex -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | 26 | 27 | ----------------------------- LICENSE FOR DCGAN -------------------------------- 28 | BSD License 29 | 30 | For dcgan.torch software 31 | 32 | Copyright (c) 2015, Facebook, Inc. All rights reserved. 33 | 34 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 35 | 36 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 37 | 38 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 39 | 40 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 41 | 42 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # pix2pix 3 | [Project](https://phillipi.github.io/pix2pix/) | [Arxiv](https://arxiv.org/abs/1611.07004) | 4 | [PyTorch](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) 5 | 6 | Torch implementation for learning a mapping from input images to output images, for example: 7 | 8 | 9 | 10 | Image-to-Image Translation with Conditional Adversarial Networks 11 | [Phillip Isola](http://web.mit.edu/phillipi/), [Jun-Yan Zhu](https://www.cs.cmu.edu/~junyanz/), [Tinghui Zhou](https://people.eecs.berkeley.edu/~tinghuiz/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/) 12 | CVPR, 2017. 13 | 14 | On some tasks, decent results can be obtained fairly quickly and on small datasets. For example, to learn to generate facades (example shown above), we trained on just 400 images for about 2 hours (on a single Pascal Titan X GPU). However, for harder problems it may be important to train on far larger datasets, and for many hours or even days. 15 | 16 | **Note**: Please check out our [PyTorch](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) implementation for pix2pix and CycleGAN. The PyTorch version is under active development and can produce results comparable to or better than this Torch version. 17 | 18 | ## Setup 19 | 20 | ### Prerequisites 21 | - Linux or OSX 22 | - NVIDIA GPU + CUDA CuDNN (CPU mode and CUDA without CuDNN may work with minimal modification, but untested) 23 | 24 | ### Getting Started 25 | - Install torch and dependencies from https://github.com/torch/distro 26 | - Install torch packages `nngraph` and `display` 27 | ```bash 28 | luarocks install nngraph 29 | luarocks install https://raw.githubusercontent.com/szym/display/master/display-scm-0.rockspec 30 | ``` 31 | - Clone this repo: 32 | ```bash 33 | git clone git@github.com:phillipi/pix2pix.git 34 | cd pix2pix 35 | ``` 36 | - Download the dataset (e.g., [CMP Facades](http://cmp.felk.cvut.cz/~tylecr1/facade/)): 37 | ```bash 38 | bash ./datasets/download_dataset.sh facades 39 | ``` 40 | - Train the model 41 | ```bash 42 | DATA_ROOT=./datasets/facades name=facades_generation which_direction=BtoA th train.lua 43 | ``` 44 | - (CPU only) The same training command without using a GPU or CUDNN. Setting the environment variables ```gpu=0 cudnn=0``` forces CPU only 45 | ```bash 46 | DATA_ROOT=./datasets/facades name=facades_generation which_direction=BtoA gpu=0 cudnn=0 batchSize=10 save_epoch_freq=5 th train.lua 47 | ``` 48 | - (Optionally) start the display server to view results as the model trains. ( See [Display UI](#display-ui) for more details): 49 | ```bash 50 | th -ldisplay.start 8000 0.0.0.0 51 | ``` 52 | 53 | - Finally, test the model: 54 | ```bash 55 | DATA_ROOT=./datasets/facades name=facades_generation which_direction=BtoA phase=val th test.lua 56 | ``` 57 | The test results will be saved to an html file here: `./results/facades_generation/latest_net_G_val/index.html`. 58 | 59 | ## Train 60 | ```bash 61 | DATA_ROOT=/path/to/data/ name=expt_name which_direction=AtoB th train.lua 62 | ``` 63 | Switch `AtoB` to `BtoA` to train translation in opposite direction. 64 | 65 | Models are saved to `./checkpoints/expt_name` (can be changed by passing `checkpoint_dir=your_dir` in train.lua). 66 | 67 | See `opt` in train.lua for additional training options. 68 | 69 | ## Test 70 | ```bash 71 | DATA_ROOT=/path/to/data/ name=expt_name which_direction=AtoB phase=val th test.lua 72 | ``` 73 | 74 | This will run the model named `expt_name` in direction `AtoB` on all images in `/path/to/data/val`. 75 | 76 | Result images, and a webpage to view them, are saved to `./results/expt_name` (can be changed by passing `results_dir=your_dir` in test.lua). 77 | 78 | See `opt` in test.lua for additional testing options. 79 | 80 | 81 | ## Datasets 82 | Download the datasets using the following script. Some of the datasets are collected by other researchers. Please cite their papers if you use the data. 83 | ```bash 84 | bash ./datasets/download_dataset.sh dataset_name 85 | ``` 86 | - `facades`: 400 images from [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade/). [[Citation](datasets/bibtex/facades.tex)] 87 | - `cityscapes`: 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com/). [[Citation](datasets/bibtex/cityscapes.tex)] 88 | - `maps`: 1096 training images scraped from Google Maps 89 | - `edges2shoes`: 50k training images from [UT Zappos50K dataset](http://vision.cs.utexas.edu/projects/finegrained/utzap50k/). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. 90 | [[Citation](datasets/bibtex/shoes.tex)] 91 | - `edges2handbags`: 137K Amazon Handbag images from [iGAN project](https://github.com/junyanz/iGAN). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. [[Citation](datasets/bibtex/handbags.tex)] 92 | - `night2day`: around 20K natural scene images from [Transient Attributes dataset](http://transattr.cs.brown.edu/) [[Citation](datasets/bibtex/transattr.tex)]. To train a `day2night` pix2pix model, you need to add `which_direction=BtoA`. 93 | 94 | ## Models 95 | Download the pre-trained models with the following script. You need to rename the model (e.g., `facades_label2image` to `/checkpoints/facades/latest_net_G.t7`) after the download has finished. 96 | ```bash 97 | bash ./models/download_model.sh model_name 98 | ``` 99 | - `facades_label2image` (label -> facade): trained on the CMP Facades dataset. 100 | - `cityscapes_label2image` (label -> street scene): trained on the Cityscapes dataset. 101 | - `cityscapes_image2label` (street scene -> label): trained on the Cityscapes dataset. 102 | - `edges2shoes` (edge -> photo): trained on UT Zappos50K dataset. 103 | - `edges2handbags` (edge -> photo): trained on Amazon handbags images. 104 | - `day2night` (daytime scene -> nighttime scene): trained on around 100 [webcams](http://transattr.cs.brown.edu/). 105 | 106 | ## Setup Training and Test data 107 | ### Generating Pairs 108 | We provide a python script to generate training data in the form of pairs of images {A,B}, where A and B are two different depictions of the same underlying scene. For example, these might be pairs {label map, photo} or {bw image, color image}. Then we can learn to translate A to B or B to A: 109 | 110 | Create folder `/path/to/data` with subfolders `A` and `B`. `A` and `B` should each have their own subfolders `train`, `val`, `test`, etc. In `/path/to/data/A/train`, put training images in style A. In `/path/to/data/B/train`, put the corresponding images in style B. Repeat same for other data splits (`val`, `test`, etc). 111 | 112 | Corresponding images in a pair {A,B} must be the same size and have the same filename, e.g., `/path/to/data/A/train/1.jpg` is considered to correspond to `/path/to/data/B/train/1.jpg`. 113 | 114 | Once the data is formatted this way, call: 115 | ```bash 116 | python scripts/combine_A_and_B.py --fold_A /path/to/data/A --fold_B /path/to/data/B --fold_AB /path/to/data 117 | ``` 118 | 119 | This will combine each pair of images (A,B) into a single image file, ready for training. 120 | 121 | ### Notes on Colorization 122 | No need to run `combine_A_and_B.py` for colorization. Instead, you need to prepare some natural images and set `preprocess=colorization` in the script. The program will automatically convert each RGB image into Lab color space, and create `L -> ab` image pair during the training. Also set `input_nc=1` and `output_nc=2`. 123 | 124 | ### Extracting Edges 125 | We provide python and Matlab scripts to extract coarse edges from photos. Run `scripts/edges/batch_hed.py` to compute [HED](https://github.com/s9xie/hed) edges. Run `scripts/edges/PostprocessHED.m` to simplify edges with additional post-processing steps. Check the code documentation for more details. 126 | 127 | ### Evaluating Labels2Photos on Cityscapes 128 | We provide scripts for running the evaluation of the Labels2Photos task on the Cityscapes **validation** set. We assume that you have installed `caffe` (and `pycaffe`) in your system. If not, see the [official website](http://caffe.berkeleyvision.org/installation.html) for installation instructions. Once `caffe` is successfully installed, download the pre-trained FCN-8s semantic segmentation model (512MB) by running 129 | ```bash 130 | bash ./scripts/eval_cityscapes/download_fcn8s.sh 131 | ``` 132 | Then make sure `./scripts/eval_cityscapes/` is in your system's python path. If not, run the following command to add it 133 | ```bash 134 | export PYTHONPATH=${PYTHONPATH}:./scripts/eval_cityscapes/ 135 | ``` 136 | Now you can run the following command to evaluate your predictions: 137 | ```bash 138 | python ./scripts/eval_cityscapes/evaluate.py --cityscapes_dir /path/to/original/cityscapes/dataset/ --result_dir /path/to/your/predictions/ --output_dir /path/to/output/directory/ 139 | ``` 140 | Images stored under `--result_dir` should contain your model predictions on the Cityscapes **validation** split, and have the original Cityscapes naming convention (e.g., `frankfurt_000001_038418_leftImg8bit.png`). The script will output a text file under `--output_dir` containing the metric. 141 | 142 | **Further notes**: Our pre-trained FCN model is **not** supposed to work on Cityscapes in the original resolution (1024x2048) as it was trained on 256x256 images that are then upsampled to 1024x2048 during training. The purpose of the resizing during training was to 1) keep the label maps in the original high resolution untouched and 2) avoid the need to change the standard FCN training code and the architecture for Cityscapes. During test time, you need to synthesize 256x256 results. Our test code will automatically upsample your results to 1024x2048 before feeding them to the pre-trained FCN model. The output is at 1024x2048 resolution and will be compared to 1024x2048 ground truth labels. You do not need to resize the ground truth labels. The best way to verify whether everything is correct is to reproduce the numbers for real images in the paper first. To achieve it, you need to resize the original/real Cityscapes images (**not** labels) to 256x256 and feed them to the evaluation code. 143 | 144 | 145 | ## Display UI 146 | Optionally, for displaying images during training and test, use the [display package](https://github.com/szym/display). 147 | 148 | - Install it with: `luarocks install https://raw.githubusercontent.com/szym/display/master/display-scm-0.rockspec` 149 | - Then start the server with: `th -ldisplay.start` 150 | - Open this URL in your browser: [http://localhost:8000](http://localhost:8000) 151 | 152 | By default, the server listens on localhost. Pass `0.0.0.0` to allow external connections on any interface: 153 | ```bash 154 | th -ldisplay.start 8000 0.0.0.0 155 | ``` 156 | Then open `http://(hostname):(port)/` in your browser to load the remote desktop. 157 | 158 | L1 error is plotted to the display by default. Set the environment variable `display_plot` to a comma-separated list of values `errL1`, `errG` and `errD` to visualize the L1, generator, and discriminator error respectively. For example, to plot only the generator and discriminator errors to the display instead of the default L1 error, set `display_plot="errG,errD"`. 159 | 160 | ## Citation 161 | If you use this code for your research, please cite our paper Image-to-Image Translation Using Conditional Adversarial Networks: 162 | 163 | ``` 164 | @article{pix2pix2017, 165 | title={Image-to-Image Translation with Conditional Adversarial Networks}, 166 | author={Isola, Phillip and Zhu, Jun-Yan and Zhou, Tinghui and Efros, Alexei A}, 167 | journal={CVPR}, 168 | year={2017} 169 | } 170 | ``` 171 | 172 | ## Cat Paper Collection 173 | If you love cats, and love reading cool graphics, vision, and learning papers, please check out the Cat Paper Collection: 174 | [[Github]](https://github.com/junyanz/CatPapers) [[Webpage]](https://www.cs.cmu.edu/~junyanz/cat/cat_papers.html) 175 | 176 | ## Acknowledgments 177 | Code borrows heavily from [DCGAN](https://github.com/soumith/dcgan.torch). The data loader is modified from [DCGAN](https://github.com/soumith/dcgan.torch) and [Context-Encoder](https://github.com/pathak22/context-encoder). 178 | -------------------------------------------------------------------------------- /data/data.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | This data loader is a modified version of the one from dcgan.torch 3 | (see https://github.com/soumith/dcgan.torch/blob/master/data/data.lua). 4 | 5 | Copyright (c) 2016, Deepak Pathak [See LICENSE file for details] 6 | ]]-- 7 | 8 | local Threads = require 'threads' 9 | Threads.serialization('threads.sharedserialize') 10 | 11 | local data = {} 12 | 13 | local result = {} 14 | local unpack = unpack and unpack or table.unpack 15 | 16 | function data.new(n, opt_) 17 | opt_ = opt_ or {} 18 | local self = {} 19 | for k,v in pairs(data) do 20 | self[k] = v 21 | end 22 | 23 | local donkey_file = 'donkey_folder.lua' 24 | if n > 0 then 25 | local options = opt_ 26 | self.threads = Threads(n, 27 | function() require 'torch' end, 28 | function(idx) 29 | opt = options 30 | tid = idx 31 | local seed = (opt.manualSeed and opt.manualSeed or 0) + idx 32 | torch.manualSeed(seed) 33 | torch.setnumthreads(1) 34 | print(string.format('Starting donkey with id: %d seed: %d', tid, seed)) 35 | assert(options, 'options not found') 36 | assert(opt, 'opt not given') 37 | print(opt) 38 | paths.dofile(donkey_file) 39 | end 40 | 41 | ) 42 | else 43 | if donkey_file then paths.dofile(donkey_file) end 44 | self.threads = {} 45 | function self.threads:addjob(f1, f2) f2(f1()) end 46 | function self.threads:dojob() end 47 | function self.threads:synchronize() end 48 | end 49 | 50 | local nSamples = 0 51 | self.threads:addjob(function() return trainLoader:size() end, 52 | function(c) nSamples = c end) 53 | self.threads:synchronize() 54 | self._size = nSamples 55 | 56 | for i = 1, n do 57 | self.threads:addjob(self._getFromThreads, 58 | self._pushResult) 59 | end 60 | return self 61 | end 62 | 63 | function data._getFromThreads() 64 | assert(opt.batchSize, 'opt.batchSize not found') 65 | return trainLoader:sample(opt.batchSize) 66 | end 67 | 68 | function data._pushResult(...) 69 | local res = {...} 70 | if res == nil then 71 | self.threads:synchronize() 72 | end 73 | result[1] = res 74 | end 75 | 76 | 77 | 78 | function data:getBatch() 79 | self.threads:addjob(self._getFromThreads, self._pushResult) 80 | self.threads:dojob() 81 | local res = result[1] 82 | 83 | img_data = res[1] 84 | img_paths = res[3] 85 | 86 | result[1] = nil 87 | if torch.type(img_data) == 'table' then 88 | img_data = unpack(img_data) 89 | end 90 | 91 | return img_data, img_paths 92 | end 93 | 94 | function data:size() 95 | return self._size 96 | end 97 | 98 | return data 99 | -------------------------------------------------------------------------------- /data/dataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2015-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | require 'torch' 11 | torch.setdefaulttensortype('torch.FloatTensor') 12 | local ffi = require 'ffi' 13 | local class = require('pl.class') 14 | local dir = require 'pl.dir' 15 | local tablex = require 'pl.tablex' 16 | local argcheck = require 'argcheck' 17 | require 'sys' 18 | require 'xlua' 19 | require 'image' 20 | 21 | local dataset = torch.class('dataLoader') 22 | 23 | local initcheck = argcheck{ 24 | pack=true, 25 | help=[[ 26 | A dataset class for images in a flat folder structure (folder-name is class-name). 27 | Optimized for extremely large datasets (upwards of 14 million images). 28 | Tested only on Linux (as it uses command-line linux utilities to scale up) 29 | ]], 30 | {check=function(paths) 31 | local out = true; 32 | for k,v in ipairs(paths) do 33 | if type(v) ~= 'string' then 34 | print('paths can only be of string input'); 35 | out = false 36 | end 37 | end 38 | return out 39 | end, 40 | name="paths", 41 | type="table", 42 | help="Multiple paths of directories with images"}, 43 | 44 | {name="sampleSize", 45 | type="table", 46 | help="a consistent sample size to resize the images"}, 47 | 48 | {name="split", 49 | type="number", 50 | help="Percentage of split to go to Training" 51 | }, 52 | {name="serial_batches", 53 | type="number", 54 | help="if randomly sample training images"}, 55 | 56 | {name="samplingMode", 57 | type="string", 58 | help="Sampling mode: random | balanced ", 59 | default = "balanced"}, 60 | 61 | {name="verbose", 62 | type="boolean", 63 | help="Verbose mode during initialization", 64 | default = false}, 65 | 66 | {name="loadSize", 67 | type="table", 68 | help="a size to load the images to, initially", 69 | opt = true}, 70 | 71 | {name="forceClasses", 72 | type="table", 73 | help="If you want this loader to map certain classes to certain indices, " 74 | .. "pass a classes table that has {classname : classindex} pairs." 75 | .. " For example: {3 : 'dog', 5 : 'cat'}" 76 | .. "This function is very useful when you want two loaders to have the same " 77 | .. "class indices (trainLoader/testLoader for example)", 78 | opt = true}, 79 | 80 | {name="sampleHookTrain", 81 | type="function", 82 | help="applied to sample during training(ex: for lighting jitter). " 83 | .. "It takes the image path as input", 84 | opt = true}, 85 | 86 | {name="sampleHookTest", 87 | type="function", 88 | help="applied to sample during testing", 89 | opt = true}, 90 | } 91 | 92 | function dataset:__init(...) 93 | 94 | -- argcheck 95 | local args = initcheck(...) 96 | print(args) 97 | for k,v in pairs(args) do self[k] = v end 98 | 99 | if not self.loadSize then self.loadSize = self.sampleSize; end 100 | 101 | if not self.sampleHookTrain then self.sampleHookTrain = self.defaultSampleHook end 102 | if not self.sampleHookTest then self.sampleHookTest = self.defaultSampleHook end 103 | self.image_count = 1 104 | -- find class names 105 | self.classes = {} 106 | local classPaths = {} 107 | if self.forceClasses then 108 | for k,v in pairs(self.forceClasses) do 109 | self.classes[k] = v 110 | classPaths[k] = {} 111 | end 112 | end 113 | local function tableFind(t, o) for k,v in pairs(t) do if v == o then return k end end end 114 | -- loop over each paths folder, get list of unique class names, 115 | -- also store the directory paths per class 116 | -- for each class, 117 | for k,path in ipairs(self.paths) do 118 | local dirs = {} -- hack 119 | dirs[1] = path 120 | for k,dirpath in ipairs(dirs) do 121 | local class = paths.basename(dirpath) 122 | local idx = tableFind(self.classes, class) 123 | if not idx then 124 | table.insert(self.classes, class) 125 | idx = #self.classes 126 | classPaths[idx] = {} 127 | end 128 | if not tableFind(classPaths[idx], dirpath) then 129 | table.insert(classPaths[idx], dirpath); 130 | end 131 | end 132 | end 133 | 134 | self.classIndices = {} 135 | for k,v in ipairs(self.classes) do 136 | self.classIndices[v] = k 137 | end 138 | 139 | -- define command-line tools, try your best to maintain OSX compatibility 140 | local wc = 'wc' 141 | local cut = 'cut' 142 | local find = 'find -H' -- if folder name is symlink, do find inside it after dereferencing 143 | if ffi.os == 'OSX' then 144 | wc = 'gwc' 145 | cut = 'gcut' 146 | find = 'gfind' 147 | end 148 | ---------------------------------------------------------------------- 149 | -- Options for the GNU find command 150 | local extensionList = {'jpg', 'png','JPG','PNG','JPEG', 'ppm', 'PPM', 'bmp', 'BMP'} 151 | local findOptions = ' -iname "*.' .. extensionList[1] .. '"' 152 | for i=2,#extensionList do 153 | findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"' 154 | end 155 | 156 | -- find the image path names 157 | self.imagePath = torch.CharTensor() -- path to each image in dataset 158 | self.imageClass = torch.LongTensor() -- class index of each image (class index in self.classes) 159 | self.classList = {} -- index of imageList to each image of a particular class 160 | self.classListSample = self.classList -- the main list used when sampling data 161 | 162 | print('running "find" on each class directory, and concatenate all' 163 | .. ' those filenames into a single file containing all image paths for a given class') 164 | -- so, generates one file per class 165 | local classFindFiles = {} 166 | for i=1,#self.classes do 167 | classFindFiles[i] = os.tmpname() 168 | end 169 | local combinedFindList = os.tmpname(); 170 | 171 | local tmpfile = os.tmpname() 172 | local tmphandle = assert(io.open(tmpfile, 'w')) 173 | -- iterate over classes 174 | for i, class in ipairs(self.classes) do 175 | -- iterate over classPaths 176 | for j,path in ipairs(classPaths[i]) do 177 | local command = find .. ' "' .. path .. '" ' .. findOptions 178 | .. ' >>"' .. classFindFiles[i] .. '" \n' 179 | tmphandle:write(command) 180 | end 181 | end 182 | io.close(tmphandle) 183 | os.execute('bash ' .. tmpfile) 184 | os.execute('rm -f ' .. tmpfile) 185 | 186 | print('now combine all the files to a single large file') 187 | local tmpfile = os.tmpname() 188 | local tmphandle = assert(io.open(tmpfile, 'w')) 189 | -- concat all finds to a single large file in the order of self.classes 190 | for i=1,#self.classes do 191 | local command = 'cat "' .. classFindFiles[i] .. '" >>' .. combinedFindList .. ' \n' 192 | tmphandle:write(command) 193 | end 194 | io.close(tmphandle) 195 | os.execute('bash ' .. tmpfile) 196 | os.execute('rm -f ' .. tmpfile) 197 | 198 | --========================================================================== 199 | print('load the large concatenated list of sample paths to self.imagePath') 200 | local cmd = wc .. " -L '" 201 | .. combinedFindList .. "' |" 202 | .. cut .. " -f1 -d' '" 203 | print('cmd..' .. cmd) 204 | local maxPathLength = tonumber(sys.fexecute(wc .. " -L '" 205 | .. combinedFindList .. "' |" 206 | .. cut .. " -f1 -d' '")) + 1 207 | local length = tonumber(sys.fexecute(wc .. " -l '" 208 | .. combinedFindList .. "' |" 209 | .. cut .. " -f1 -d' '")) 210 | assert(length > 0, "Could not find any image file in the given input paths") 211 | assert(maxPathLength > 0, "paths of files are length 0?") 212 | self.imagePath:resize(length, maxPathLength):fill(0) 213 | local s_data = self.imagePath:data() 214 | local count = 0 215 | for line in io.lines(combinedFindList) do 216 | ffi.copy(s_data, line) 217 | s_data = s_data + maxPathLength 218 | if self.verbose and count % 10000 == 0 then 219 | xlua.progress(count, length) 220 | end; 221 | count = count + 1 222 | end 223 | 224 | self.numSamples = self.imagePath:size(1) 225 | if self.verbose then print(self.numSamples .. ' samples found.') end 226 | --========================================================================== 227 | print('Updating classList and imageClass appropriately') 228 | self.imageClass:resize(self.numSamples) 229 | local runningIndex = 0 230 | for i=1,#self.classes do 231 | if self.verbose then xlua.progress(i, #(self.classes)) end 232 | local length = tonumber(sys.fexecute(wc .. " -l '" 233 | .. classFindFiles[i] .. "' |" 234 | .. cut .. " -f1 -d' '")) 235 | if length == 0 then 236 | error('Class has zero samples') 237 | else 238 | self.classList[i] = torch.range(runningIndex + 1, runningIndex + length):long() 239 | self.imageClass[{{runningIndex + 1, runningIndex + length}}]:fill(i) 240 | end 241 | runningIndex = runningIndex + length 242 | end 243 | 244 | --========================================================================== 245 | -- clean up temporary files 246 | print('Cleaning up temporary files') 247 | local tmpfilelistall = '' 248 | for i=1,#(classFindFiles) do 249 | tmpfilelistall = tmpfilelistall .. ' "' .. classFindFiles[i] .. '"' 250 | if i % 1000 == 0 then 251 | os.execute('rm -f ' .. tmpfilelistall) 252 | tmpfilelistall = '' 253 | end 254 | end 255 | os.execute('rm -f ' .. tmpfilelistall) 256 | os.execute('rm -f "' .. combinedFindList .. '"') 257 | --========================================================================== 258 | 259 | if self.split == 100 then 260 | self.testIndicesSize = 0 261 | else 262 | print('Splitting training and test sets to a ratio of ' 263 | .. self.split .. '/' .. (100-self.split)) 264 | self.classListTrain = {} 265 | self.classListTest = {} 266 | self.classListSample = self.classListTrain 267 | local totalTestSamples = 0 268 | -- split the classList into classListTrain and classListTest 269 | for i=1,#self.classes do 270 | local list = self.classList[i] 271 | local count = self.classList[i]:size(1) 272 | local splitidx = math.floor((count * self.split / 100) + 0.5) -- +round 273 | local perm = torch.randperm(count) 274 | self.classListTrain[i] = torch.LongTensor(splitidx) 275 | for j=1,splitidx do 276 | self.classListTrain[i][j] = list[perm[j]] 277 | end 278 | if splitidx == count then -- all samples were allocated to train set 279 | self.classListTest[i] = torch.LongTensor() 280 | else 281 | self.classListTest[i] = torch.LongTensor(count-splitidx) 282 | totalTestSamples = totalTestSamples + self.classListTest[i]:size(1) 283 | local idx = 1 284 | for j=splitidx+1,count do 285 | self.classListTest[i][idx] = list[perm[j]] 286 | idx = idx + 1 287 | end 288 | end 289 | end 290 | -- Now combine classListTest into a single tensor 291 | self.testIndices = torch.LongTensor(totalTestSamples) 292 | self.testIndicesSize = totalTestSamples 293 | local tdata = self.testIndices:data() 294 | local tidx = 0 295 | for i=1,#self.classes do 296 | local list = self.classListTest[i] 297 | if list:dim() ~= 0 then 298 | local ldata = list:data() 299 | for j=0,list:size(1)-1 do 300 | tdata[tidx] = ldata[j] 301 | tidx = tidx + 1 302 | end 303 | end 304 | end 305 | end 306 | end 307 | 308 | -- size(), size(class) 309 | function dataset:size(class, list) 310 | list = list or self.classList 311 | if not class then 312 | return self.numSamples 313 | elseif type(class) == 'string' then 314 | return list[self.classIndices[class]]:size(1) 315 | elseif type(class) == 'number' then 316 | return list[class]:size(1) 317 | end 318 | end 319 | 320 | -- getByClass 321 | function dataset:getByClass(class) 322 | local index = 0 323 | if self.serial_batches == 1 then 324 | index = math.fmod(self.image_count-1, self.classListSample[class]:nElement())+1 325 | self.image_count = self.image_count +1 326 | else 327 | index = math.ceil(torch.uniform() * self.classListSample[class]:nElement()) 328 | end 329 | local imgpath = ffi.string(torch.data(self.imagePath[self.classListSample[class][index]])) 330 | return self:sampleHookTrain(imgpath), imgpath 331 | end 332 | 333 | -- converts a table of samples (and corresponding labels) to a clean tensor 334 | local function tableToOutput(self, dataTable, scalarTable) 335 | local data, scalarLabels, labels 336 | local quantity = #scalarTable 337 | assert(dataTable[1]:dim() == 3) 338 | data = torch.Tensor(quantity, 339 | self.sampleSize[1], self.sampleSize[2], self.sampleSize[3]) 340 | scalarLabels = torch.LongTensor(quantity):fill(-1111) 341 | for i=1,#dataTable do 342 | data[i]:copy(dataTable[i]) 343 | scalarLabels[i] = scalarTable[i] 344 | end 345 | return data, scalarLabels 346 | end 347 | 348 | -- sampler, samples from the training set. 349 | function dataset:sample(quantity) 350 | assert(quantity) 351 | local dataTable = {} 352 | local scalarTable = {} 353 | local samplePaths = {} 354 | for i=1,quantity do 355 | local class = torch.random(1, #self.classes) 356 | local out, imgpath = self:getByClass(class) 357 | table.insert(dataTable, out) 358 | table.insert(scalarTable, class) 359 | samplePaths[i] = imgpath 360 | end 361 | local data, scalarLabels = tableToOutput(self, dataTable, scalarTable) 362 | return data, scalarLabels, samplePaths-- filePaths 363 | end 364 | 365 | function dataset:get(i1, i2) 366 | local indices = torch.range(i1, i2); 367 | local quantity = i2 - i1 + 1; 368 | assert(quantity > 0) 369 | -- now that indices has been initialized, get the samples 370 | local dataTable = {} 371 | local scalarTable = {} 372 | for i=1,quantity do 373 | -- load the sample 374 | local imgpath = ffi.string(torch.data(self.imagePath[indices[i]])) 375 | local out = self:sampleHookTest(imgpath) 376 | table.insert(dataTable, out) 377 | table.insert(scalarTable, self.imageClass[indices[i]]) 378 | end 379 | local data, scalarLabels = tableToOutput(self, dataTable, scalarTable) 380 | return data, scalarLabels 381 | end 382 | 383 | return dataset 384 | -------------------------------------------------------------------------------- /data/donkey_folder.lua: -------------------------------------------------------------------------------- 1 | 2 | --[[ 3 | This data loader is a modified version of the one from dcgan.torch 4 | (see https://github.com/soumith/dcgan.torch/blob/master/data/donkey_folder.lua). 5 | Copyright (c) 2016, Deepak Pathak [See LICENSE file for details] 6 | Copyright (c) 2015-present, Facebook, Inc. 7 | All rights reserved. 8 | This source code is licensed under the BSD-style license found in the 9 | LICENSE file in the root directory of this source tree. An additional grant 10 | of patent rights can be found in the PATENTS file in the same directory. 11 | ]]-- 12 | 13 | require 'image' 14 | paths.dofile('dataset.lua') 15 | -- This file contains the data-loading logic and details. 16 | -- It is run by each data-loader thread. 17 | ------------------------------------------ 18 | -------- COMMON CACHES and PATHS 19 | -- Check for existence of opt.data 20 | print(os.getenv('DATA_ROOT')) 21 | opt.data = paths.concat(os.getenv('DATA_ROOT'), opt.phase) 22 | 23 | if not paths.dirp(opt.data) then 24 | error('Did not find directory: ' .. opt.data) 25 | end 26 | 27 | -- a cache file of the training metadata (if doesnt exist, will be created) 28 | local cache = "cache" 29 | local cache_prefix = opt.data:gsub('/', '_') 30 | os.execute('mkdir -p cache') 31 | local trainCache = paths.concat(cache, cache_prefix .. '_trainCache.t7') 32 | 33 | -------------------------------------------------------------------------------------------- 34 | local input_nc = opt.input_nc -- input channels 35 | local output_nc = opt.output_nc 36 | local loadSize = {input_nc, opt.loadSize} 37 | local sampleSize = {input_nc, opt.fineSize} 38 | 39 | local preprocessAandB = function(imA, imB) 40 | imA = image.scale(imA, loadSize[2], loadSize[2]) 41 | imB = image.scale(imB, loadSize[2], loadSize[2]) 42 | local perm = torch.LongTensor{3, 2, 1} 43 | imA = imA:index(1, perm)--:mul(256.0): brg, rgb 44 | imA = imA:mul(2):add(-1) 45 | imB = imB:index(1, perm) 46 | imB = imB:mul(2):add(-1) 47 | -- print(img:size()) 48 | assert(imA:max()<=1,"A: badly scaled inputs") 49 | assert(imA:min()>=-1,"A: badly scaled inputs") 50 | assert(imB:max()<=1,"B: badly scaled inputs") 51 | assert(imB:min()>=-1,"B: badly scaled inputs") 52 | 53 | 54 | local oW = sampleSize[2] 55 | local oH = sampleSize[2] 56 | local iH = imA:size(2) 57 | local iW = imA:size(3) 58 | 59 | if iH~=oH then 60 | h1 = math.ceil(torch.uniform(1e-2, iH-oH)) 61 | end 62 | 63 | if iW~=oW then 64 | w1 = math.ceil(torch.uniform(1e-2, iW-oW)) 65 | end 66 | if iH ~= oH or iW ~= oW then 67 | imA = image.crop(imA, w1, h1, w1 + oW, h1 + oH) 68 | imB = image.crop(imB, w1, h1, w1 + oW, h1 + oH) 69 | end 70 | 71 | if opt.flip == 1 and torch.uniform() > 0.5 then 72 | imA = image.hflip(imA) 73 | imB = image.hflip(imB) 74 | end 75 | 76 | return imA, imB 77 | end 78 | 79 | 80 | 81 | local function loadImageChannel(path) 82 | local input = image.load(path, 3, 'float') 83 | input = image.scale(input, loadSize[2], loadSize[2]) 84 | 85 | local oW = sampleSize[2] 86 | local oH = sampleSize[2] 87 | local iH = input:size(2) 88 | local iW = input:size(3) 89 | 90 | if iH~=oH then 91 | h1 = math.ceil(torch.uniform(1e-2, iH-oH)) 92 | end 93 | 94 | if iW~=oW then 95 | w1 = math.ceil(torch.uniform(1e-2, iW-oW)) 96 | end 97 | if iH ~= oH or iW ~= oW then 98 | input = image.crop(input, w1, h1, w1 + oW, h1 + oH) 99 | end 100 | 101 | 102 | if opt.flip == 1 and torch.uniform() > 0.5 then 103 | input = image.hflip(input) 104 | end 105 | 106 | local input_lab = image.rgb2lab(input) 107 | local imA = input_lab[{{1}, {}, {} }]:div(50.0) - 1.0 108 | local imB = input_lab[{{2,3},{},{}}]:div(110.0) 109 | local imAB = torch.cat(imA, imB, 1) 110 | assert(imAB:max()<=1,"A: badly scaled inputs") 111 | assert(imAB:min()>=-1,"A: badly scaled inputs") 112 | 113 | return imAB 114 | end 115 | 116 | --local function loadImage 117 | 118 | local function loadImage(path) 119 | local input = image.load(path, 3, 'float') 120 | local h = input:size(2) 121 | local w = input:size(3) 122 | 123 | local imA = image.crop(input, 0, 0, w/2, h) 124 | local imB = image.crop(input, w/2, 0, w, h) 125 | 126 | return imA, imB 127 | end 128 | 129 | local function loadImageInpaint(path) 130 | local imB = image.load(path, 3, 'float') 131 | imB = image.scale(imB, loadSize[2], loadSize[2]) 132 | local perm = torch.LongTensor{3, 2, 1} 133 | imB = imB:index(1, perm)--:mul(256.0): brg, rgb 134 | imB = imB:mul(2):add(-1) 135 | assert(imB:max()<=1,"A: badly scaled inputs") 136 | assert(imB:min()>=-1,"A: badly scaled inputs") 137 | local oW = sampleSize[2] 138 | local oH = sampleSize[2] 139 | local iH = imB:size(2) 140 | local iW = imB:size(3) 141 | if iH~=oH then 142 | h1 = math.ceil(torch.uniform(1e-2, iH-oH)) 143 | end 144 | 145 | if iW~=oW then 146 | w1 = math.ceil(torch.uniform(1e-2, iW-oW)) 147 | end 148 | if iH ~= oH or iW ~= oW then 149 | imB = image.crop(imB, w1, h1, w1 + oW, h1 + oH) 150 | end 151 | local imA = imB:clone() 152 | imA[{{},{1 + oH/4, oH/2 + oH/4},{1 + oW/4, oW/2 + oW/4}}] = 1.0 153 | if opt.flip == 1 and torch.uniform() > 0.5 then 154 | imA = image.hflip(imA) 155 | imB = image.hflip(imB) 156 | end 157 | imAB = torch.cat(imA, imB, 1) 158 | return imAB 159 | end 160 | 161 | -- channel-wise mean and std. Calculate or load them from disk later in the script. 162 | local mean,std 163 | -------------------------------------------------------------------------------- 164 | -- Hooks that are used for each image that is loaded 165 | 166 | -- function to load the image, jitter it appropriately (random crops etc.) 167 | local trainHook = function(self, path) 168 | collectgarbage() 169 | if opt.preprocess == 'regular' then 170 | local imA, imB = loadImage(path) 171 | imA, imB = preprocessAandB(imA, imB) 172 | imAB = torch.cat(imA, imB, 1) 173 | end 174 | 175 | if opt.preprocess == 'colorization' then 176 | imAB = loadImageChannel(path) 177 | end 178 | 179 | if opt.preprocess == 'inpaint' then 180 | imAB = loadImageInpaint(path) 181 | end 182 | return imAB 183 | end 184 | 185 | -------------------------------------- 186 | -- trainLoader 187 | print('trainCache', trainCache) 188 | print('Creating train metadata') 189 | print('serial batch:, ', opt.serial_batches) 190 | trainLoader = dataLoader{ 191 | paths = {opt.data}, 192 | loadSize = {input_nc, loadSize[2], loadSize[2]}, 193 | sampleSize = {input_nc+output_nc, sampleSize[2], sampleSize[2]}, 194 | split = 100, 195 | serial_batches = opt.serial_batches, 196 | verbose = true 197 | } 198 | 199 | trainLoader.sampleHookTrain = trainHook 200 | collectgarbage() 201 | 202 | -- do some sanity checks on trainLoader 203 | do 204 | local class = trainLoader.imageClass 205 | local nClasses = #trainLoader.classes 206 | assert(class:max() <= nClasses, "class logic has error") 207 | assert(class:min() >= 1, "class logic has error") 208 | end 209 | -------------------------------------------------------------------------------- /datasets/bibtex/cityscapes.tex: -------------------------------------------------------------------------------- 1 | @inproceedings{Cordts2016Cityscapes, 2 | title={The Cityscapes Dataset for Semantic Urban Scene Understanding}, 3 | author={Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt}, 4 | booktitle={Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 5 | year={2016} 6 | } 7 | -------------------------------------------------------------------------------- /datasets/bibtex/facades.tex: -------------------------------------------------------------------------------- 1 | @INPROCEEDINGS{Tylecek13, 2 | author = {Radim Tyle{\v c}ek, Radim {\v S}{\' a}ra}, 3 | title = {Spatial Pattern Templates for Recognition of Objects with Regular Structure}, 4 | booktitle = {Proc. GCPR}, 5 | year = {2013}, 6 | address = {Saarbrucken, Germany}, 7 | } 8 | -------------------------------------------------------------------------------- /datasets/bibtex/handbags.tex: -------------------------------------------------------------------------------- 1 | @inproceedings{zhu2016generative, 2 | title={Generative Visual Manipulation on the Natural Image Manifold}, 3 | author={Zhu, Jun-Yan and Kr{\"a}henb{\"u}hl, Philipp and Shechtman, Eli and Efros, Alexei A.}, 4 | booktitle={Proceedings of European Conference on Computer Vision (ECCV)}, 5 | year={2016} 6 | } 7 | 8 | @InProceedings{xie15hed, 9 | author = {"Xie, Saining and Tu, Zhuowen"}, 10 | Title = {Holistically-Nested Edge Detection}, 11 | Booktitle = "Proceedings of IEEE International Conference on Computer Vision", 12 | Year = {2015}, 13 | } 14 | -------------------------------------------------------------------------------- /datasets/bibtex/shoes.tex: -------------------------------------------------------------------------------- 1 | @InProceedings{fine-grained, 2 | author = {A. Yu and K. Grauman}, 3 | title = {{F}ine-{G}rained {V}isual {C}omparisons with {L}ocal {L}earning}, 4 | booktitle = {Computer Vision and Pattern Recognition (CVPR)}, 5 | month = {June}, 6 | year = {2014} 7 | } 8 | 9 | @InProceedings{xie15hed, 10 | author = {"Xie, Saining and Tu, Zhuowen"}, 11 | Title = {Holistically-Nested Edge Detection}, 12 | Booktitle = "Proceedings of IEEE International Conference on Computer Vision", 13 | Year = {2015}, 14 | } 15 | -------------------------------------------------------------------------------- /datasets/bibtex/transattr.tex: -------------------------------------------------------------------------------- 1 | @article {Laffont14, 2 | title = {Transient Attributes for High-Level Understanding and Editing of Outdoor Scenes}, 3 | author = {Pierre-Yves Laffont and Zhile Ren and Xiaofeng Tao and Chao Qian and James Hays}, 4 | journal = {ACM Transactions on Graphics (proceedings of SIGGRAPH)}, 5 | volume = {33}, 6 | number = {4}, 7 | year = {2014} 8 | } 9 | -------------------------------------------------------------------------------- /datasets/download_dataset.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | 3 | if [[ $FILE != "cityscapes" && $FILE != "night2day" && $FILE != "edges2handbags" && $FILE != "edges2shoes" && $FILE != "facades" && $FILE != "maps" ]]; then 4 | echo "Available datasets are cityscapes, night2day, edges2handbags, edges2shoes, facades, maps" 5 | exit 1 6 | fi 7 | 8 | echo "Specified [$FILE]" 9 | 10 | URL=http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/$FILE.tar.gz 11 | TAR_FILE=./datasets/$FILE.tar.gz 12 | TARGET_DIR=./datasets/$FILE/ 13 | wget -N $URL -O $TAR_FILE 14 | mkdir -p $TARGET_DIR 15 | tar -zxvf $TAR_FILE -C ./datasets/ 16 | rm $TAR_FILE 17 | -------------------------------------------------------------------------------- /imgs/examples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phillipi/pix2pix/89ff2a81ce441fbe1f1b13eca463b87f1e539df8/imgs/examples.jpg -------------------------------------------------------------------------------- /models.lua: -------------------------------------------------------------------------------- 1 | require 'nngraph' 2 | 3 | function defineG_encoder_decoder(input_nc, output_nc, ngf) 4 | local netG = nil 5 | -- input is (nc) x 256 x 256 6 | local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1) 7 | -- input is (ngf) x 128 x 128 8 | local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2) 9 | -- input is (ngf * 2) x 64 x 64 10 | local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4) 11 | -- input is (ngf * 4) x 32 x 32 12 | local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) 13 | -- input is (ngf * 8) x 16 x 16 14 | local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) 15 | -- input is (ngf * 8) x 8 x 8 16 | local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) 17 | -- input is (ngf * 8) x 4 x 4 18 | local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) 19 | -- input is (ngf * 8) x 2 x 2 20 | local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) 21 | -- input is (ngf * 8) x 1 x 1 22 | 23 | local d1 = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5) 24 | -- input is (ngf * 8) x 2 x 2 25 | local d2 = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5) 26 | -- input is (ngf * 8) x 4 x 4 27 | local d3 = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5) 28 | -- input is (ngf * 8) x 8 x 8 29 | local d4 = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) 30 | -- input is (ngf * 8) x 16 x 16 31 | local d5 = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4) 32 | -- input is (ngf * 4) x 32 x 32 33 | local d6 = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2) 34 | -- input is (ngf * 2) x 64 x 64 35 | local d7 = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, ngf, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf) 36 | -- input is (ngf) x128 x 128 37 | local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf, output_nc, 4, 4, 2, 2, 1, 1) 38 | -- input is (nc) x 256 x 256 39 | 40 | local o1 = d8 - nn.Tanh() 41 | 42 | netG = nn.gModule({e1},{o1}) 43 | 44 | return netG 45 | end 46 | 47 | function defineG_unet(input_nc, output_nc, ngf) 48 | local netG = nil 49 | -- input is (nc) x 256 x 256 50 | local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1) 51 | -- input is (ngf) x 128 x 128 52 | local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2) 53 | -- input is (ngf * 2) x 64 x 64 54 | local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4) 55 | -- input is (ngf * 4) x 32 x 32 56 | local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) 57 | -- input is (ngf * 8) x 16 x 16 58 | local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) 59 | -- input is (ngf * 8) x 8 x 8 60 | local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) 61 | -- input is (ngf * 8) x 4 x 4 62 | local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) 63 | -- input is (ngf * 8) x 2 x 2 64 | local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) 65 | -- input is (ngf * 8) x 1 x 1 66 | 67 | local d1_ = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5) 68 | -- input is (ngf * 8) x 2 x 2 69 | local d1 = {d1_,e7} - nn.JoinTable(2) 70 | local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5) 71 | -- input is (ngf * 8) x 4 x 4 72 | local d2 = {d2_,e6} - nn.JoinTable(2) 73 | local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5) 74 | -- input is (ngf * 8) x 8 x 8 75 | local d3 = {d3_,e5} - nn.JoinTable(2) 76 | local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) 77 | -- input is (ngf * 8) x 16 x 16 78 | local d4 = {d4_,e4} - nn.JoinTable(2) 79 | local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4) 80 | -- input is (ngf * 4) x 32 x 32 81 | local d5 = {d5_,e3} - nn.JoinTable(2) 82 | local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2) 83 | -- input is (ngf * 2) x 64 x 64 84 | local d6 = {d6_,e2} - nn.JoinTable(2) 85 | local d7_ = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf) 86 | -- input is (ngf) x128 x 128 87 | local d7 = {d7_,e1} - nn.JoinTable(2) 88 | local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1) 89 | -- input is (nc) x 256 x 256 90 | 91 | local o1 = d8 - nn.Tanh() 92 | 93 | netG = nn.gModule({e1},{o1}) 94 | 95 | --graph.dot(netG.fg,'netG') 96 | 97 | return netG 98 | end 99 | 100 | function defineG_unet_128(input_nc, output_nc, ngf) 101 | -- Two layer less than the default unet to handle 128x128 input 102 | local netG = nil 103 | -- input is (nc) x 128 x 128 104 | local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1) 105 | -- input is (ngf) x 64 x 64 106 | local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2) 107 | -- input is (ngf * 2) x 32 x 32 108 | local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4) 109 | -- input is (ngf * 4) x 16 x 16 110 | local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) 111 | -- input is (ngf * 8) x 8 x 8 112 | local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) 113 | -- input is (ngf * 8) x 4 x 4 114 | local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) 115 | -- input is (ngf * 8) x 2 x 2 116 | local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) 117 | -- input is (ngf * 8) x 1 x 1 118 | 119 | local d1_ = e7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5) 120 | -- input is (ngf * 8) x 2 x 2 121 | local d1 = {d1_,e6} - nn.JoinTable(2) 122 | local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5) 123 | -- input is (ngf * 8) x 4 x 4 124 | local d2 = {d2_,e5} - nn.JoinTable(2) 125 | local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5) 126 | -- input is (ngf * 8) x 8 x 8 127 | local d3 = {d3_,e4} - nn.JoinTable(2) 128 | local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4) 129 | -- input is (ngf * 8) x 16 x 16 130 | local d4 = {d4_,e3} - nn.JoinTable(2) 131 | local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2) 132 | -- input is (ngf * 4) x 32 x 32 133 | local d5 = {d5_,e2} - nn.JoinTable(2) 134 | local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf) 135 | -- input is (ngf * 2) x 64 x 64 136 | local d6 = {d6_,e1} - nn.JoinTable(2) 137 | local d7 = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1) 138 | -- input is (ngf) x128 x 128 139 | 140 | local o1 = d7 - nn.Tanh() 141 | 142 | netG = nn.gModule({e1},{o1}) 143 | 144 | --graph.dot(netG.fg,'netG') 145 | 146 | return netG 147 | end 148 | 149 | function defineD_basic(input_nc, output_nc, ndf) 150 | n_layers = 3 151 | return defineD_n_layers(input_nc, output_nc, ndf, n_layers) 152 | end 153 | 154 | -- rf=1 155 | function defineD_pixelGAN(input_nc, output_nc, ndf) 156 | local netD = nn.Sequential() 157 | 158 | -- input is (nc) x 256 x 256 159 | netD:add(nn.SpatialConvolution(input_nc+output_nc, ndf, 1, 1, 1, 1, 0, 0)) 160 | netD:add(nn.LeakyReLU(0.2, true)) 161 | -- state size: (ndf) x 256 x 256 162 | netD:add(nn.SpatialConvolution(ndf, ndf * 2, 1, 1, 1, 1, 0, 0)) 163 | netD:add(nn.SpatialBatchNormalization(ndf * 2)):add(nn.LeakyReLU(0.2, true)) 164 | -- state size: (ndf*2) x 256 x 256 165 | netD:add(nn.SpatialConvolution(ndf * 2, 1, 1, 1, 1, 1, 0, 0)) 166 | -- state size: 1 x 256 x 256 167 | netD:add(nn.Sigmoid()) 168 | -- state size: 1 x 256 x 256 169 | 170 | return netD 171 | end 172 | 173 | -- if n=0, then use pixelGAN (rf=1) 174 | -- else rf is 16 if n=1 175 | -- 34 if n=2 176 | -- 70 if n=3 177 | -- 142 if n=4 178 | -- 286 if n=5 179 | -- 574 if n=6 180 | function defineD_n_layers(input_nc, output_nc, ndf, n_layers) 181 | if n_layers==0 then 182 | return defineD_pixelGAN(input_nc, output_nc, ndf) 183 | else 184 | 185 | local netD = nn.Sequential() 186 | 187 | -- input is (nc) x 256 x 256 188 | netD:add(nn.SpatialConvolution(input_nc+output_nc, ndf, 4, 4, 2, 2, 1, 1)) 189 | netD:add(nn.LeakyReLU(0.2, true)) 190 | 191 | local nf_mult = 1 192 | local nf_mult_prev = 1 193 | for n = 1, n_layers-1 do 194 | nf_mult_prev = nf_mult 195 | nf_mult = math.min(2^n,8) 196 | netD:add(nn.SpatialConvolution(ndf * nf_mult_prev, ndf * nf_mult, 4, 4, 2, 2, 1, 1)) 197 | netD:add(nn.SpatialBatchNormalization(ndf * nf_mult)):add(nn.LeakyReLU(0.2, true)) 198 | end 199 | 200 | -- state size: (ndf*M) x N x N 201 | nf_mult_prev = nf_mult 202 | nf_mult = math.min(2^n_layers,8) 203 | netD:add(nn.SpatialConvolution(ndf * nf_mult_prev, ndf * nf_mult, 4, 4, 1, 1, 1, 1)) 204 | netD:add(nn.SpatialBatchNormalization(ndf * nf_mult)):add(nn.LeakyReLU(0.2, true)) 205 | -- state size: (ndf*M*2) x (N-1) x (N-1) 206 | netD:add(nn.SpatialConvolution(ndf * nf_mult, 1, 4, 4, 1, 1, 1, 1)) 207 | -- state size: 1 x (N-2) x (N-2) 208 | 209 | netD:add(nn.Sigmoid()) 210 | -- state size: 1 x (N-2) x (N-2) 211 | 212 | return netD 213 | end 214 | end 215 | -------------------------------------------------------------------------------- /models/download_model.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | URL=http://efrosgans.eecs.berkeley.edu/pix2pix/models/$FILE.t7 3 | MODEL_FILE=./models/$FILE.t7 4 | wget -N $URL -O $MODEL_FILE 5 | -------------------------------------------------------------------------------- /scripts/combine_A_and_B.py: -------------------------------------------------------------------------------- 1 | from pdb import set_trace as st 2 | import os 3 | import numpy as np 4 | import cv2 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser('create image pairs') 8 | parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges') 9 | parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg') 10 | parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB') 11 | parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000) 12 | parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)',action='store_true') 13 | args = parser.parse_args() 14 | 15 | for arg in vars(args): 16 | print('[%s] = ' % arg, getattr(args, arg)) 17 | 18 | splits = filter( lambda f: not f.startswith('.'), os.listdir(args.fold_A)) # ignore hidden folders like .DS_Store 19 | 20 | for sp in splits: 21 | img_fold_A = os.path.join(args.fold_A, sp) 22 | img_fold_B = os.path.join(args.fold_B, sp) 23 | img_list = filter( lambda f: not f.startswith('.'), os.listdir(img_fold_A)) # ignore hidden folders like .DS_Store 24 | img_list = list(img_list) 25 | if args.use_AB: 26 | img_list = [img_path for img_path in img_list if '_A.' in img_path] 27 | 28 | num_imgs = min(args.num_imgs, len(img_list)) 29 | print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list))) 30 | img_fold_AB = os.path.join(args.fold_AB, sp) 31 | if not os.path.isdir(img_fold_AB): 32 | os.makedirs(img_fold_AB) 33 | print('split = %s, number of images = %d' % (sp, num_imgs)) 34 | for n in range(num_imgs): 35 | name_A = img_list[n] 36 | path_A = os.path.join(img_fold_A, name_A) 37 | if args.use_AB: 38 | name_B = name_A.replace('_A.', '_B.') 39 | else: 40 | name_B = name_A 41 | path_B = os.path.join(img_fold_B, name_B) 42 | if os.path.isfile(path_A) and os.path.isfile(path_B): 43 | name_AB = name_A 44 | if args.use_AB: 45 | name_AB = name_AB.replace('_A.', '.') # remove _A 46 | path_AB = os.path.join(img_fold_AB, name_AB) 47 | im_A = cv2.imread(path_A, cv2.IMREAD_COLOR) 48 | im_B = cv2.imread(path_B, cv2.IMREAD_COLOR) 49 | im_AB = np.concatenate([im_A, im_B], 1) 50 | cv2.imwrite(path_AB, im_AB) 51 | 52 | -------------------------------------------------------------------------------- /scripts/edges/PostprocessHED.m: -------------------------------------------------------------------------------- 1 | %%% Prerequisites 2 | % You need to get the cpp file edgesNmsMex.cpp from https://raw.githubusercontent.com/pdollar/edges/master/private/edgesNmsMex.cpp 3 | % and compile it in Matlab: mex edgesNmsMex.cpp 4 | % You also need to download and install Piotr's Computer Vision Matlab Toolbox: https://pdollar.github.io/toolbox/ 5 | 6 | %%% parameters 7 | % hed_mat_dir: the hed mat file directory (the output of 'batch_hed.py') 8 | % edge_dir: the output HED edges directory 9 | % image_width: resize the edge map to [image_width, image_width] 10 | % threshold: threshold for image binarization (default 25.0/255.0) 11 | % small_edge: remove small edges (default 5) 12 | 13 | function [] = PostprocessHED(hed_mat_dir, edge_dir, image_width, threshold, small_edge) 14 | 15 | if ~exist(edge_dir, 'dir') 16 | mkdir(edge_dir); 17 | end 18 | fileList = dir(fullfile(hed_mat_dir, '*.mat')); 19 | nFiles = numel(fileList); 20 | fprintf('find %d mat files\n', nFiles); 21 | 22 | for n = 1 : nFiles 23 | if mod(n, 1000) == 0 24 | fprintf('process %d/%d images\n', n, nFiles); 25 | end 26 | fileName = fileList(n).name; 27 | filePath = fullfile(hed_mat_dir, fileName); 28 | jpgName = strrep(fileName, '.mat', '.jpg'); 29 | edge_path = fullfile(edge_dir, jpgName); 30 | 31 | if ~exist(edge_path, 'file') 32 | E = GetEdge(filePath); 33 | E = imresize(E,[image_width,image_width]); 34 | E_simple = SimpleEdge(E, threshold, small_edge); 35 | E_simple = uint8(E_simple*255); 36 | imwrite(E_simple, edge_path, 'Quality',100); 37 | end 38 | end 39 | end 40 | 41 | 42 | 43 | 44 | function [E] = GetEdge(filePath) 45 | load(filePath); 46 | E = 1-edge_predict; 47 | end 48 | 49 | function [E4] = SimpleEdge(E, threshold, small_edge) 50 | if nargin <= 1 51 | threshold = 25.0/255.0; 52 | end 53 | 54 | if nargin <= 2 55 | small_edge = 5; 56 | end 57 | 58 | if ndims(E) == 3 59 | E = E(:,:,1); 60 | end 61 | 62 | E1 = 1 - E; 63 | E2 = EdgeNMS(E1); 64 | E3 = double(E2>=max(eps,threshold)); 65 | E3 = bwmorph(E3,'thin',inf); 66 | E4 = bwareaopen(E3, small_edge); 67 | E4=1-E4; 68 | end 69 | 70 | function [E_nms] = EdgeNMS( E ) 71 | E=single(E); 72 | [Ox,Oy] = gradient2(convTri(E,4)); 73 | [Oxx,~] = gradient2(Ox); 74 | [Oxy,Oyy] = gradient2(Oy); 75 | O = mod(atan(Oyy.*sign(-Oxy)./(Oxx+1e-5)),pi); 76 | E_nms = edgesNmsMex(E,O,1,5,1.01,1); 77 | end 78 | -------------------------------------------------------------------------------- /scripts/edges/batch_hed.py: -------------------------------------------------------------------------------- 1 | # HED batch processing script; modified from https://github.com/s9xie/hed/blob/master/examples/hed/HED-tutorial.ipynb 2 | # Step 1: download the hed repo: https://github.com/s9xie/hed 3 | # Step 2: download the models and protoxt, and put them under {caffe_root}/examples/hed/ 4 | # Step 3: put this script under {caffe_root}/examples/hed/ 5 | # Step 4: run the following script: 6 | # python batch_hed.py --images_dir=/data/to/path/photos/ --hed_mat_dir=/data/to/path/hed_mat_files/ 7 | # The code sometimes crashes after computation is done. Error looks like "Check failed: ... driver shutting down". You can just kill the job. 8 | # For large images, it will produce gpu memory issue. Therefore, you better resize the images before running this script. 9 | # Step 5: run the MATLAB post-processing script "PostprocessHED.m" 10 | import scipy.io as sio 11 | import caffe 12 | import sys 13 | import numpy as np 14 | from PIL import Image 15 | import os 16 | import argparse 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description='batch proccesing: photos->edges') 21 | parser.add_argument('--caffe_root', dest='caffe_root', help='caffe root', default='../../', type=str) 22 | parser.add_argument('--caffemodel', dest='caffemodel', help='caffemodel', default='./hed_pretrained_bsds.caffemodel', type=str) 23 | parser.add_argument('--prototxt', dest='prototxt', help='caffe prototxt file', default='./deploy.prototxt', type=str) 24 | parser.add_argument('--images_dir', dest='images_dir', help='directory to store input photos', type=str) 25 | parser.add_argument('--hed_mat_dir', dest='hed_mat_dir', help='directory to store output hed edges in mat file', type=str) 26 | parser.add_argument('--border', dest='border', help='padding border', type=int, default=128) 27 | parser.add_argument('--gpu_id', dest='gpu_id', help='gpu id', type=int, default=1) 28 | args = parser.parse_args() 29 | return args 30 | 31 | 32 | args = parse_args() 33 | for arg in vars(args): 34 | print('[%s] =' % arg, getattr(args, arg)) 35 | # Make sure that caffe is on the python path: 36 | caffe_root = args.caffe_root # this file is expected to be in {caffe_root}/examples/hed/ 37 | sys.path.insert(0, caffe_root + 'python') 38 | 39 | 40 | if not os.path.exists(args.hed_mat_dir): 41 | print('create output directory %s' % args.hed_mat_dir) 42 | os.makedirs(args.hed_mat_dir) 43 | 44 | imgList = os.listdir(args.images_dir) 45 | nImgs = len(imgList) 46 | print('#images = %d' % nImgs) 47 | 48 | caffe.set_mode_gpu() 49 | caffe.set_device(args.gpu_id) 50 | # load net 51 | net = caffe.Net(args.prototxt, args.caffemodel, caffe.TEST) 52 | # pad border 53 | border = args.border 54 | 55 | for i in range(nImgs): 56 | if i % 500 == 0: 57 | print('processing image %d/%d' % (i, nImgs)) 58 | im = Image.open(os.path.join(args.images_dir, imgList[i])) 59 | 60 | in_ = np.array(im, dtype=np.float32) 61 | in_ = np.pad(in_, ((border, border), (border, border), (0, 0)), 'reflect') 62 | 63 | in_ = in_[:, :, 0:3] 64 | in_ = in_[:, :, ::-1] 65 | in_ -= np.array((104.00698793, 116.66876762, 122.67891434)) 66 | in_ = in_.transpose((2, 0, 1)) 67 | # remove the following two lines if testing with cpu 68 | 69 | # shape for input (data blob is N x C x H x W), set data 70 | net.blobs['data'].reshape(1, *in_.shape) 71 | net.blobs['data'].data[...] = in_ 72 | # run net and take argmax for prediction 73 | net.forward() 74 | fuse = net.blobs['sigmoid-fuse'].data[0][0, :, :] 75 | # get rid of the border 76 | fuse = fuse[(border+35):(-border+35), (border+35):(-border+35)] 77 | # save hed file to the disk 78 | name, ext = os.path.splitext(imgList[i]) 79 | sio.savemat(os.path.join(args.hed_mat_dir, name + '.mat'), {'edge_predict': fuse}) 80 | -------------------------------------------------------------------------------- /scripts/eval_cityscapes/caffemodel/deploy.prototxt: -------------------------------------------------------------------------------- 1 | layer { 2 | name: "data" 3 | type: "Input" 4 | top: "data" 5 | input_param { 6 | shape { 7 | dim: 1 8 | dim: 3 9 | dim: 500 10 | dim: 500 11 | } 12 | } 13 | } 14 | layer { 15 | name: "conv1_1" 16 | type: "Convolution" 17 | bottom: "data" 18 | top: "conv1_1" 19 | param { 20 | lr_mult: 1 21 | decay_mult: 1 22 | } 23 | param { 24 | lr_mult: 2 25 | decay_mult: 0 26 | } 27 | convolution_param { 28 | num_output: 64 29 | pad: 100 30 | kernel_size: 3 31 | stride: 1 32 | weight_filler { 33 | type: "gaussian" 34 | std: 0.01 35 | } 36 | bias_filler { 37 | type: "constant" 38 | value: 0 39 | } 40 | } 41 | } 42 | layer { 43 | name: "relu1_1" 44 | type: "ReLU" 45 | bottom: "conv1_1" 46 | top: "conv1_1" 47 | } 48 | layer { 49 | name: "conv1_2" 50 | type: "Convolution" 51 | bottom: "conv1_1" 52 | top: "conv1_2" 53 | param { 54 | lr_mult: 1 55 | decay_mult: 1 56 | } 57 | param { 58 | lr_mult: 2 59 | decay_mult: 0 60 | } 61 | convolution_param { 62 | num_output: 64 63 | pad: 1 64 | kernel_size: 3 65 | stride: 1 66 | weight_filler { 67 | type: "gaussian" 68 | std: 0.01 69 | } 70 | bias_filler { 71 | type: "constant" 72 | value: 0 73 | } 74 | } 75 | } 76 | layer { 77 | name: "relu1_2" 78 | type: "ReLU" 79 | bottom: "conv1_2" 80 | top: "conv1_2" 81 | } 82 | layer { 83 | name: "pool1" 84 | type: "Pooling" 85 | bottom: "conv1_2" 86 | top: "pool1" 87 | pooling_param { 88 | pool: MAX 89 | kernel_size: 2 90 | stride: 2 91 | } 92 | } 93 | layer { 94 | name: "conv2_1" 95 | type: "Convolution" 96 | bottom: "pool1" 97 | top: "conv2_1" 98 | param { 99 | lr_mult: 1 100 | decay_mult: 1 101 | } 102 | param { 103 | lr_mult: 2 104 | decay_mult: 0 105 | } 106 | convolution_param { 107 | num_output: 128 108 | pad: 1 109 | kernel_size: 3 110 | stride: 1 111 | weight_filler { 112 | type: "gaussian" 113 | std: 0.01 114 | } 115 | bias_filler { 116 | type: "constant" 117 | value: 0 118 | } 119 | } 120 | } 121 | layer { 122 | name: "relu2_1" 123 | type: "ReLU" 124 | bottom: "conv2_1" 125 | top: "conv2_1" 126 | } 127 | layer { 128 | name: "conv2_2" 129 | type: "Convolution" 130 | bottom: "conv2_1" 131 | top: "conv2_2" 132 | param { 133 | lr_mult: 1 134 | decay_mult: 1 135 | } 136 | param { 137 | lr_mult: 2 138 | decay_mult: 0 139 | } 140 | convolution_param { 141 | num_output: 128 142 | pad: 1 143 | kernel_size: 3 144 | stride: 1 145 | weight_filler { 146 | type: "gaussian" 147 | std: 0.01 148 | } 149 | bias_filler { 150 | type: "constant" 151 | value: 0 152 | } 153 | } 154 | } 155 | layer { 156 | name: "relu2_2" 157 | type: "ReLU" 158 | bottom: "conv2_2" 159 | top: "conv2_2" 160 | } 161 | layer { 162 | name: "pool2" 163 | type: "Pooling" 164 | bottom: "conv2_2" 165 | top: "pool2" 166 | pooling_param { 167 | pool: MAX 168 | kernel_size: 2 169 | stride: 2 170 | } 171 | } 172 | layer { 173 | name: "conv3_1" 174 | type: "Convolution" 175 | bottom: "pool2" 176 | top: "conv3_1" 177 | param { 178 | lr_mult: 1 179 | decay_mult: 1 180 | } 181 | param { 182 | lr_mult: 2 183 | decay_mult: 0 184 | } 185 | convolution_param { 186 | num_output: 256 187 | pad: 1 188 | kernel_size: 3 189 | stride: 1 190 | weight_filler { 191 | type: "gaussian" 192 | std: 0.01 193 | } 194 | bias_filler { 195 | type: "constant" 196 | value: 0 197 | } 198 | } 199 | } 200 | layer { 201 | name: "relu3_1" 202 | type: "ReLU" 203 | bottom: "conv3_1" 204 | top: "conv3_1" 205 | } 206 | layer { 207 | name: "conv3_2" 208 | type: "Convolution" 209 | bottom: "conv3_1" 210 | top: "conv3_2" 211 | param { 212 | lr_mult: 1 213 | decay_mult: 1 214 | } 215 | param { 216 | lr_mult: 2 217 | decay_mult: 0 218 | } 219 | convolution_param { 220 | num_output: 256 221 | pad: 1 222 | kernel_size: 3 223 | stride: 1 224 | weight_filler { 225 | type: "gaussian" 226 | std: 0.01 227 | } 228 | bias_filler { 229 | type: "constant" 230 | value: 0 231 | } 232 | } 233 | } 234 | layer { 235 | name: "relu3_2" 236 | type: "ReLU" 237 | bottom: "conv3_2" 238 | top: "conv3_2" 239 | } 240 | layer { 241 | name: "conv3_3" 242 | type: "Convolution" 243 | bottom: "conv3_2" 244 | top: "conv3_3" 245 | param { 246 | lr_mult: 1 247 | decay_mult: 1 248 | } 249 | param { 250 | lr_mult: 2 251 | decay_mult: 0 252 | } 253 | convolution_param { 254 | num_output: 256 255 | pad: 1 256 | kernel_size: 3 257 | stride: 1 258 | weight_filler { 259 | type: "gaussian" 260 | std: 0.01 261 | } 262 | bias_filler { 263 | type: "constant" 264 | value: 0 265 | } 266 | } 267 | } 268 | layer { 269 | name: "relu3_3" 270 | type: "ReLU" 271 | bottom: "conv3_3" 272 | top: "conv3_3" 273 | } 274 | layer { 275 | name: "pool3" 276 | type: "Pooling" 277 | bottom: "conv3_3" 278 | top: "pool3" 279 | pooling_param { 280 | pool: MAX 281 | kernel_size: 2 282 | stride: 2 283 | } 284 | } 285 | layer { 286 | name: "conv4_1" 287 | type: "Convolution" 288 | bottom: "pool3" 289 | top: "conv4_1" 290 | param { 291 | lr_mult: 1 292 | decay_mult: 1 293 | } 294 | param { 295 | lr_mult: 2 296 | decay_mult: 0 297 | } 298 | convolution_param { 299 | num_output: 512 300 | pad: 1 301 | kernel_size: 3 302 | stride: 1 303 | weight_filler { 304 | type: "gaussian" 305 | std: 0.01 306 | } 307 | bias_filler { 308 | type: "constant" 309 | value: 0 310 | } 311 | } 312 | } 313 | layer { 314 | name: "relu4_1" 315 | type: "ReLU" 316 | bottom: "conv4_1" 317 | top: "conv4_1" 318 | } 319 | layer { 320 | name: "conv4_2" 321 | type: "Convolution" 322 | bottom: "conv4_1" 323 | top: "conv4_2" 324 | param { 325 | lr_mult: 1 326 | decay_mult: 1 327 | } 328 | param { 329 | lr_mult: 2 330 | decay_mult: 0 331 | } 332 | convolution_param { 333 | num_output: 512 334 | pad: 1 335 | kernel_size: 3 336 | stride: 1 337 | weight_filler { 338 | type: "gaussian" 339 | std: 0.01 340 | } 341 | bias_filler { 342 | type: "constant" 343 | value: 0 344 | } 345 | } 346 | } 347 | layer { 348 | name: "relu4_2" 349 | type: "ReLU" 350 | bottom: "conv4_2" 351 | top: "conv4_2" 352 | } 353 | layer { 354 | name: "conv4_3" 355 | type: "Convolution" 356 | bottom: "conv4_2" 357 | top: "conv4_3" 358 | param { 359 | lr_mult: 1 360 | decay_mult: 1 361 | } 362 | param { 363 | lr_mult: 2 364 | decay_mult: 0 365 | } 366 | convolution_param { 367 | num_output: 512 368 | pad: 1 369 | kernel_size: 3 370 | stride: 1 371 | weight_filler { 372 | type: "gaussian" 373 | std: 0.01 374 | } 375 | bias_filler { 376 | type: "constant" 377 | value: 0 378 | } 379 | } 380 | } 381 | layer { 382 | name: "relu4_3" 383 | type: "ReLU" 384 | bottom: "conv4_3" 385 | top: "conv4_3" 386 | } 387 | layer { 388 | name: "pool4" 389 | type: "Pooling" 390 | bottom: "conv4_3" 391 | top: "pool4" 392 | pooling_param { 393 | pool: MAX 394 | kernel_size: 2 395 | stride: 2 396 | } 397 | } 398 | layer { 399 | name: "conv5_1" 400 | type: "Convolution" 401 | bottom: "pool4" 402 | top: "conv5_1" 403 | param { 404 | lr_mult: 1 405 | decay_mult: 1 406 | } 407 | param { 408 | lr_mult: 2 409 | decay_mult: 0 410 | } 411 | convolution_param { 412 | num_output: 512 413 | pad: 1 414 | kernel_size: 3 415 | stride: 1 416 | weight_filler { 417 | type: "gaussian" 418 | std: 0.01 419 | } 420 | bias_filler { 421 | type: "constant" 422 | value: 0 423 | } 424 | } 425 | } 426 | layer { 427 | name: "relu5_1" 428 | type: "ReLU" 429 | bottom: "conv5_1" 430 | top: "conv5_1" 431 | } 432 | layer { 433 | name: "conv5_2" 434 | type: "Convolution" 435 | bottom: "conv5_1" 436 | top: "conv5_2" 437 | param { 438 | lr_mult: 1 439 | decay_mult: 1 440 | } 441 | param { 442 | lr_mult: 2 443 | decay_mult: 0 444 | } 445 | convolution_param { 446 | num_output: 512 447 | pad: 1 448 | kernel_size: 3 449 | stride: 1 450 | weight_filler { 451 | type: "gaussian" 452 | std: 0.01 453 | } 454 | bias_filler { 455 | type: "constant" 456 | value: 0 457 | } 458 | } 459 | } 460 | layer { 461 | name: "relu5_2" 462 | type: "ReLU" 463 | bottom: "conv5_2" 464 | top: "conv5_2" 465 | } 466 | layer { 467 | name: "conv5_3" 468 | type: "Convolution" 469 | bottom: "conv5_2" 470 | top: "conv5_3" 471 | param { 472 | lr_mult: 1 473 | decay_mult: 1 474 | } 475 | param { 476 | lr_mult: 2 477 | decay_mult: 0 478 | } 479 | convolution_param { 480 | num_output: 512 481 | pad: 1 482 | kernel_size: 3 483 | stride: 1 484 | weight_filler { 485 | type: "gaussian" 486 | std: 0.01 487 | } 488 | bias_filler { 489 | type: "constant" 490 | value: 0 491 | } 492 | } 493 | } 494 | layer { 495 | name: "relu5_3" 496 | type: "ReLU" 497 | bottom: "conv5_3" 498 | top: "conv5_3" 499 | } 500 | layer { 501 | name: "pool5" 502 | type: "Pooling" 503 | bottom: "conv5_3" 504 | top: "pool5" 505 | pooling_param { 506 | pool: MAX 507 | kernel_size: 2 508 | stride: 2 509 | } 510 | } 511 | layer { 512 | name: "fc6_cs" 513 | type: "Convolution" 514 | bottom: "pool5" 515 | top: "fc6_cs" 516 | param { 517 | lr_mult: 1 518 | decay_mult: 1 519 | } 520 | param { 521 | lr_mult: 2 522 | decay_mult: 0 523 | } 524 | convolution_param { 525 | num_output: 4096 526 | pad: 0 527 | kernel_size: 7 528 | stride: 1 529 | weight_filler { 530 | type: "gaussian" 531 | std: 0.01 532 | } 533 | bias_filler { 534 | type: "constant" 535 | value: 0 536 | } 537 | } 538 | } 539 | layer { 540 | name: "relu6_cs" 541 | type: "ReLU" 542 | bottom: "fc6_cs" 543 | top: "fc6_cs" 544 | } 545 | layer { 546 | name: "fc7_cs" 547 | type: "Convolution" 548 | bottom: "fc6_cs" 549 | top: "fc7_cs" 550 | param { 551 | lr_mult: 1 552 | decay_mult: 1 553 | } 554 | param { 555 | lr_mult: 2 556 | decay_mult: 0 557 | } 558 | convolution_param { 559 | num_output: 4096 560 | pad: 0 561 | kernel_size: 1 562 | stride: 1 563 | weight_filler { 564 | type: "gaussian" 565 | std: 0.01 566 | } 567 | bias_filler { 568 | type: "constant" 569 | value: 0 570 | } 571 | } 572 | } 573 | layer { 574 | name: "relu7_cs" 575 | type: "ReLU" 576 | bottom: "fc7_cs" 577 | top: "fc7_cs" 578 | } 579 | layer { 580 | name: "score_fr" 581 | type: "Convolution" 582 | bottom: "fc7_cs" 583 | top: "score_fr" 584 | param { 585 | lr_mult: 1 586 | decay_mult: 1 587 | } 588 | param { 589 | lr_mult: 2 590 | decay_mult: 0 591 | } 592 | convolution_param { 593 | num_output: 20 594 | pad: 0 595 | kernel_size: 1 596 | weight_filler { 597 | type: "xavier" 598 | } 599 | bias_filler { 600 | type: "constant" 601 | } 602 | } 603 | } 604 | layer { 605 | name: "upscore2" 606 | type: "Deconvolution" 607 | bottom: "score_fr" 608 | top: "upscore2" 609 | param { 610 | lr_mult: 1 611 | } 612 | convolution_param { 613 | num_output: 20 614 | bias_term: false 615 | kernel_size: 4 616 | stride: 2 617 | weight_filler { 618 | type: "xavier" 619 | } 620 | bias_filler { 621 | type: "constant" 622 | } 623 | } 624 | } 625 | layer { 626 | name: "score_pool4" 627 | type: "Convolution" 628 | bottom: "pool4" 629 | top: "score_pool4" 630 | param { 631 | lr_mult: 1 632 | decay_mult: 1 633 | } 634 | param { 635 | lr_mult: 2 636 | decay_mult: 0 637 | } 638 | convolution_param { 639 | num_output: 20 640 | pad: 0 641 | kernel_size: 1 642 | weight_filler { 643 | type: "xavier" 644 | } 645 | bias_filler { 646 | type: "constant" 647 | } 648 | } 649 | } 650 | layer { 651 | name: "score_pool4c" 652 | type: "Crop" 653 | bottom: "score_pool4" 654 | bottom: "upscore2" 655 | top: "score_pool4c" 656 | crop_param { 657 | axis: 2 658 | offset: 5 659 | } 660 | } 661 | layer { 662 | name: "fuse_pool4" 663 | type: "Eltwise" 664 | bottom: "upscore2" 665 | bottom: "score_pool4c" 666 | top: "fuse_pool4" 667 | eltwise_param { 668 | operation: SUM 669 | } 670 | } 671 | layer { 672 | name: "upscore_pool4" 673 | type: "Deconvolution" 674 | bottom: "fuse_pool4" 675 | top: "upscore_pool4" 676 | param { 677 | lr_mult: 1 678 | } 679 | convolution_param { 680 | num_output: 20 681 | bias_term: false 682 | kernel_size: 4 683 | stride: 2 684 | weight_filler { 685 | type: "xavier" 686 | } 687 | bias_filler { 688 | type: "constant" 689 | } 690 | } 691 | } 692 | layer { 693 | name: "score_pool3" 694 | type: "Convolution" 695 | bottom: "pool3" 696 | top: "score_pool3" 697 | param { 698 | lr_mult: 1 699 | decay_mult: 1 700 | } 701 | param { 702 | lr_mult: 2 703 | decay_mult: 0 704 | } 705 | convolution_param { 706 | num_output: 20 707 | pad: 0 708 | kernel_size: 1 709 | weight_filler { 710 | type: "xavier" 711 | } 712 | bias_filler { 713 | type: "constant" 714 | } 715 | } 716 | } 717 | layer { 718 | name: "score_pool3c" 719 | type: "Crop" 720 | bottom: "score_pool3" 721 | bottom: "upscore_pool4" 722 | top: "score_pool3c" 723 | crop_param { 724 | axis: 2 725 | offset: 9 726 | } 727 | } 728 | layer { 729 | name: "fuse_pool3" 730 | type: "Eltwise" 731 | bottom: "upscore_pool4" 732 | bottom: "score_pool3c" 733 | top: "fuse_pool3" 734 | eltwise_param { 735 | operation: SUM 736 | } 737 | } 738 | layer { 739 | name: "upscore8" 740 | type: "Deconvolution" 741 | bottom: "fuse_pool3" 742 | top: "upscore8" 743 | param { 744 | lr_mult: 1 745 | } 746 | convolution_param { 747 | num_output: 20 748 | bias_term: false 749 | kernel_size: 16 750 | stride: 8 751 | weight_filler { 752 | type: "xavier" 753 | } 754 | bias_filler { 755 | type: "constant" 756 | } 757 | } 758 | } 759 | layer { 760 | name: "score" 761 | type: "Crop" 762 | bottom: "upscore8" 763 | bottom: "data" 764 | top: "score" 765 | crop_param { 766 | axis: 2 767 | offset: 31 768 | } 769 | } 770 | -------------------------------------------------------------------------------- /scripts/eval_cityscapes/cityscapes.py: -------------------------------------------------------------------------------- 1 | # The following code is modified from https://github.com/shelhamer/clockwork-fcn 2 | import sys 3 | import os 4 | import glob 5 | import numpy as np 6 | from PIL import Image 7 | 8 | class cityscapes: 9 | def __init__(self, data_path): 10 | # data_path something like /data2/cityscapes 11 | self.dir = data_path 12 | self.classes = ['road', 'sidewalk', 'building', 'wall', 'fence', 13 | 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 14 | 'sky', 'person', 'rider', 'car', 'truck', 15 | 'bus', 'train', 'motorcycle', 'bicycle'] 16 | self.mean = np.array((72.78044, 83.21195, 73.45286), dtype=np.float32) 17 | # import cityscapes label helper and set up label mappings 18 | sys.path.insert(0, '{}/scripts/helpers/'.format(self.dir)) 19 | labels = __import__('labels') 20 | self.id2trainId = {label.id: label.trainId for label in labels.labels} # dictionary mapping from raw IDs to train IDs 21 | self.trainId2color = {label.trainId: label.color for label in labels.labels} # dictionary mapping train IDs to colors as 3-tuples 22 | 23 | def get_dset(self, split): 24 | ''' 25 | List images as (city, id) for the specified split 26 | 27 | TODO(shelhamer) generate splits from cityscapes itself, instead of 28 | relying on these separately made text files. 29 | ''' 30 | if split == 'train': 31 | dataset = open('{}/ImageSets/segFine/train.txt'.format(self.dir)).read().splitlines() 32 | else: 33 | dataset = open('{}/ImageSets/segFine/val.txt'.format(self.dir)).read().splitlines() 34 | return [(item.split('/')[0], item.split('/')[1]) for item in dataset] 35 | 36 | def load_image(self, split, city, idx): 37 | im = Image.open('{}/leftImg8bit_sequence/{}/{}/{}_leftImg8bit.png'.format(self.dir, split, city, idx)) 38 | return im 39 | 40 | def assign_trainIds(self, label): 41 | """ 42 | Map the given label IDs to the train IDs appropriate for training 43 | Use the label mapping provided in labels.py from the cityscapes scripts 44 | """ 45 | label = np.array(label, dtype=np.float32) 46 | if sys.version_info[0] < 3: 47 | for k, v in self.id2trainId.iteritems(): 48 | label[label == k] = v 49 | else: 50 | for k, v in self.id2trainId.items(): 51 | label[label == k] = v 52 | return label 53 | 54 | def load_label(self, split, city, idx): 55 | """ 56 | Load label image as 1 x height x width integer array of label indices. 57 | The leading singleton dimension is required by the loss. 58 | """ 59 | label = Image.open('{}/gtFine/{}/{}/{}_gtFine_labelIds.png'.format(self.dir, split, city, idx)) 60 | label = self.assign_trainIds(label) # get proper labels for eval 61 | label = np.array(label, dtype=np.uint8) 62 | label = label[np.newaxis, ...] 63 | return label 64 | 65 | def preprocess(self, im): 66 | """ 67 | Preprocess loaded image (by load_image) for Caffe: 68 | - cast to float 69 | - switch channels RGB -> BGR 70 | - subtract mean 71 | - transpose to channel x height x width order 72 | """ 73 | in_ = np.array(im, dtype=np.float32) 74 | in_ = in_[:, :, ::-1] 75 | in_ -= self.mean 76 | in_ = in_.transpose((2, 0, 1)) 77 | return in_ 78 | 79 | def palette(self, label): 80 | ''' 81 | Map trainIds to colors as specified in labels.py 82 | ''' 83 | if label.ndim == 3: 84 | label= label[0] 85 | color = np.empty((label.shape[0], label.shape[1], 3)) 86 | if sys.version_info[0] < 3: 87 | for k, v in self.trainId2color.iteritems(): 88 | color[label == k, :] = v 89 | else: 90 | for k, v in self.trainId2color.items(): 91 | color[label == k, :] = v 92 | return color 93 | 94 | def make_boundaries(label, thickness=None): 95 | """ 96 | Input is an image label, output is a numpy array mask encoding the boundaries of the objects 97 | Extract pixels at the true boundary by dilation - erosion of label. 98 | Don't just pick the void label as it is not exclusive to the boundaries. 99 | """ 100 | assert(thickness is not None) 101 | import skimage.morphology as skm 102 | void = 255 103 | mask = np.logical_and(label > 0, label != void)[0] 104 | selem = skm.disk(thickness) 105 | boundaries = np.logical_xor(skm.dilation(mask, selem), 106 | skm.erosion(mask, selem)) 107 | return boundaries 108 | 109 | def list_label_frames(self, split): 110 | """ 111 | Select labeled frames from a split for evaluation 112 | collected as (city, shot, idx) tuples 113 | """ 114 | def file2idx(f): 115 | """Helper to convert file path into frame ID""" 116 | city, shot, frame = (os.path.basename(f).split('_')[:3]) 117 | return "_".join([city, shot, frame]) 118 | frames = [] 119 | cities = [os.path.basename(f) for f in glob.glob('{}/gtFine/{}/*'.format(self.dir, split))] 120 | for c in cities: 121 | files = sorted(glob.glob('{}/gtFine/{}/{}/*labelIds.png'.format(self.dir, split, c))) 122 | frames.extend([file2idx(f) for f in files]) 123 | return frames 124 | 125 | def collect_frame_sequence(self, split, idx, length): 126 | """ 127 | Collect sequence of frames preceding (and including) a labeled frame 128 | as a list of Images. 129 | 130 | Note: 19 preceding frames are provided for each labeled frame. 131 | """ 132 | SEQ_LEN = length 133 | city, shot, frame = idx.split('_') 134 | frame = int(frame) 135 | frame_seq = [] 136 | for i in range(frame - SEQ_LEN, frame + 1): 137 | frame_path = '{0}/leftImg8bit_sequence/val/{1}/{1}_{2}_{3:0>6d}_leftImg8bit.png'.format( 138 | self.dir, city, shot, i) 139 | frame_seq.append(Image.open(frame_path)) 140 | return frame_seq 141 | -------------------------------------------------------------------------------- /scripts/eval_cityscapes/download_fcn8s.sh: -------------------------------------------------------------------------------- 1 | URL=http://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/fcn-8s-cityscapes/fcn-8s-cityscapes.caffemodel 2 | OUTPUT_FILE=./scripts/eval_cityscapes/caffemodel/fcn-8s-cityscapes.caffemodel 3 | wget -N $URL -O $OUTPUT_FILE 4 | -------------------------------------------------------------------------------- /scripts/eval_cityscapes/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import caffe 4 | import argparse 5 | import numpy as np 6 | import scipy.misc 7 | from PIL import Image 8 | from util import * 9 | from cityscapes import cityscapes 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--cityscapes_dir", type=str, required=True, help="Path to the original cityscapes dataset") 13 | parser.add_argument("--result_dir", type=str, required=True, help="Path to the generated images to be evaluated") 14 | parser.add_argument("--output_dir", type=str, required=True, help="Where to save the evaluation results") 15 | parser.add_argument("--caffemodel_dir", type=str, default='./scripts/eval_cityscapes/caffemodel/', help="Where the FCN-8s caffemodel stored") 16 | parser.add_argument("--gpu_id", type=int, default=0, help="Which gpu id to use") 17 | parser.add_argument("--split", type=str, default='val', help="Data split to be evaluated") 18 | parser.add_argument("--save_output_images", type=int, default=0, help="Whether to save the FCN output images") 19 | args = parser.parse_args() 20 | 21 | def main(): 22 | if not os.path.isdir(args.output_dir): 23 | os.makedirs(args.output_dir) 24 | if args.save_output_images > 0: 25 | output_image_dir = args.output_dir + 'image_outputs/' 26 | if not os.path.isdir(output_image_dir): 27 | os.makedirs(output_image_dir) 28 | CS = cityscapes(args.cityscapes_dir) 29 | n_cl = len(CS.classes) 30 | label_frames = CS.list_label_frames(args.split) 31 | caffe.set_device(args.gpu_id) 32 | caffe.set_mode_gpu() 33 | net = caffe.Net(args.caffemodel_dir + '/deploy.prototxt', 34 | args.caffemodel_dir + 'fcn-8s-cityscapes.caffemodel', 35 | caffe.TEST) 36 | 37 | hist_perframe = np.zeros((n_cl, n_cl)) 38 | for i, idx in enumerate(label_frames): 39 | if i % 10 == 0: 40 | print('Evaluating: %d/%d' % (i, len(label_frames))) 41 | city = idx.split('_')[0] 42 | # idx is city_shot_frame 43 | label = CS.load_label(args.split, city, idx) 44 | im_file = args.result_dir + '/' + idx + '_leftImg8bit.png' 45 | im = np.array(Image.open(im_file)) 46 | # im = scipy.misc.imresize(im, (256, 256)) 47 | im = scipy.misc.imresize(im, (label.shape[1], label.shape[2])) 48 | out = segrun(net, CS.preprocess(im)) 49 | hist_perframe += fast_hist(label.flatten(), out.flatten(), n_cl) 50 | if args.save_output_images > 0: 51 | label_im = CS.palette(label) 52 | pred_im = CS.palette(out) 53 | scipy.misc.imsave(output_image_dir + '/' + str(i) + '_pred.jpg', pred_im) 54 | scipy.misc.imsave(output_image_dir + '/' + str(i) + '_gt.jpg', label_im) 55 | scipy.misc.imsave(output_image_dir + '/' + str(i) + '_input.jpg', im) 56 | 57 | mean_pixel_acc, mean_class_acc, mean_class_iou, per_class_acc, per_class_iou = get_scores(hist_perframe) 58 | with open(args.output_dir + '/evaluation_results.txt', 'w') as f: 59 | f.write('Mean pixel accuracy: %f\n' % mean_pixel_acc) 60 | f.write('Mean class accuracy: %f\n' % mean_class_acc) 61 | f.write('Mean class IoU: %f\n' % mean_class_iou) 62 | f.write('************ Per class numbers below ************\n') 63 | for i, cl in enumerate(CS.classes): 64 | while len(cl) < 15: 65 | cl = cl + ' ' 66 | f.write('%s: acc = %f, iou = %f\n' % (cl, per_class_acc[i], per_class_iou[i])) 67 | main() -------------------------------------------------------------------------------- /scripts/eval_cityscapes/util.py: -------------------------------------------------------------------------------- 1 | # The following code is modified from https://github.com/shelhamer/clockwork-fcn 2 | import numpy as np 3 | import scipy.io as sio 4 | 5 | def get_out_scoremap(net): 6 | return net.blobs['score'].data[0].argmax(axis=0).astype(np.uint8) 7 | 8 | def feed_net(net, in_): 9 | """ 10 | Load prepared input into net. 11 | """ 12 | net.blobs['data'].reshape(1, *in_.shape) 13 | net.blobs['data'].data[...] = in_ 14 | 15 | def segrun(net, in_): 16 | feed_net(net, in_) 17 | net.forward() 18 | return get_out_scoremap(net) 19 | 20 | def fast_hist(a, b, n): 21 | # print('saving') 22 | # sio.savemat('/tmp/fcn_debug/xx.mat', {'a':a, 'b':b, 'n':n}) 23 | 24 | k = np.where((a >= 0) & (a < n))[0] 25 | bc = np.bincount(n * a[k].astype(int) + b[k], minlength=n**2) 26 | if len(bc) != n**2: 27 | # ignore this example if dimension mismatch 28 | return 0 29 | return bc.reshape(n, n) 30 | 31 | def get_scores(hist): 32 | # Mean pixel accuracy 33 | acc = np.diag(hist).sum() / (hist.sum() + 1e-12) 34 | 35 | # Per class accuracy 36 | cl_acc = np.diag(hist) / (hist.sum(1) + 1e-12) 37 | 38 | # Per class IoU 39 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + 1e-12) 40 | 41 | return acc, np.nanmean(cl_acc), np.nanmean(iu), cl_acc, iu -------------------------------------------------------------------------------- /scripts/receptive_field_sizes.m: -------------------------------------------------------------------------------- 1 | % modified from: https://github.com/rbgirshick/rcnn/blob/master/utils/receptive_field_sizes.m 2 | % 3 | % RCNN LICENSE: 4 | % 5 | % Copyright (c) 2014, The Regents of the University of California (Regents) 6 | % All rights reserved. 7 | % 8 | % Redistribution and use in source and binary forms, with or without 9 | % modification, are permitted provided that the following conditions are met: 10 | % 11 | % 1. Redistributions of source code must retain the above copyright notice, this 12 | % list of conditions and the following disclaimer. 13 | % 2. Redistributions in binary form must reproduce the above copyright notice, 14 | % this list of conditions and the following disclaimer in the documentation 15 | % and/or other materials provided with the distribution. 16 | % 17 | % THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | % ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | % WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | % DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 21 | % ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 | % (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 | % LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 24 | % ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | % (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | % SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | 28 | function receptive_field_sizes() 29 | 30 | 31 | % compute input size from a given output size 32 | f = @(output_size, ksize, stride) (output_size - 1) * stride + ksize; 33 | 34 | 35 | %% n=1 discriminator 36 | 37 | % fix the output size to 1 and derive the receptive field in the input 38 | out = ... 39 | f(f(f(1, 4, 1), ... % conv2 -> conv3 40 | 4, 1), ... % conv1 -> conv2 41 | 4, 2); % input -> conv1 42 | 43 | fprintf('n=1 discriminator receptive field size: %d\n', out); 44 | 45 | 46 | %% n=2 discriminator 47 | 48 | % fix the output size to 1 and derive the receptive field in the input 49 | out = ... 50 | f(f(f(f(1, 4, 1), ... % conv3 -> conv4 51 | 4, 1), ... % conv2 -> conv3 52 | 4, 2), ... % conv1 -> conv2 53 | 4, 2); % input -> conv1 54 | 55 | fprintf('n=2 discriminator receptive field size: %d\n', out); 56 | 57 | 58 | %% n=3 discriminator 59 | 60 | % fix the output size to 1 and derive the receptive field in the input 61 | out = ... 62 | f(f(f(f(f(1, 4, 1), ... % conv4 -> conv5 63 | 4, 1), ... % conv3 -> conv4 64 | 4, 2), ... % conv2 -> conv3 65 | 4, 2), ... % conv1 -> conv2 66 | 4, 2); % input -> conv1 67 | 68 | fprintf('n=3 discriminator receptive field size: %d\n', out); 69 | 70 | 71 | %% n=4 discriminator 72 | 73 | % fix the output size to 1 and derive the receptive field in the input 74 | out = ... 75 | f(f(f(f(f(f(1, 4, 1), ... % conv5 -> conv6 76 | 4, 1), ... % conv4 -> conv5 77 | 4, 2), ... % conv3 -> conv4 78 | 4, 2), ... % conv2 -> conv3 79 | 4, 2), ... % conv1 -> conv2 80 | 4, 2); % input -> conv1 81 | 82 | fprintf('n=4 discriminator receptive field size: %d\n', out); 83 | 84 | 85 | %% n=5 discriminator 86 | 87 | % fix the output size to 1 and derive the receptive field in the input 88 | out = ... 89 | f(f(f(f(f(f(f(1, 4, 1), ... % conv6 -> conv7 90 | 4, 1), ... % conv5 -> conv6 91 | 4, 2), ... % conv4 -> conv5 92 | 4, 2), ... % conv3 -> conv4 93 | 4, 2), ... % conv2 -> conv3 94 | 4, 2), ... % conv1 -> conv2 95 | 4, 2); % input -> conv1 96 | 97 | fprintf('n=5 discriminator receptive field size: %d\n', out); -------------------------------------------------------------------------------- /test.lua: -------------------------------------------------------------------------------- 1 | -- usage: DATA_ROOT=/path/to/data/ name=expt1 which_direction=BtoA th test.lua 2 | -- 3 | -- code derived from https://github.com/soumith/dcgan.torch 4 | -- 5 | 6 | require 'image' 7 | require 'nn' 8 | require 'nngraph' 9 | util = paths.dofile('util/util.lua') 10 | torch.setdefaulttensortype('torch.FloatTensor') 11 | 12 | opt = { 13 | DATA_ROOT = '', -- path to images (should have subfolders 'train', 'val', etc) 14 | batchSize = 1, -- # images in batch 15 | loadSize = 256, -- scale images to this size 16 | fineSize = 256, -- then crop to this size 17 | flip=0, -- horizontal mirroring data augmentation 18 | display = 1, -- display samples while training. 0 = false 19 | display_id = 200, -- display window id. 20 | gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X 21 | how_many = 'all', -- how many test images to run (set to all to run on every image found in the data/phase folder) 22 | which_direction = 'AtoB', -- AtoB or BtoA 23 | phase = 'val', -- train, val, test ,etc 24 | preprocess = 'regular', -- for special purpose preprocessing, e.g., for colorization, change this (selects preprocessing functions in util.lua) 25 | aspect_ratio = 1.0, -- aspect ratio of result images 26 | name = '', -- name of experiment, selects which model to run, should generally should be passed on command line 27 | input_nc = 3, -- # of input image channels 28 | output_nc = 3, -- # of output image channels 29 | serial_batches = 1, -- if 1, takes images in order to make batches, otherwise takes them randomly 30 | serial_batch_iter = 1, -- iter into serial image list 31 | cudnn = 1, -- set to 0 to not use cudnn (untested) 32 | checkpoints_dir = './checkpoints', -- loads models from here 33 | results_dir='./results/', -- saves results here 34 | which_epoch = 'latest', -- which epoch to test? set to 'latest' to use latest cached model 35 | } 36 | 37 | 38 | -- one-line argument parser. parses enviroment variables to override the defaults 39 | for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end 40 | opt.nThreads = 1 -- test only works with 1 thread... 41 | print(opt) 42 | if opt.display == 0 then opt.display = false end 43 | 44 | opt.manualSeed = torch.random(1, 10000) -- set seed 45 | print("Random Seed: " .. opt.manualSeed) 46 | torch.manualSeed(opt.manualSeed) 47 | torch.setdefaulttensortype('torch.FloatTensor') 48 | 49 | opt.netG_name = opt.name .. '/' .. opt.which_epoch .. '_net_G' 50 | 51 | local data_loader = paths.dofile('data/data.lua') 52 | print('#threads...' .. opt.nThreads) 53 | local data = data_loader.new(opt.nThreads, opt) 54 | print("Dataset Size: ", data:size()) 55 | 56 | -- translation direction 57 | local idx_A = nil 58 | local idx_B = nil 59 | local input_nc = opt.input_nc 60 | local output_nc = opt.output_nc 61 | if opt.which_direction=='AtoB' then 62 | idx_A = {1, input_nc} 63 | idx_B = {input_nc+1, input_nc+output_nc} 64 | elseif opt.which_direction=='BtoA' then 65 | idx_A = {input_nc+1, input_nc+output_nc} 66 | idx_B = {1, input_nc} 67 | else 68 | error(string.format('bad direction %s',opt.which_direction)) 69 | end 70 | ---------------------------------------------------------------------------- 71 | 72 | local input = torch.FloatTensor(opt.batchSize,3,opt.fineSize,opt.fineSize) 73 | local target = torch.FloatTensor(opt.batchSize,3,opt.fineSize,opt.fineSize) 74 | 75 | print('checkpoints_dir', opt.checkpoints_dir) 76 | local netG = util.load(paths.concat(opt.checkpoints_dir, opt.netG_name .. '.t7'), opt) 77 | --netG:evaluate() 78 | 79 | print(netG) 80 | 81 | 82 | function TableConcat(t1,t2) 83 | for i=1,#t2 do 84 | t1[#t1+1] = t2[i] 85 | end 86 | return t1 87 | end 88 | 89 | if opt.how_many=='all' then 90 | opt.how_many=data:size() 91 | end 92 | opt.how_many=math.min(opt.how_many, data:size()) 93 | 94 | local filepaths = {} -- paths to images tested on 95 | for n=1,math.floor(opt.how_many/opt.batchSize) do 96 | print('processing batch ' .. n) 97 | 98 | local data_curr, filepaths_curr = data:getBatch() 99 | filepaths_curr = util.basename_batch(filepaths_curr) 100 | print('filepaths_curr: ', filepaths_curr) 101 | 102 | input = data_curr[{ {}, idx_A, {}, {} }] 103 | target = data_curr[{ {}, idx_B, {}, {} }] 104 | 105 | if opt.gpu > 0 then 106 | input = input:cuda() 107 | end 108 | 109 | if opt.preprocess == 'colorization' then 110 | local output_AB = netG:forward(input):float() 111 | local input_L = input:float() 112 | output = util.deprocessLAB_batch(input_L, output_AB) 113 | local target_AB = target:float() 114 | target = util.deprocessLAB_batch(input_L, target_AB) 115 | input = util.deprocessL_batch(input_L) 116 | else 117 | output = util.deprocess_batch(netG:forward(input)) 118 | input = util.deprocess_batch(input):float() 119 | output = output:float() 120 | target = util.deprocess_batch(target):float() 121 | end 122 | paths.mkdir(paths.concat(opt.results_dir, opt.netG_name .. '_' .. opt.phase)) 123 | local image_dir = paths.concat(opt.results_dir, opt.netG_name .. '_' .. opt.phase, 'images') 124 | paths.mkdir(image_dir) 125 | paths.mkdir(paths.concat(image_dir,'input')) 126 | paths.mkdir(paths.concat(image_dir,'output')) 127 | paths.mkdir(paths.concat(image_dir,'target')) 128 | for i=1, opt.batchSize do 129 | image.save(paths.concat(image_dir,'input',filepaths_curr[i]), image.scale(input[i],input[i]:size(2),input[i]:size(3)/opt.aspect_ratio)) 130 | image.save(paths.concat(image_dir,'output',filepaths_curr[i]), image.scale(output[i],output[i]:size(2),output[i]:size(3)/opt.aspect_ratio)) 131 | image.save(paths.concat(image_dir,'target',filepaths_curr[i]), image.scale(target[i],target[i]:size(2),target[i]:size(3)/opt.aspect_ratio)) 132 | end 133 | print('Saved images to: ', image_dir) 134 | 135 | if opt.display then 136 | if opt.preprocess == 'regular' then 137 | disp = require 'display' 138 | disp.image(util.scaleBatch(input,100,100),{win=opt.display_id, title='input'}) 139 | disp.image(util.scaleBatch(output,100,100),{win=opt.display_id+1, title='output'}) 140 | disp.image(util.scaleBatch(target,100,100),{win=opt.display_id+2, title='target'}) 141 | 142 | print('Displayed images') 143 | end 144 | end 145 | 146 | filepaths = TableConcat(filepaths, filepaths_curr) 147 | end 148 | 149 | -- make webpage 150 | io.output(paths.concat(opt.results_dir,opt.netG_name .. '_' .. opt.phase, 'index.html')) 151 | 152 | io.write('') 153 | 154 | io.write('') 155 | for i=1, #filepaths do 156 | io.write('') 157 | io.write('') 158 | io.write('') 159 | io.write('') 160 | io.write('') 161 | io.write('') 162 | end 163 | 164 | io.write('
Image #InputOutputGround Truth
' .. filepaths[i] .. '
') 165 | -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | -- usage example: DATA_ROOT=/path/to/data/ which_direction=BtoA name=expt1 th train.lua 2 | -- 3 | -- code derived from https://github.com/soumith/dcgan.torch 4 | -- 5 | 6 | require 'torch' 7 | require 'nn' 8 | require 'optim' 9 | util = paths.dofile('util/util.lua') 10 | require 'image' 11 | require 'models' 12 | 13 | 14 | opt = { 15 | DATA_ROOT = '', -- path to images (should have subfolders 'train', 'val', etc) 16 | batchSize = 1, -- # images in batch 17 | loadSize = 286, -- scale images to this size 18 | fineSize = 256, -- then crop to this size 19 | ngf = 64, -- # of gen filters in first conv layer 20 | ndf = 64, -- # of discrim filters in first conv layer 21 | input_nc = 3, -- # of input image channels 22 | output_nc = 3, -- # of output image channels 23 | niter = 200, -- # of iter at starting learning rate 24 | lr = 0.0002, -- initial learning rate for adam 25 | beta1 = 0.5, -- momentum term of adam 26 | ntrain = math.huge, -- # of examples per epoch. math.huge for full dataset 27 | flip = 1, -- if flip the images for data argumentation 28 | display = 1, -- display samples while training. 0 = false 29 | display_id = 10, -- display window id. 30 | display_plot = 'errL1', -- which loss values to plot over time. Accepted values include a comma seperated list of: errL1, errG, and errD 31 | gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X 32 | name = '', -- name of the experiment, should generally be passed on the command line 33 | which_direction = 'AtoB', -- AtoB or BtoA 34 | phase = 'train', -- train, val, test, etc 35 | preprocess = 'regular', -- for special purpose preprocessing, e.g., for colorization, change this (selects preprocessing functions in util.lua) 36 | nThreads = 2, -- # threads for loading data 37 | save_epoch_freq = 50, -- save a model every save_epoch_freq epochs (does not overwrite previously saved models) 38 | save_latest_freq = 5000, -- save the latest model every latest_freq sgd iterations (overwrites the previous latest model) 39 | print_freq = 50, -- print the debug information every print_freq iterations 40 | display_freq = 100, -- display the current results every display_freq iterations 41 | save_display_freq = 5000, -- save the current display of results every save_display_freq_iterations 42 | continue_train=0, -- if continue training, load the latest model: 1: true, 0: false 43 | serial_batches = 0, -- if 1, takes images in order to make batches, otherwise takes them randomly 44 | serial_batch_iter = 1, -- iter into serial image list 45 | checkpoints_dir = './checkpoints', -- models are saved here 46 | cudnn = 1, -- set to 0 to not use cudnn 47 | condition_GAN = 1, -- set to 0 to use unconditional discriminator 48 | use_GAN = 1, -- set to 0 to turn off GAN term 49 | use_L1 = 1, -- set to 0 to turn off L1 term 50 | which_model_netD = 'basic', -- selects model to use for netD 51 | which_model_netG = 'unet', -- selects model to use for netG 52 | n_layers_D = 0, -- only used if which_model_netD=='n_layers' 53 | lambda = 100, -- weight on L1 term in objective 54 | } 55 | 56 | -- one-line argument parser. parses enviroment variables to override the defaults 57 | for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end 58 | print(opt) 59 | 60 | local input_nc = opt.input_nc 61 | local output_nc = opt.output_nc 62 | -- translation direction 63 | local idx_A = nil 64 | local idx_B = nil 65 | 66 | if opt.which_direction=='AtoB' then 67 | idx_A = {1, input_nc} 68 | idx_B = {input_nc+1, input_nc+output_nc} 69 | elseif opt.which_direction=='BtoA' then 70 | idx_A = {input_nc+1, input_nc+output_nc} 71 | idx_B = {1, input_nc} 72 | else 73 | error(string.format('bad direction %s',opt.which_direction)) 74 | end 75 | 76 | if opt.display == 0 then opt.display = false end 77 | 78 | opt.manualSeed = torch.random(1, 10000) -- fix seed 79 | print("Random Seed: " .. opt.manualSeed) 80 | torch.manualSeed(opt.manualSeed) 81 | torch.setdefaulttensortype('torch.FloatTensor') 82 | 83 | -- create data loader 84 | local data_loader = paths.dofile('data/data.lua') 85 | print('#threads...' .. opt.nThreads) 86 | local data = data_loader.new(opt.nThreads, opt) 87 | print("Dataset Size: ", data:size()) 88 | 89 | ---------------------------------------------------------------------------- 90 | local function weights_init(m) 91 | local name = torch.type(m) 92 | if name:find('Convolution') then 93 | m.weight:normal(0.0, 0.02) 94 | m.bias:fill(0) 95 | elseif name:find('BatchNormalization') then 96 | if m.weight then m.weight:normal(1.0, 0.02) end 97 | if m.bias then m.bias:fill(0) end 98 | end 99 | end 100 | 101 | 102 | local ndf = opt.ndf 103 | local ngf = opt.ngf 104 | local real_label = 1 105 | local fake_label = 0 106 | 107 | function defineG(input_nc, output_nc, ngf) 108 | local netG = nil 109 | if opt.which_model_netG == "encoder_decoder" then netG = defineG_encoder_decoder(input_nc, output_nc, ngf) 110 | elseif opt.which_model_netG == "unet" then netG = defineG_unet(input_nc, output_nc, ngf) 111 | elseif opt.which_model_netG == "unet_128" then netG = defineG_unet_128(input_nc, output_nc, ngf) 112 | else error("unsupported netG model") 113 | end 114 | 115 | netG:apply(weights_init) 116 | 117 | return netG 118 | end 119 | 120 | function defineD(input_nc, output_nc, ndf) 121 | local netD = nil 122 | if opt.condition_GAN==1 then 123 | input_nc_tmp = input_nc 124 | else 125 | input_nc_tmp = 0 -- only penalizes structure in output channels 126 | end 127 | 128 | if opt.which_model_netD == "basic" then netD = defineD_basic(input_nc_tmp, output_nc, ndf) 129 | elseif opt.which_model_netD == "n_layers" then netD = defineD_n_layers(input_nc_tmp, output_nc, ndf, opt.n_layers_D) 130 | else error("unsupported netD model") 131 | end 132 | 133 | netD:apply(weights_init) 134 | 135 | return netD 136 | end 137 | 138 | 139 | -- load saved models and finetune 140 | if opt.continue_train == 1 then 141 | print('loading previously trained netG...') 142 | netG = util.load(paths.concat(opt.checkpoints_dir, opt.name, 'latest_net_G.t7'), opt) 143 | print('loading previously trained netD...') 144 | netD = util.load(paths.concat(opt.checkpoints_dir, opt.name, 'latest_net_D.t7'), opt) 145 | else 146 | print('define model netG...') 147 | netG = defineG(input_nc, output_nc, ngf) 148 | print('define model netD...') 149 | netD = defineD(input_nc, output_nc, ndf) 150 | end 151 | 152 | print(netG) 153 | print(netD) 154 | 155 | 156 | local criterion = nn.BCECriterion() 157 | local criterionAE = nn.AbsCriterion() 158 | --------------------------------------------------------------------------- 159 | optimStateG = { 160 | learningRate = opt.lr, 161 | beta1 = opt.beta1, 162 | } 163 | optimStateD = { 164 | learningRate = opt.lr, 165 | beta1 = opt.beta1, 166 | } 167 | ---------------------------------------------------------------------------- 168 | local real_A = torch.Tensor(opt.batchSize, input_nc, opt.fineSize, opt.fineSize) 169 | local real_B = torch.Tensor(opt.batchSize, output_nc, opt.fineSize, opt.fineSize) 170 | local fake_B = torch.Tensor(opt.batchSize, output_nc, opt.fineSize, opt.fineSize) 171 | local real_AB = torch.Tensor(opt.batchSize, output_nc + input_nc*opt.condition_GAN, opt.fineSize, opt.fineSize) 172 | local fake_AB = torch.Tensor(opt.batchSize, output_nc + input_nc*opt.condition_GAN, opt.fineSize, opt.fineSize) 173 | local errD, errG, errL1 = 0, 0, 0 174 | local epoch_tm = torch.Timer() 175 | local tm = torch.Timer() 176 | local data_tm = torch.Timer() 177 | ---------------------------------------------------------------------------- 178 | 179 | if opt.gpu > 0 then 180 | print('transferring to gpu...') 181 | require 'cunn' 182 | cutorch.setDevice(opt.gpu) 183 | real_A = real_A:cuda(); 184 | real_B = real_B:cuda(); fake_B = fake_B:cuda(); 185 | real_AB = real_AB:cuda(); fake_AB = fake_AB:cuda(); 186 | if opt.cudnn==1 then 187 | netG = util.cudnn(netG); netD = util.cudnn(netD); 188 | end 189 | netD:cuda(); netG:cuda(); criterion:cuda(); criterionAE:cuda(); 190 | print('done') 191 | else 192 | print('running model on CPU') 193 | end 194 | 195 | 196 | local parametersD, gradParametersD = netD:getParameters() 197 | local parametersG, gradParametersG = netG:getParameters() 198 | 199 | 200 | 201 | if opt.display then disp = require 'display' end 202 | 203 | 204 | function createRealFake() 205 | -- load real 206 | data_tm:reset(); data_tm:resume() 207 | local real_data, data_path = data:getBatch() 208 | data_tm:stop() 209 | 210 | real_A:copy(real_data[{ {}, idx_A, {}, {} }]) 211 | real_B:copy(real_data[{ {}, idx_B, {}, {} }]) 212 | 213 | if opt.condition_GAN==1 then 214 | real_AB = torch.cat(real_A,real_B,2) 215 | else 216 | real_AB = real_B -- unconditional GAN, only penalizes structure in B 217 | end 218 | 219 | -- create fake 220 | fake_B = netG:forward(real_A) 221 | 222 | if opt.condition_GAN==1 then 223 | fake_AB = torch.cat(real_A,fake_B,2) 224 | else 225 | fake_AB = fake_B -- unconditional GAN, only penalizes structure in B 226 | end 227 | end 228 | 229 | -- create closure to evaluate f(X) and df/dX of discriminator 230 | local fDx = function(x) 231 | netD:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end) 232 | netG:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end) 233 | 234 | gradParametersD:zero() 235 | 236 | -- Real 237 | local output = netD:forward(real_AB) 238 | local label = torch.FloatTensor(output:size()):fill(real_label) 239 | if opt.gpu>0 then 240 | label = label:cuda() 241 | end 242 | 243 | local errD_real = criterion:forward(output, label) 244 | local df_do = criterion:backward(output, label) 245 | netD:backward(real_AB, df_do) 246 | 247 | -- Fake 248 | local output = netD:forward(fake_AB) 249 | label:fill(fake_label) 250 | local errD_fake = criterion:forward(output, label) 251 | local df_do = criterion:backward(output, label) 252 | netD:backward(fake_AB, df_do) 253 | 254 | errD = (errD_real + errD_fake)/2 255 | 256 | return errD, gradParametersD 257 | end 258 | 259 | -- create closure to evaluate f(X) and df/dX of generator 260 | local fGx = function(x) 261 | netD:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end) 262 | netG:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end) 263 | 264 | gradParametersG:zero() 265 | 266 | -- GAN loss 267 | local df_dg = torch.zeros(fake_B:size()) 268 | if opt.gpu>0 then 269 | df_dg = df_dg:cuda(); 270 | end 271 | 272 | if opt.use_GAN==1 then 273 | local output = netD.output -- netD:forward{input_A,input_B} was already executed in fDx, so save computation 274 | local label = torch.FloatTensor(output:size()):fill(real_label) -- fake labels are real for generator cost 275 | if opt.gpu>0 then 276 | label = label:cuda(); 277 | end 278 | errG = criterion:forward(output, label) 279 | local df_do = criterion:backward(output, label) 280 | df_dg = netD:updateGradInput(fake_AB, df_do):narrow(2,fake_AB:size(2)-output_nc+1, output_nc) 281 | else 282 | errG = 0 283 | end 284 | 285 | -- unary loss 286 | local df_do_AE = torch.zeros(fake_B:size()) 287 | if opt.gpu>0 then 288 | df_do_AE = df_do_AE:cuda(); 289 | end 290 | if opt.use_L1==1 then 291 | errL1 = criterionAE:forward(fake_B, real_B) 292 | df_do_AE = criterionAE:backward(fake_B, real_B) 293 | else 294 | errL1 = 0 295 | end 296 | 297 | netG:backward(real_A, df_dg + df_do_AE:mul(opt.lambda)) 298 | 299 | return errG, gradParametersG 300 | end 301 | 302 | 303 | 304 | 305 | -- train 306 | local best_err = nil 307 | paths.mkdir(opt.checkpoints_dir) 308 | paths.mkdir(opt.checkpoints_dir .. '/' .. opt.name) 309 | 310 | -- save opt 311 | file = torch.DiskFile(paths.concat(opt.checkpoints_dir, opt.name, 'opt.txt'), 'w') 312 | file:writeObject(opt) 313 | file:close() 314 | 315 | -- parse diplay_plot string into table 316 | opt.display_plot = string.split(string.gsub(opt.display_plot, "%s+", ""), ",") 317 | for k, v in ipairs(opt.display_plot) do 318 | if not util.containsValue({"errG", "errD", "errL1"}, v) then 319 | error(string.format('bad display_plot value "%s"', v)) 320 | end 321 | end 322 | 323 | -- display plot config 324 | local plot_config = { 325 | title = "Loss over time", 326 | labels = {"epoch", unpack(opt.display_plot)}, 327 | ylabel = "loss", 328 | } 329 | 330 | -- display plot vars 331 | local plot_data = {} 332 | local plot_win 333 | 334 | local counter = 0 335 | for epoch = 1, opt.niter do 336 | epoch_tm:reset() 337 | for i = 1, math.min(data:size(), opt.ntrain), opt.batchSize do 338 | tm:reset() 339 | 340 | -- load a batch and run G on that batch 341 | createRealFake() 342 | 343 | -- (1) Update D network: maximize log(D(x,y)) + log(1 - D(x,G(x))) 344 | if opt.use_GAN==1 then optim.adam(fDx, parametersD, optimStateD) end 345 | 346 | -- (2) Update G network: maximize log(D(x,G(x))) + L1(y,G(x)) 347 | optim.adam(fGx, parametersG, optimStateG) 348 | 349 | -- display 350 | counter = counter + 1 351 | if counter % opt.display_freq == 0 and opt.display then 352 | createRealFake() 353 | if opt.preprocess == 'colorization' then 354 | local real_A_s = util.scaleBatch(real_A:float(),100,100) 355 | local fake_B_s = util.scaleBatch(fake_B:float(),100,100) 356 | local real_B_s = util.scaleBatch(real_B:float(),100,100) 357 | disp.image(util.deprocessL_batch(real_A_s), {win=opt.display_id, title=opt.name .. ' input'}) 358 | disp.image(util.deprocessLAB_batch(real_A_s, fake_B_s), {win=opt.display_id+1, title=opt.name .. ' output'}) 359 | disp.image(util.deprocessLAB_batch(real_A_s, real_B_s), {win=opt.display_id+2, title=opt.name .. ' target'}) 360 | else 361 | disp.image(util.deprocess_batch(util.scaleBatch(real_A:float(),100,100)), {win=opt.display_id, title=opt.name .. ' input'}) 362 | disp.image(util.deprocess_batch(util.scaleBatch(fake_B:float(),100,100)), {win=opt.display_id+1, title=opt.name .. ' output'}) 363 | disp.image(util.deprocess_batch(util.scaleBatch(real_B:float(),100,100)), {win=opt.display_id+2, title=opt.name .. ' target'}) 364 | end 365 | end 366 | 367 | -- write display visualization to disk 368 | -- runs on the first batchSize images in the opt.phase set 369 | if counter % opt.save_display_freq == 0 and opt.display then 370 | local serial_batches=opt.serial_batches 371 | opt.serial_batches=1 372 | opt.serial_batch_iter=1 373 | 374 | local image_out = nil 375 | local N_save_display = 10 376 | local N_save_iter = torch.max(torch.Tensor({1, torch.floor(N_save_display/opt.batchSize)})) 377 | for i3=1, N_save_iter do 378 | 379 | createRealFake() 380 | print('save to the disk') 381 | if opt.preprocess == 'colorization' then 382 | for i2=1, fake_B:size(1) do 383 | if image_out==nil then image_out = torch.cat(util.deprocessL(real_A[i2]:float()),util.deprocessLAB(real_A[i2]:float(), fake_B[i2]:float()),3)/255.0 384 | else image_out = torch.cat(image_out, torch.cat(util.deprocessL(real_A[i2]:float()),util.deprocessLAB(real_A[i2]:float(), fake_B[i2]:float()),3)/255.0, 2) end 385 | end 386 | else 387 | for i2=1, fake_B:size(1) do 388 | if image_out==nil then image_out = torch.cat(util.deprocess(real_A[i2]:float()),util.deprocess(fake_B[i2]:float()),3) 389 | else image_out = torch.cat(image_out, torch.cat(util.deprocess(real_A[i2]:float()),util.deprocess(fake_B[i2]:float()),3), 2) end 390 | end 391 | end 392 | end 393 | image.save(paths.concat(opt.checkpoints_dir, opt.name , counter .. '_train_res.png'), image_out) 394 | 395 | opt.serial_batches=serial_batches 396 | end 397 | 398 | -- logging and display plot 399 | if counter % opt.print_freq == 0 then 400 | local loss = {errG=errG and errG or -1, errD=errD and errD or -1, errL1=errL1 and errL1 or -1} 401 | local curItInBatch = ((i-1) / opt.batchSize) 402 | local totalItInBatch = math.floor(math.min(data:size(), opt.ntrain) / opt.batchSize) 403 | print(('Epoch: [%d][%8d / %8d]\t Time: %.3f DataTime: %.3f ' 404 | .. ' Err_G: %.4f Err_D: %.4f ErrL1: %.4f'):format( 405 | epoch, curItInBatch, totalItInBatch, 406 | tm:time().real / opt.batchSize, data_tm:time().real / opt.batchSize, 407 | errG, errD, errL1)) 408 | 409 | local plot_vals = { epoch + curItInBatch / totalItInBatch } 410 | for k, v in ipairs(opt.display_plot) do 411 | if loss[v] ~= nil then 412 | plot_vals[#plot_vals + 1] = loss[v] 413 | end 414 | end 415 | 416 | -- update display plot 417 | if opt.display then 418 | table.insert(plot_data, plot_vals) 419 | plot_config.win = plot_win 420 | plot_win = disp.plot(plot_data, plot_config) 421 | end 422 | end 423 | 424 | -- save latest model 425 | if counter % opt.save_latest_freq == 0 then 426 | print(('saving the latest model (epoch %d, iters %d)'):format(epoch, counter)) 427 | torch.save(paths.concat(opt.checkpoints_dir, opt.name, 'latest_net_G.t7'), netG:clearState()) 428 | torch.save(paths.concat(opt.checkpoints_dir, opt.name, 'latest_net_D.t7'), netD:clearState()) 429 | end 430 | 431 | end 432 | 433 | 434 | parametersD, gradParametersD = nil, nil -- nil them to avoid spiking memory 435 | parametersG, gradParametersG = nil, nil 436 | 437 | if epoch % opt.save_epoch_freq == 0 then 438 | torch.save(paths.concat(opt.checkpoints_dir, opt.name, epoch .. '_net_G.t7'), netG:clearState()) 439 | torch.save(paths.concat(opt.checkpoints_dir, opt.name, epoch .. '_net_D.t7'), netD:clearState()) 440 | end 441 | 442 | print(('End of epoch %d / %d \t Time Taken: %.3f'):format( 443 | epoch, opt.niter, epoch_tm:time().real)) 444 | parametersD, gradParametersD = netD:getParameters() -- reflatten the params and get them 445 | parametersG, gradParametersG = netG:getParameters() 446 | end 447 | -------------------------------------------------------------------------------- /util/cudnn_convert_custom.lua: -------------------------------------------------------------------------------- 1 | -- modified from https://github.com/NVIDIA/torch-cudnn/blob/master/convert.lua 2 | -- removed error on nngraph 3 | 4 | -- modules that can be converted to nn seamlessly 5 | local layer_list = { 6 | 'BatchNormalization', 7 | 'SpatialBatchNormalization', 8 | 'SpatialConvolution', 9 | 'SpatialCrossMapLRN', 10 | 'SpatialFullConvolution', 11 | 'SpatialMaxPooling', 12 | 'SpatialAveragePooling', 13 | 'ReLU', 14 | 'Tanh', 15 | 'Sigmoid', 16 | 'SoftMax', 17 | 'LogSoftMax', 18 | 'VolumetricBatchNormalization', 19 | 'VolumetricConvolution', 20 | 'VolumetricFullConvolution', 21 | 'VolumetricMaxPooling', 22 | 'VolumetricAveragePooling', 23 | } 24 | 25 | -- goes over a given net and converts all layers to dst backend 26 | -- for example: net = cudnn_convert_custom(net, cudnn) 27 | -- same as cudnn.convert with gModule check commented out 28 | function cudnn_convert_custom(net, dst, exclusion_fn) 29 | return net:replace(function(x) 30 | --if torch.type(x) == 'nn.gModule' then 31 | -- io.stderr:write('Warning: cudnn.convert does not work with nngraph yet. Ignoring nn.gModule') 32 | -- return x 33 | --end 34 | local y = 0 35 | local src = dst == nn and cudnn or nn 36 | local src_prefix = src == nn and 'nn.' or 'cudnn.' 37 | local dst_prefix = dst == nn and 'nn.' or 'cudnn.' 38 | 39 | local function convert(v) 40 | local y = {} 41 | torch.setmetatable(y, dst_prefix..v) 42 | if v == 'ReLU' then y = dst.ReLU() end -- because parameters 43 | for k,u in pairs(x) do y[k] = u end 44 | if src == cudnn and x.clearDesc then x.clearDesc(y) end 45 | if src == cudnn and v == 'SpatialAveragePooling' then 46 | y.divide = true 47 | y.count_include_pad = v.mode == 'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING' 48 | end 49 | if src == nn and string.find(v, 'Convolution') then 50 | y.groups = 1 51 | end 52 | return y 53 | end 54 | 55 | if exclusion_fn and exclusion_fn(x) then 56 | return x 57 | end 58 | local t = torch.typename(x) 59 | if t == 'nn.SpatialConvolutionMM' then 60 | y = convert('SpatialConvolution') 61 | elseif t == 'inn.SpatialCrossResponseNormalization' then 62 | y = convert('SpatialCrossMapLRN') 63 | else 64 | for i,v in ipairs(layer_list) do 65 | if torch.typename(x) == src_prefix..v then 66 | y = convert(v) 67 | end 68 | end 69 | end 70 | return y == 0 and x or y 71 | end) 72 | end 73 | -------------------------------------------------------------------------------- /util/util.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- code derived from https://github.com/soumith/dcgan.torch 3 | -- 4 | 5 | local util = {} 6 | 7 | require 'torch' 8 | 9 | function util.normalize(img) 10 | -- rescale image to 0 .. 1 11 | local min = img:min() 12 | local max = img:max() 13 | 14 | img = torch.FloatTensor(img:size()):copy(img) 15 | img:add(-min):mul(1/(max-min)) 16 | return img 17 | end 18 | 19 | function util.normalizeBatch(batch) 20 | for i = 1, batch:size(1) do 21 | batch[i] = util.normalize(batch[i]:squeeze()) 22 | end 23 | return batch 24 | end 25 | 26 | function util.basename_batch(batch) 27 | for i = 1, #batch do 28 | batch[i] = paths.basename(batch[i]) 29 | end 30 | return batch 31 | end 32 | 33 | 34 | 35 | -- default preprocessing 36 | -- 37 | -- Preprocesses an image before passing it to a net 38 | -- Converts from RGB to BGR and rescales from [0,1] to [-1,1] 39 | function util.preprocess(img) 40 | -- RGB to BGR 41 | local perm = torch.LongTensor{3, 2, 1} 42 | img = img:index(1, perm) 43 | 44 | -- [0,1] to [-1,1] 45 | img = img:mul(2):add(-1) 46 | 47 | -- check that input is in expected range 48 | assert(img:max()<=1,"badly scaled inputs") 49 | assert(img:min()>=-1,"badly scaled inputs") 50 | 51 | return img 52 | end 53 | 54 | -- Undo the above preprocessing. 55 | function util.deprocess(img) 56 | -- BGR to RGB 57 | local perm = torch.LongTensor{3, 2, 1} 58 | img = img:index(1, perm) 59 | 60 | -- [-1,1] to [0,1] 61 | 62 | img = img:add(1):div(2) 63 | 64 | return img 65 | end 66 | 67 | function util.preprocess_batch(batch) 68 | for i = 1, batch:size(1) do 69 | batch[i] = util.preprocess(batch[i]:squeeze()) 70 | end 71 | return batch 72 | end 73 | 74 | function util.deprocess_batch(batch) 75 | for i = 1, batch:size(1) do 76 | batch[i] = util.deprocess(batch[i]:squeeze()) 77 | end 78 | return batch 79 | end 80 | 81 | 82 | 83 | -- preprocessing specific to colorization 84 | 85 | function util.deprocessLAB(L, AB) 86 | local L2 = torch.Tensor(L:size()):copy(L) 87 | if L2:dim() == 3 then 88 | L2 = L2[{1, {}, {} }] 89 | end 90 | local AB2 = torch.Tensor(AB:size()):copy(AB) 91 | AB2 = torch.clamp(AB2, -1.0, 1.0) 92 | -- local AB2 = AB 93 | L2 = L2:add(1):mul(50.0) 94 | AB2 = AB2:mul(110.0) 95 | 96 | L2 = L2:reshape(1, L2:size(1), L2:size(2)) 97 | 98 | im_lab = torch.cat(L2, AB2, 1) 99 | im_rgb = torch.clamp(image.lab2rgb(im_lab):mul(255.0), 0.0, 255.0)/255.0 100 | 101 | return im_rgb 102 | end 103 | 104 | function util.deprocessL(L) 105 | local L2 = torch.Tensor(L:size()):copy(L) 106 | L2 = L2:add(1):mul(255.0/2.0) 107 | 108 | if L2:dim()==2 then 109 | L2 = L2:reshape(1,L2:size(1),L2:size(2)) 110 | end 111 | L2 = L2:repeatTensor(L2,3,1,1)/255.0 112 | 113 | return L2 114 | end 115 | 116 | function util.deprocessL_batch(batch) 117 | local batch_new = {} 118 | for i = 1, batch:size(1) do 119 | batch_new[i] = util.deprocessL(batch[i]:squeeze()) 120 | end 121 | return batch_new 122 | end 123 | 124 | function util.deprocessLAB_batch(batchL, batchAB) 125 | local batch = {} 126 | 127 | for i = 1, batchL:size(1) do 128 | batch[i] = util.deprocessLAB(batchL[i]:squeeze(), batchAB[i]:squeeze()) 129 | end 130 | 131 | return batch 132 | end 133 | 134 | 135 | function util.scaleBatch(batch,s1,s2) 136 | local scaled_batch = torch.Tensor(batch:size(1),batch:size(2),s1,s2) 137 | for i = 1, batch:size(1) do 138 | scaled_batch[i] = image.scale(batch[i],s1,s2):squeeze() 139 | end 140 | return scaled_batch 141 | end 142 | 143 | 144 | 145 | function util.toTrivialBatch(input) 146 | return input:reshape(1,input:size(1),input:size(2),input:size(3)) 147 | end 148 | function util.fromTrivialBatch(input) 149 | return input[1] 150 | end 151 | 152 | 153 | 154 | function util.scaleImage(input, loadSize) 155 | -- replicate bw images to 3 channels 156 | if input:size(1)==1 then 157 | input = torch.repeatTensor(input,3,1,1) 158 | end 159 | 160 | input = image.scale(input, loadSize, loadSize) 161 | 162 | return input 163 | end 164 | 165 | function util.getAspectRatio(path) 166 | local input = image.load(path, 3, 'float') 167 | local ar = input:size(3)/input:size(2) 168 | return ar 169 | end 170 | 171 | function util.loadImage(path, loadSize, nc) 172 | local input = image.load(path, 3, 'float') 173 | input= util.preprocess(util.scaleImage(input, loadSize)) 174 | 175 | if nc == 1 then 176 | input = input[{{1}, {}, {}}] 177 | end 178 | 179 | return input 180 | end 181 | 182 | 183 | 184 | -- TO DO: loading code is rather hacky; clean it up and make sure it works on all types of nets / cpu/gpu configurations 185 | function util.load(filename, opt) 186 | if opt.cudnn>0 then 187 | require 'cudnn' 188 | end 189 | 190 | if opt.gpu > 0 then 191 | require 'cunn' 192 | end 193 | 194 | local net = torch.load(filename) 195 | 196 | if opt.gpu > 0 then 197 | net:cuda() 198 | 199 | -- calling cuda on cudnn saved nngraphs doesn't change all variables to cuda, so do it below 200 | if net.forwardnodes then 201 | for i=1,#net.forwardnodes do 202 | if net.forwardnodes[i].data.module then 203 | net.forwardnodes[i].data.module:cuda() 204 | end 205 | end 206 | end 207 | else 208 | net:float() 209 | end 210 | net:apply(function(m) if m.weight then 211 | m.gradWeight = m.weight:clone():zero(); 212 | m.gradBias = m.bias:clone():zero(); end end) 213 | return net 214 | end 215 | 216 | function util.cudnn(net) 217 | require 'cudnn' 218 | require 'util/cudnn_convert_custom' 219 | return cudnn_convert_custom(net, cudnn) 220 | end 221 | 222 | function util.containsValue(table, value) 223 | for k, v in pairs(table) do 224 | if v == value then return true end 225 | end 226 | return false 227 | end 228 | 229 | return util 230 | --------------------------------------------------------------------------------