├── .gitignore
├── LICENSE
├── README.md
├── data
├── data.lua
├── dataset.lua
└── donkey_folder.lua
├── datasets
├── bibtex
│ ├── cityscapes.tex
│ ├── facades.tex
│ ├── handbags.tex
│ ├── shoes.tex
│ └── transattr.tex
└── download_dataset.sh
├── imgs
└── examples.jpg
├── models.lua
├── models
└── download_model.sh
├── scripts
├── combine_A_and_B.py
├── edges
│ ├── PostprocessHED.m
│ └── batch_hed.py
├── eval_cityscapes
│ ├── caffemodel
│ │ └── deploy.prototxt
│ ├── cityscapes.py
│ ├── download_fcn8s.sh
│ ├── evaluate.py
│ └── util.py
└── receptive_field_sizes.m
├── test.lua
├── train.lua
└── util
├── cudnn_convert_custom.lua
└── util.lua
/.gitignore:
--------------------------------------------------------------------------------
1 | #
2 | *~
3 |
4 | *.DS_Store
5 | cache/
6 | results/
7 | checkpoints/
8 |
9 | # luarocks build files
10 | *.src.rock
11 | *.zip
12 | *.tar.gz
13 | *.t7
14 |
15 | # Object files
16 | *.o
17 | *.os
18 | *.ko
19 | *.obj
20 | *.elf
21 |
22 | # Precompiled Headers
23 | *.gch
24 | *.pch
25 |
26 | # Libraries
27 | *.lib
28 | *.a
29 | *.la
30 | *.lo
31 | *.def
32 | *.exp
33 |
34 | # Shared objects (inc. Windows DLLs)
35 | *.dll
36 | *.so
37 | *.so.*
38 | *.dylib
39 |
40 | # Executables
41 | *.exe
42 | *.out
43 | *.app
44 | *.i*86
45 | *.x86_64
46 | *.hex
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 |
25 |
26 |
27 | ----------------------------- LICENSE FOR DCGAN --------------------------------
28 | BSD License
29 |
30 | For dcgan.torch software
31 |
32 | Copyright (c) 2015, Facebook, Inc. All rights reserved.
33 |
34 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
35 |
36 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
37 |
38 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
39 |
40 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
41 |
42 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
43 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # pix2pix
3 | [Project](https://phillipi.github.io/pix2pix/) | [Arxiv](https://arxiv.org/abs/1611.07004) |
4 | [PyTorch](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)
5 |
6 | Torch implementation for learning a mapping from input images to output images, for example:
7 |
8 |
9 |
10 | Image-to-Image Translation with Conditional Adversarial Networks
11 | [Phillip Isola](http://web.mit.edu/phillipi/), [Jun-Yan Zhu](https://www.cs.cmu.edu/~junyanz/), [Tinghui Zhou](https://people.eecs.berkeley.edu/~tinghuiz/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/)
12 | CVPR, 2017.
13 |
14 | On some tasks, decent results can be obtained fairly quickly and on small datasets. For example, to learn to generate facades (example shown above), we trained on just 400 images for about 2 hours (on a single Pascal Titan X GPU). However, for harder problems it may be important to train on far larger datasets, and for many hours or even days.
15 |
16 | **Note**: Please check out our [PyTorch](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) implementation for pix2pix and CycleGAN. The PyTorch version is under active development and can produce results comparable to or better than this Torch version.
17 |
18 | ## Setup
19 |
20 | ### Prerequisites
21 | - Linux or OSX
22 | - NVIDIA GPU + CUDA CuDNN (CPU mode and CUDA without CuDNN may work with minimal modification, but untested)
23 |
24 | ### Getting Started
25 | - Install torch and dependencies from https://github.com/torch/distro
26 | - Install torch packages `nngraph` and `display`
27 | ```bash
28 | luarocks install nngraph
29 | luarocks install https://raw.githubusercontent.com/szym/display/master/display-scm-0.rockspec
30 | ```
31 | - Clone this repo:
32 | ```bash
33 | git clone git@github.com:phillipi/pix2pix.git
34 | cd pix2pix
35 | ```
36 | - Download the dataset (e.g., [CMP Facades](http://cmp.felk.cvut.cz/~tylecr1/facade/)):
37 | ```bash
38 | bash ./datasets/download_dataset.sh facades
39 | ```
40 | - Train the model
41 | ```bash
42 | DATA_ROOT=./datasets/facades name=facades_generation which_direction=BtoA th train.lua
43 | ```
44 | - (CPU only) The same training command without using a GPU or CUDNN. Setting the environment variables ```gpu=0 cudnn=0``` forces CPU only
45 | ```bash
46 | DATA_ROOT=./datasets/facades name=facades_generation which_direction=BtoA gpu=0 cudnn=0 batchSize=10 save_epoch_freq=5 th train.lua
47 | ```
48 | - (Optionally) start the display server to view results as the model trains. ( See [Display UI](#display-ui) for more details):
49 | ```bash
50 | th -ldisplay.start 8000 0.0.0.0
51 | ```
52 |
53 | - Finally, test the model:
54 | ```bash
55 | DATA_ROOT=./datasets/facades name=facades_generation which_direction=BtoA phase=val th test.lua
56 | ```
57 | The test results will be saved to an html file here: `./results/facades_generation/latest_net_G_val/index.html`.
58 |
59 | ## Train
60 | ```bash
61 | DATA_ROOT=/path/to/data/ name=expt_name which_direction=AtoB th train.lua
62 | ```
63 | Switch `AtoB` to `BtoA` to train translation in opposite direction.
64 |
65 | Models are saved to `./checkpoints/expt_name` (can be changed by passing `checkpoint_dir=your_dir` in train.lua).
66 |
67 | See `opt` in train.lua for additional training options.
68 |
69 | ## Test
70 | ```bash
71 | DATA_ROOT=/path/to/data/ name=expt_name which_direction=AtoB phase=val th test.lua
72 | ```
73 |
74 | This will run the model named `expt_name` in direction `AtoB` on all images in `/path/to/data/val`.
75 |
76 | Result images, and a webpage to view them, are saved to `./results/expt_name` (can be changed by passing `results_dir=your_dir` in test.lua).
77 |
78 | See `opt` in test.lua for additional testing options.
79 |
80 |
81 | ## Datasets
82 | Download the datasets using the following script. Some of the datasets are collected by other researchers. Please cite their papers if you use the data.
83 | ```bash
84 | bash ./datasets/download_dataset.sh dataset_name
85 | ```
86 | - `facades`: 400 images from [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade/). [[Citation](datasets/bibtex/facades.tex)]
87 | - `cityscapes`: 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com/). [[Citation](datasets/bibtex/cityscapes.tex)]
88 | - `maps`: 1096 training images scraped from Google Maps
89 | - `edges2shoes`: 50k training images from [UT Zappos50K dataset](http://vision.cs.utexas.edu/projects/finegrained/utzap50k/). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing.
90 | [[Citation](datasets/bibtex/shoes.tex)]
91 | - `edges2handbags`: 137K Amazon Handbag images from [iGAN project](https://github.com/junyanz/iGAN). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. [[Citation](datasets/bibtex/handbags.tex)]
92 | - `night2day`: around 20K natural scene images from [Transient Attributes dataset](http://transattr.cs.brown.edu/) [[Citation](datasets/bibtex/transattr.tex)]. To train a `day2night` pix2pix model, you need to add `which_direction=BtoA`.
93 |
94 | ## Models
95 | Download the pre-trained models with the following script. You need to rename the model (e.g., `facades_label2image` to `/checkpoints/facades/latest_net_G.t7`) after the download has finished.
96 | ```bash
97 | bash ./models/download_model.sh model_name
98 | ```
99 | - `facades_label2image` (label -> facade): trained on the CMP Facades dataset.
100 | - `cityscapes_label2image` (label -> street scene): trained on the Cityscapes dataset.
101 | - `cityscapes_image2label` (street scene -> label): trained on the Cityscapes dataset.
102 | - `edges2shoes` (edge -> photo): trained on UT Zappos50K dataset.
103 | - `edges2handbags` (edge -> photo): trained on Amazon handbags images.
104 | - `day2night` (daytime scene -> nighttime scene): trained on around 100 [webcams](http://transattr.cs.brown.edu/).
105 |
106 | ## Setup Training and Test data
107 | ### Generating Pairs
108 | We provide a python script to generate training data in the form of pairs of images {A,B}, where A and B are two different depictions of the same underlying scene. For example, these might be pairs {label map, photo} or {bw image, color image}. Then we can learn to translate A to B or B to A:
109 |
110 | Create folder `/path/to/data` with subfolders `A` and `B`. `A` and `B` should each have their own subfolders `train`, `val`, `test`, etc. In `/path/to/data/A/train`, put training images in style A. In `/path/to/data/B/train`, put the corresponding images in style B. Repeat same for other data splits (`val`, `test`, etc).
111 |
112 | Corresponding images in a pair {A,B} must be the same size and have the same filename, e.g., `/path/to/data/A/train/1.jpg` is considered to correspond to `/path/to/data/B/train/1.jpg`.
113 |
114 | Once the data is formatted this way, call:
115 | ```bash
116 | python scripts/combine_A_and_B.py --fold_A /path/to/data/A --fold_B /path/to/data/B --fold_AB /path/to/data
117 | ```
118 |
119 | This will combine each pair of images (A,B) into a single image file, ready for training.
120 |
121 | ### Notes on Colorization
122 | No need to run `combine_A_and_B.py` for colorization. Instead, you need to prepare some natural images and set `preprocess=colorization` in the script. The program will automatically convert each RGB image into Lab color space, and create `L -> ab` image pair during the training. Also set `input_nc=1` and `output_nc=2`.
123 |
124 | ### Extracting Edges
125 | We provide python and Matlab scripts to extract coarse edges from photos. Run `scripts/edges/batch_hed.py` to compute [HED](https://github.com/s9xie/hed) edges. Run `scripts/edges/PostprocessHED.m` to simplify edges with additional post-processing steps. Check the code documentation for more details.
126 |
127 | ### Evaluating Labels2Photos on Cityscapes
128 | We provide scripts for running the evaluation of the Labels2Photos task on the Cityscapes **validation** set. We assume that you have installed `caffe` (and `pycaffe`) in your system. If not, see the [official website](http://caffe.berkeleyvision.org/installation.html) for installation instructions. Once `caffe` is successfully installed, download the pre-trained FCN-8s semantic segmentation model (512MB) by running
129 | ```bash
130 | bash ./scripts/eval_cityscapes/download_fcn8s.sh
131 | ```
132 | Then make sure `./scripts/eval_cityscapes/` is in your system's python path. If not, run the following command to add it
133 | ```bash
134 | export PYTHONPATH=${PYTHONPATH}:./scripts/eval_cityscapes/
135 | ```
136 | Now you can run the following command to evaluate your predictions:
137 | ```bash
138 | python ./scripts/eval_cityscapes/evaluate.py --cityscapes_dir /path/to/original/cityscapes/dataset/ --result_dir /path/to/your/predictions/ --output_dir /path/to/output/directory/
139 | ```
140 | Images stored under `--result_dir` should contain your model predictions on the Cityscapes **validation** split, and have the original Cityscapes naming convention (e.g., `frankfurt_000001_038418_leftImg8bit.png`). The script will output a text file under `--output_dir` containing the metric.
141 |
142 | **Further notes**: Our pre-trained FCN model is **not** supposed to work on Cityscapes in the original resolution (1024x2048) as it was trained on 256x256 images that are then upsampled to 1024x2048 during training. The purpose of the resizing during training was to 1) keep the label maps in the original high resolution untouched and 2) avoid the need to change the standard FCN training code and the architecture for Cityscapes. During test time, you need to synthesize 256x256 results. Our test code will automatically upsample your results to 1024x2048 before feeding them to the pre-trained FCN model. The output is at 1024x2048 resolution and will be compared to 1024x2048 ground truth labels. You do not need to resize the ground truth labels. The best way to verify whether everything is correct is to reproduce the numbers for real images in the paper first. To achieve it, you need to resize the original/real Cityscapes images (**not** labels) to 256x256 and feed them to the evaluation code.
143 |
144 |
145 | ## Display UI
146 | Optionally, for displaying images during training and test, use the [display package](https://github.com/szym/display).
147 |
148 | - Install it with: `luarocks install https://raw.githubusercontent.com/szym/display/master/display-scm-0.rockspec`
149 | - Then start the server with: `th -ldisplay.start`
150 | - Open this URL in your browser: [http://localhost:8000](http://localhost:8000)
151 |
152 | By default, the server listens on localhost. Pass `0.0.0.0` to allow external connections on any interface:
153 | ```bash
154 | th -ldisplay.start 8000 0.0.0.0
155 | ```
156 | Then open `http://(hostname):(port)/` in your browser to load the remote desktop.
157 |
158 | L1 error is plotted to the display by default. Set the environment variable `display_plot` to a comma-separated list of values `errL1`, `errG` and `errD` to visualize the L1, generator, and discriminator error respectively. For example, to plot only the generator and discriminator errors to the display instead of the default L1 error, set `display_plot="errG,errD"`.
159 |
160 | ## Citation
161 | If you use this code for your research, please cite our paper Image-to-Image Translation Using Conditional Adversarial Networks:
162 |
163 | ```
164 | @article{pix2pix2017,
165 | title={Image-to-Image Translation with Conditional Adversarial Networks},
166 | author={Isola, Phillip and Zhu, Jun-Yan and Zhou, Tinghui and Efros, Alexei A},
167 | journal={CVPR},
168 | year={2017}
169 | }
170 | ```
171 |
172 | ## Cat Paper Collection
173 | If you love cats, and love reading cool graphics, vision, and learning papers, please check out the Cat Paper Collection:
174 | [[Github]](https://github.com/junyanz/CatPapers) [[Webpage]](https://www.cs.cmu.edu/~junyanz/cat/cat_papers.html)
175 |
176 | ## Acknowledgments
177 | Code borrows heavily from [DCGAN](https://github.com/soumith/dcgan.torch). The data loader is modified from [DCGAN](https://github.com/soumith/dcgan.torch) and [Context-Encoder](https://github.com/pathak22/context-encoder).
178 |
--------------------------------------------------------------------------------
/data/data.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | This data loader is a modified version of the one from dcgan.torch
3 | (see https://github.com/soumith/dcgan.torch/blob/master/data/data.lua).
4 |
5 | Copyright (c) 2016, Deepak Pathak [See LICENSE file for details]
6 | ]]--
7 |
8 | local Threads = require 'threads'
9 | Threads.serialization('threads.sharedserialize')
10 |
11 | local data = {}
12 |
13 | local result = {}
14 | local unpack = unpack and unpack or table.unpack
15 |
16 | function data.new(n, opt_)
17 | opt_ = opt_ or {}
18 | local self = {}
19 | for k,v in pairs(data) do
20 | self[k] = v
21 | end
22 |
23 | local donkey_file = 'donkey_folder.lua'
24 | if n > 0 then
25 | local options = opt_
26 | self.threads = Threads(n,
27 | function() require 'torch' end,
28 | function(idx)
29 | opt = options
30 | tid = idx
31 | local seed = (opt.manualSeed and opt.manualSeed or 0) + idx
32 | torch.manualSeed(seed)
33 | torch.setnumthreads(1)
34 | print(string.format('Starting donkey with id: %d seed: %d', tid, seed))
35 | assert(options, 'options not found')
36 | assert(opt, 'opt not given')
37 | print(opt)
38 | paths.dofile(donkey_file)
39 | end
40 |
41 | )
42 | else
43 | if donkey_file then paths.dofile(donkey_file) end
44 | self.threads = {}
45 | function self.threads:addjob(f1, f2) f2(f1()) end
46 | function self.threads:dojob() end
47 | function self.threads:synchronize() end
48 | end
49 |
50 | local nSamples = 0
51 | self.threads:addjob(function() return trainLoader:size() end,
52 | function(c) nSamples = c end)
53 | self.threads:synchronize()
54 | self._size = nSamples
55 |
56 | for i = 1, n do
57 | self.threads:addjob(self._getFromThreads,
58 | self._pushResult)
59 | end
60 | return self
61 | end
62 |
63 | function data._getFromThreads()
64 | assert(opt.batchSize, 'opt.batchSize not found')
65 | return trainLoader:sample(opt.batchSize)
66 | end
67 |
68 | function data._pushResult(...)
69 | local res = {...}
70 | if res == nil then
71 | self.threads:synchronize()
72 | end
73 | result[1] = res
74 | end
75 |
76 |
77 |
78 | function data:getBatch()
79 | self.threads:addjob(self._getFromThreads, self._pushResult)
80 | self.threads:dojob()
81 | local res = result[1]
82 |
83 | img_data = res[1]
84 | img_paths = res[3]
85 |
86 | result[1] = nil
87 | if torch.type(img_data) == 'table' then
88 | img_data = unpack(img_data)
89 | end
90 |
91 | return img_data, img_paths
92 | end
93 |
94 | function data:size()
95 | return self._size
96 | end
97 |
98 | return data
99 |
--------------------------------------------------------------------------------
/data/dataset.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2015-present, Facebook, Inc.
3 | All rights reserved.
4 |
5 | This source code is licensed under the BSD-style license found in the
6 | LICENSE file in the root directory of this source tree. An additional grant
7 | of patent rights can be found in the PATENTS file in the same directory.
8 | ]]--
9 |
10 | require 'torch'
11 | torch.setdefaulttensortype('torch.FloatTensor')
12 | local ffi = require 'ffi'
13 | local class = require('pl.class')
14 | local dir = require 'pl.dir'
15 | local tablex = require 'pl.tablex'
16 | local argcheck = require 'argcheck'
17 | require 'sys'
18 | require 'xlua'
19 | require 'image'
20 |
21 | local dataset = torch.class('dataLoader')
22 |
23 | local initcheck = argcheck{
24 | pack=true,
25 | help=[[
26 | A dataset class for images in a flat folder structure (folder-name is class-name).
27 | Optimized for extremely large datasets (upwards of 14 million images).
28 | Tested only on Linux (as it uses command-line linux utilities to scale up)
29 | ]],
30 | {check=function(paths)
31 | local out = true;
32 | for k,v in ipairs(paths) do
33 | if type(v) ~= 'string' then
34 | print('paths can only be of string input');
35 | out = false
36 | end
37 | end
38 | return out
39 | end,
40 | name="paths",
41 | type="table",
42 | help="Multiple paths of directories with images"},
43 |
44 | {name="sampleSize",
45 | type="table",
46 | help="a consistent sample size to resize the images"},
47 |
48 | {name="split",
49 | type="number",
50 | help="Percentage of split to go to Training"
51 | },
52 | {name="serial_batches",
53 | type="number",
54 | help="if randomly sample training images"},
55 |
56 | {name="samplingMode",
57 | type="string",
58 | help="Sampling mode: random | balanced ",
59 | default = "balanced"},
60 |
61 | {name="verbose",
62 | type="boolean",
63 | help="Verbose mode during initialization",
64 | default = false},
65 |
66 | {name="loadSize",
67 | type="table",
68 | help="a size to load the images to, initially",
69 | opt = true},
70 |
71 | {name="forceClasses",
72 | type="table",
73 | help="If you want this loader to map certain classes to certain indices, "
74 | .. "pass a classes table that has {classname : classindex} pairs."
75 | .. " For example: {3 : 'dog', 5 : 'cat'}"
76 | .. "This function is very useful when you want two loaders to have the same "
77 | .. "class indices (trainLoader/testLoader for example)",
78 | opt = true},
79 |
80 | {name="sampleHookTrain",
81 | type="function",
82 | help="applied to sample during training(ex: for lighting jitter). "
83 | .. "It takes the image path as input",
84 | opt = true},
85 |
86 | {name="sampleHookTest",
87 | type="function",
88 | help="applied to sample during testing",
89 | opt = true},
90 | }
91 |
92 | function dataset:__init(...)
93 |
94 | -- argcheck
95 | local args = initcheck(...)
96 | print(args)
97 | for k,v in pairs(args) do self[k] = v end
98 |
99 | if not self.loadSize then self.loadSize = self.sampleSize; end
100 |
101 | if not self.sampleHookTrain then self.sampleHookTrain = self.defaultSampleHook end
102 | if not self.sampleHookTest then self.sampleHookTest = self.defaultSampleHook end
103 | self.image_count = 1
104 | -- find class names
105 | self.classes = {}
106 | local classPaths = {}
107 | if self.forceClasses then
108 | for k,v in pairs(self.forceClasses) do
109 | self.classes[k] = v
110 | classPaths[k] = {}
111 | end
112 | end
113 | local function tableFind(t, o) for k,v in pairs(t) do if v == o then return k end end end
114 | -- loop over each paths folder, get list of unique class names,
115 | -- also store the directory paths per class
116 | -- for each class,
117 | for k,path in ipairs(self.paths) do
118 | local dirs = {} -- hack
119 | dirs[1] = path
120 | for k,dirpath in ipairs(dirs) do
121 | local class = paths.basename(dirpath)
122 | local idx = tableFind(self.classes, class)
123 | if not idx then
124 | table.insert(self.classes, class)
125 | idx = #self.classes
126 | classPaths[idx] = {}
127 | end
128 | if not tableFind(classPaths[idx], dirpath) then
129 | table.insert(classPaths[idx], dirpath);
130 | end
131 | end
132 | end
133 |
134 | self.classIndices = {}
135 | for k,v in ipairs(self.classes) do
136 | self.classIndices[v] = k
137 | end
138 |
139 | -- define command-line tools, try your best to maintain OSX compatibility
140 | local wc = 'wc'
141 | local cut = 'cut'
142 | local find = 'find -H' -- if folder name is symlink, do find inside it after dereferencing
143 | if ffi.os == 'OSX' then
144 | wc = 'gwc'
145 | cut = 'gcut'
146 | find = 'gfind'
147 | end
148 | ----------------------------------------------------------------------
149 | -- Options for the GNU find command
150 | local extensionList = {'jpg', 'png','JPG','PNG','JPEG', 'ppm', 'PPM', 'bmp', 'BMP'}
151 | local findOptions = ' -iname "*.' .. extensionList[1] .. '"'
152 | for i=2,#extensionList do
153 | findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"'
154 | end
155 |
156 | -- find the image path names
157 | self.imagePath = torch.CharTensor() -- path to each image in dataset
158 | self.imageClass = torch.LongTensor() -- class index of each image (class index in self.classes)
159 | self.classList = {} -- index of imageList to each image of a particular class
160 | self.classListSample = self.classList -- the main list used when sampling data
161 |
162 | print('running "find" on each class directory, and concatenate all'
163 | .. ' those filenames into a single file containing all image paths for a given class')
164 | -- so, generates one file per class
165 | local classFindFiles = {}
166 | for i=1,#self.classes do
167 | classFindFiles[i] = os.tmpname()
168 | end
169 | local combinedFindList = os.tmpname();
170 |
171 | local tmpfile = os.tmpname()
172 | local tmphandle = assert(io.open(tmpfile, 'w'))
173 | -- iterate over classes
174 | for i, class in ipairs(self.classes) do
175 | -- iterate over classPaths
176 | for j,path in ipairs(classPaths[i]) do
177 | local command = find .. ' "' .. path .. '" ' .. findOptions
178 | .. ' >>"' .. classFindFiles[i] .. '" \n'
179 | tmphandle:write(command)
180 | end
181 | end
182 | io.close(tmphandle)
183 | os.execute('bash ' .. tmpfile)
184 | os.execute('rm -f ' .. tmpfile)
185 |
186 | print('now combine all the files to a single large file')
187 | local tmpfile = os.tmpname()
188 | local tmphandle = assert(io.open(tmpfile, 'w'))
189 | -- concat all finds to a single large file in the order of self.classes
190 | for i=1,#self.classes do
191 | local command = 'cat "' .. classFindFiles[i] .. '" >>' .. combinedFindList .. ' \n'
192 | tmphandle:write(command)
193 | end
194 | io.close(tmphandle)
195 | os.execute('bash ' .. tmpfile)
196 | os.execute('rm -f ' .. tmpfile)
197 |
198 | --==========================================================================
199 | print('load the large concatenated list of sample paths to self.imagePath')
200 | local cmd = wc .. " -L '"
201 | .. combinedFindList .. "' |"
202 | .. cut .. " -f1 -d' '"
203 | print('cmd..' .. cmd)
204 | local maxPathLength = tonumber(sys.fexecute(wc .. " -L '"
205 | .. combinedFindList .. "' |"
206 | .. cut .. " -f1 -d' '")) + 1
207 | local length = tonumber(sys.fexecute(wc .. " -l '"
208 | .. combinedFindList .. "' |"
209 | .. cut .. " -f1 -d' '"))
210 | assert(length > 0, "Could not find any image file in the given input paths")
211 | assert(maxPathLength > 0, "paths of files are length 0?")
212 | self.imagePath:resize(length, maxPathLength):fill(0)
213 | local s_data = self.imagePath:data()
214 | local count = 0
215 | for line in io.lines(combinedFindList) do
216 | ffi.copy(s_data, line)
217 | s_data = s_data + maxPathLength
218 | if self.verbose and count % 10000 == 0 then
219 | xlua.progress(count, length)
220 | end;
221 | count = count + 1
222 | end
223 |
224 | self.numSamples = self.imagePath:size(1)
225 | if self.verbose then print(self.numSamples .. ' samples found.') end
226 | --==========================================================================
227 | print('Updating classList and imageClass appropriately')
228 | self.imageClass:resize(self.numSamples)
229 | local runningIndex = 0
230 | for i=1,#self.classes do
231 | if self.verbose then xlua.progress(i, #(self.classes)) end
232 | local length = tonumber(sys.fexecute(wc .. " -l '"
233 | .. classFindFiles[i] .. "' |"
234 | .. cut .. " -f1 -d' '"))
235 | if length == 0 then
236 | error('Class has zero samples')
237 | else
238 | self.classList[i] = torch.range(runningIndex + 1, runningIndex + length):long()
239 | self.imageClass[{{runningIndex + 1, runningIndex + length}}]:fill(i)
240 | end
241 | runningIndex = runningIndex + length
242 | end
243 |
244 | --==========================================================================
245 | -- clean up temporary files
246 | print('Cleaning up temporary files')
247 | local tmpfilelistall = ''
248 | for i=1,#(classFindFiles) do
249 | tmpfilelistall = tmpfilelistall .. ' "' .. classFindFiles[i] .. '"'
250 | if i % 1000 == 0 then
251 | os.execute('rm -f ' .. tmpfilelistall)
252 | tmpfilelistall = ''
253 | end
254 | end
255 | os.execute('rm -f ' .. tmpfilelistall)
256 | os.execute('rm -f "' .. combinedFindList .. '"')
257 | --==========================================================================
258 |
259 | if self.split == 100 then
260 | self.testIndicesSize = 0
261 | else
262 | print('Splitting training and test sets to a ratio of '
263 | .. self.split .. '/' .. (100-self.split))
264 | self.classListTrain = {}
265 | self.classListTest = {}
266 | self.classListSample = self.classListTrain
267 | local totalTestSamples = 0
268 | -- split the classList into classListTrain and classListTest
269 | for i=1,#self.classes do
270 | local list = self.classList[i]
271 | local count = self.classList[i]:size(1)
272 | local splitidx = math.floor((count * self.split / 100) + 0.5) -- +round
273 | local perm = torch.randperm(count)
274 | self.classListTrain[i] = torch.LongTensor(splitidx)
275 | for j=1,splitidx do
276 | self.classListTrain[i][j] = list[perm[j]]
277 | end
278 | if splitidx == count then -- all samples were allocated to train set
279 | self.classListTest[i] = torch.LongTensor()
280 | else
281 | self.classListTest[i] = torch.LongTensor(count-splitidx)
282 | totalTestSamples = totalTestSamples + self.classListTest[i]:size(1)
283 | local idx = 1
284 | for j=splitidx+1,count do
285 | self.classListTest[i][idx] = list[perm[j]]
286 | idx = idx + 1
287 | end
288 | end
289 | end
290 | -- Now combine classListTest into a single tensor
291 | self.testIndices = torch.LongTensor(totalTestSamples)
292 | self.testIndicesSize = totalTestSamples
293 | local tdata = self.testIndices:data()
294 | local tidx = 0
295 | for i=1,#self.classes do
296 | local list = self.classListTest[i]
297 | if list:dim() ~= 0 then
298 | local ldata = list:data()
299 | for j=0,list:size(1)-1 do
300 | tdata[tidx] = ldata[j]
301 | tidx = tidx + 1
302 | end
303 | end
304 | end
305 | end
306 | end
307 |
308 | -- size(), size(class)
309 | function dataset:size(class, list)
310 | list = list or self.classList
311 | if not class then
312 | return self.numSamples
313 | elseif type(class) == 'string' then
314 | return list[self.classIndices[class]]:size(1)
315 | elseif type(class) == 'number' then
316 | return list[class]:size(1)
317 | end
318 | end
319 |
320 | -- getByClass
321 | function dataset:getByClass(class)
322 | local index = 0
323 | if self.serial_batches == 1 then
324 | index = math.fmod(self.image_count-1, self.classListSample[class]:nElement())+1
325 | self.image_count = self.image_count +1
326 | else
327 | index = math.ceil(torch.uniform() * self.classListSample[class]:nElement())
328 | end
329 | local imgpath = ffi.string(torch.data(self.imagePath[self.classListSample[class][index]]))
330 | return self:sampleHookTrain(imgpath), imgpath
331 | end
332 |
333 | -- converts a table of samples (and corresponding labels) to a clean tensor
334 | local function tableToOutput(self, dataTable, scalarTable)
335 | local data, scalarLabels, labels
336 | local quantity = #scalarTable
337 | assert(dataTable[1]:dim() == 3)
338 | data = torch.Tensor(quantity,
339 | self.sampleSize[1], self.sampleSize[2], self.sampleSize[3])
340 | scalarLabels = torch.LongTensor(quantity):fill(-1111)
341 | for i=1,#dataTable do
342 | data[i]:copy(dataTable[i])
343 | scalarLabels[i] = scalarTable[i]
344 | end
345 | return data, scalarLabels
346 | end
347 |
348 | -- sampler, samples from the training set.
349 | function dataset:sample(quantity)
350 | assert(quantity)
351 | local dataTable = {}
352 | local scalarTable = {}
353 | local samplePaths = {}
354 | for i=1,quantity do
355 | local class = torch.random(1, #self.classes)
356 | local out, imgpath = self:getByClass(class)
357 | table.insert(dataTable, out)
358 | table.insert(scalarTable, class)
359 | samplePaths[i] = imgpath
360 | end
361 | local data, scalarLabels = tableToOutput(self, dataTable, scalarTable)
362 | return data, scalarLabels, samplePaths-- filePaths
363 | end
364 |
365 | function dataset:get(i1, i2)
366 | local indices = torch.range(i1, i2);
367 | local quantity = i2 - i1 + 1;
368 | assert(quantity > 0)
369 | -- now that indices has been initialized, get the samples
370 | local dataTable = {}
371 | local scalarTable = {}
372 | for i=1,quantity do
373 | -- load the sample
374 | local imgpath = ffi.string(torch.data(self.imagePath[indices[i]]))
375 | local out = self:sampleHookTest(imgpath)
376 | table.insert(dataTable, out)
377 | table.insert(scalarTable, self.imageClass[indices[i]])
378 | end
379 | local data, scalarLabels = tableToOutput(self, dataTable, scalarTable)
380 | return data, scalarLabels
381 | end
382 |
383 | return dataset
384 |
--------------------------------------------------------------------------------
/data/donkey_folder.lua:
--------------------------------------------------------------------------------
1 |
2 | --[[
3 | This data loader is a modified version of the one from dcgan.torch
4 | (see https://github.com/soumith/dcgan.torch/blob/master/data/donkey_folder.lua).
5 | Copyright (c) 2016, Deepak Pathak [See LICENSE file for details]
6 | Copyright (c) 2015-present, Facebook, Inc.
7 | All rights reserved.
8 | This source code is licensed under the BSD-style license found in the
9 | LICENSE file in the root directory of this source tree. An additional grant
10 | of patent rights can be found in the PATENTS file in the same directory.
11 | ]]--
12 |
13 | require 'image'
14 | paths.dofile('dataset.lua')
15 | -- This file contains the data-loading logic and details.
16 | -- It is run by each data-loader thread.
17 | ------------------------------------------
18 | -------- COMMON CACHES and PATHS
19 | -- Check for existence of opt.data
20 | print(os.getenv('DATA_ROOT'))
21 | opt.data = paths.concat(os.getenv('DATA_ROOT'), opt.phase)
22 |
23 | if not paths.dirp(opt.data) then
24 | error('Did not find directory: ' .. opt.data)
25 | end
26 |
27 | -- a cache file of the training metadata (if doesnt exist, will be created)
28 | local cache = "cache"
29 | local cache_prefix = opt.data:gsub('/', '_')
30 | os.execute('mkdir -p cache')
31 | local trainCache = paths.concat(cache, cache_prefix .. '_trainCache.t7')
32 |
33 | --------------------------------------------------------------------------------------------
34 | local input_nc = opt.input_nc -- input channels
35 | local output_nc = opt.output_nc
36 | local loadSize = {input_nc, opt.loadSize}
37 | local sampleSize = {input_nc, opt.fineSize}
38 |
39 | local preprocessAandB = function(imA, imB)
40 | imA = image.scale(imA, loadSize[2], loadSize[2])
41 | imB = image.scale(imB, loadSize[2], loadSize[2])
42 | local perm = torch.LongTensor{3, 2, 1}
43 | imA = imA:index(1, perm)--:mul(256.0): brg, rgb
44 | imA = imA:mul(2):add(-1)
45 | imB = imB:index(1, perm)
46 | imB = imB:mul(2):add(-1)
47 | -- print(img:size())
48 | assert(imA:max()<=1,"A: badly scaled inputs")
49 | assert(imA:min()>=-1,"A: badly scaled inputs")
50 | assert(imB:max()<=1,"B: badly scaled inputs")
51 | assert(imB:min()>=-1,"B: badly scaled inputs")
52 |
53 |
54 | local oW = sampleSize[2]
55 | local oH = sampleSize[2]
56 | local iH = imA:size(2)
57 | local iW = imA:size(3)
58 |
59 | if iH~=oH then
60 | h1 = math.ceil(torch.uniform(1e-2, iH-oH))
61 | end
62 |
63 | if iW~=oW then
64 | w1 = math.ceil(torch.uniform(1e-2, iW-oW))
65 | end
66 | if iH ~= oH or iW ~= oW then
67 | imA = image.crop(imA, w1, h1, w1 + oW, h1 + oH)
68 | imB = image.crop(imB, w1, h1, w1 + oW, h1 + oH)
69 | end
70 |
71 | if opt.flip == 1 and torch.uniform() > 0.5 then
72 | imA = image.hflip(imA)
73 | imB = image.hflip(imB)
74 | end
75 |
76 | return imA, imB
77 | end
78 |
79 |
80 |
81 | local function loadImageChannel(path)
82 | local input = image.load(path, 3, 'float')
83 | input = image.scale(input, loadSize[2], loadSize[2])
84 |
85 | local oW = sampleSize[2]
86 | local oH = sampleSize[2]
87 | local iH = input:size(2)
88 | local iW = input:size(3)
89 |
90 | if iH~=oH then
91 | h1 = math.ceil(torch.uniform(1e-2, iH-oH))
92 | end
93 |
94 | if iW~=oW then
95 | w1 = math.ceil(torch.uniform(1e-2, iW-oW))
96 | end
97 | if iH ~= oH or iW ~= oW then
98 | input = image.crop(input, w1, h1, w1 + oW, h1 + oH)
99 | end
100 |
101 |
102 | if opt.flip == 1 and torch.uniform() > 0.5 then
103 | input = image.hflip(input)
104 | end
105 |
106 | local input_lab = image.rgb2lab(input)
107 | local imA = input_lab[{{1}, {}, {} }]:div(50.0) - 1.0
108 | local imB = input_lab[{{2,3},{},{}}]:div(110.0)
109 | local imAB = torch.cat(imA, imB, 1)
110 | assert(imAB:max()<=1,"A: badly scaled inputs")
111 | assert(imAB:min()>=-1,"A: badly scaled inputs")
112 |
113 | return imAB
114 | end
115 |
116 | --local function loadImage
117 |
118 | local function loadImage(path)
119 | local input = image.load(path, 3, 'float')
120 | local h = input:size(2)
121 | local w = input:size(3)
122 |
123 | local imA = image.crop(input, 0, 0, w/2, h)
124 | local imB = image.crop(input, w/2, 0, w, h)
125 |
126 | return imA, imB
127 | end
128 |
129 | local function loadImageInpaint(path)
130 | local imB = image.load(path, 3, 'float')
131 | imB = image.scale(imB, loadSize[2], loadSize[2])
132 | local perm = torch.LongTensor{3, 2, 1}
133 | imB = imB:index(1, perm)--:mul(256.0): brg, rgb
134 | imB = imB:mul(2):add(-1)
135 | assert(imB:max()<=1,"A: badly scaled inputs")
136 | assert(imB:min()>=-1,"A: badly scaled inputs")
137 | local oW = sampleSize[2]
138 | local oH = sampleSize[2]
139 | local iH = imB:size(2)
140 | local iW = imB:size(3)
141 | if iH~=oH then
142 | h1 = math.ceil(torch.uniform(1e-2, iH-oH))
143 | end
144 |
145 | if iW~=oW then
146 | w1 = math.ceil(torch.uniform(1e-2, iW-oW))
147 | end
148 | if iH ~= oH or iW ~= oW then
149 | imB = image.crop(imB, w1, h1, w1 + oW, h1 + oH)
150 | end
151 | local imA = imB:clone()
152 | imA[{{},{1 + oH/4, oH/2 + oH/4},{1 + oW/4, oW/2 + oW/4}}] = 1.0
153 | if opt.flip == 1 and torch.uniform() > 0.5 then
154 | imA = image.hflip(imA)
155 | imB = image.hflip(imB)
156 | end
157 | imAB = torch.cat(imA, imB, 1)
158 | return imAB
159 | end
160 |
161 | -- channel-wise mean and std. Calculate or load them from disk later in the script.
162 | local mean,std
163 | --------------------------------------------------------------------------------
164 | -- Hooks that are used for each image that is loaded
165 |
166 | -- function to load the image, jitter it appropriately (random crops etc.)
167 | local trainHook = function(self, path)
168 | collectgarbage()
169 | if opt.preprocess == 'regular' then
170 | local imA, imB = loadImage(path)
171 | imA, imB = preprocessAandB(imA, imB)
172 | imAB = torch.cat(imA, imB, 1)
173 | end
174 |
175 | if opt.preprocess == 'colorization' then
176 | imAB = loadImageChannel(path)
177 | end
178 |
179 | if opt.preprocess == 'inpaint' then
180 | imAB = loadImageInpaint(path)
181 | end
182 | return imAB
183 | end
184 |
185 | --------------------------------------
186 | -- trainLoader
187 | print('trainCache', trainCache)
188 | print('Creating train metadata')
189 | print('serial batch:, ', opt.serial_batches)
190 | trainLoader = dataLoader{
191 | paths = {opt.data},
192 | loadSize = {input_nc, loadSize[2], loadSize[2]},
193 | sampleSize = {input_nc+output_nc, sampleSize[2], sampleSize[2]},
194 | split = 100,
195 | serial_batches = opt.serial_batches,
196 | verbose = true
197 | }
198 |
199 | trainLoader.sampleHookTrain = trainHook
200 | collectgarbage()
201 |
202 | -- do some sanity checks on trainLoader
203 | do
204 | local class = trainLoader.imageClass
205 | local nClasses = #trainLoader.classes
206 | assert(class:max() <= nClasses, "class logic has error")
207 | assert(class:min() >= 1, "class logic has error")
208 | end
209 |
--------------------------------------------------------------------------------
/datasets/bibtex/cityscapes.tex:
--------------------------------------------------------------------------------
1 | @inproceedings{Cordts2016Cityscapes,
2 | title={The Cityscapes Dataset for Semantic Urban Scene Understanding},
3 | author={Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt},
4 | booktitle={Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
5 | year={2016}
6 | }
7 |
--------------------------------------------------------------------------------
/datasets/bibtex/facades.tex:
--------------------------------------------------------------------------------
1 | @INPROCEEDINGS{Tylecek13,
2 | author = {Radim Tyle{\v c}ek, Radim {\v S}{\' a}ra},
3 | title = {Spatial Pattern Templates for Recognition of Objects with Regular Structure},
4 | booktitle = {Proc. GCPR},
5 | year = {2013},
6 | address = {Saarbrucken, Germany},
7 | }
8 |
--------------------------------------------------------------------------------
/datasets/bibtex/handbags.tex:
--------------------------------------------------------------------------------
1 | @inproceedings{zhu2016generative,
2 | title={Generative Visual Manipulation on the Natural Image Manifold},
3 | author={Zhu, Jun-Yan and Kr{\"a}henb{\"u}hl, Philipp and Shechtman, Eli and Efros, Alexei A.},
4 | booktitle={Proceedings of European Conference on Computer Vision (ECCV)},
5 | year={2016}
6 | }
7 |
8 | @InProceedings{xie15hed,
9 | author = {"Xie, Saining and Tu, Zhuowen"},
10 | Title = {Holistically-Nested Edge Detection},
11 | Booktitle = "Proceedings of IEEE International Conference on Computer Vision",
12 | Year = {2015},
13 | }
14 |
--------------------------------------------------------------------------------
/datasets/bibtex/shoes.tex:
--------------------------------------------------------------------------------
1 | @InProceedings{fine-grained,
2 | author = {A. Yu and K. Grauman},
3 | title = {{F}ine-{G}rained {V}isual {C}omparisons with {L}ocal {L}earning},
4 | booktitle = {Computer Vision and Pattern Recognition (CVPR)},
5 | month = {June},
6 | year = {2014}
7 | }
8 |
9 | @InProceedings{xie15hed,
10 | author = {"Xie, Saining and Tu, Zhuowen"},
11 | Title = {Holistically-Nested Edge Detection},
12 | Booktitle = "Proceedings of IEEE International Conference on Computer Vision",
13 | Year = {2015},
14 | }
15 |
--------------------------------------------------------------------------------
/datasets/bibtex/transattr.tex:
--------------------------------------------------------------------------------
1 | @article {Laffont14,
2 | title = {Transient Attributes for High-Level Understanding and Editing of Outdoor Scenes},
3 | author = {Pierre-Yves Laffont and Zhile Ren and Xiaofeng Tao and Chao Qian and James Hays},
4 | journal = {ACM Transactions on Graphics (proceedings of SIGGRAPH)},
5 | volume = {33},
6 | number = {4},
7 | year = {2014}
8 | }
9 |
--------------------------------------------------------------------------------
/datasets/download_dataset.sh:
--------------------------------------------------------------------------------
1 | FILE=$1
2 |
3 | if [[ $FILE != "cityscapes" && $FILE != "night2day" && $FILE != "edges2handbags" && $FILE != "edges2shoes" && $FILE != "facades" && $FILE != "maps" ]]; then
4 | echo "Available datasets are cityscapes, night2day, edges2handbags, edges2shoes, facades, maps"
5 | exit 1
6 | fi
7 |
8 | echo "Specified [$FILE]"
9 |
10 | URL=http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/$FILE.tar.gz
11 | TAR_FILE=./datasets/$FILE.tar.gz
12 | TARGET_DIR=./datasets/$FILE/
13 | wget -N $URL -O $TAR_FILE
14 | mkdir -p $TARGET_DIR
15 | tar -zxvf $TAR_FILE -C ./datasets/
16 | rm $TAR_FILE
17 |
--------------------------------------------------------------------------------
/imgs/examples.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/phillipi/pix2pix/89ff2a81ce441fbe1f1b13eca463b87f1e539df8/imgs/examples.jpg
--------------------------------------------------------------------------------
/models.lua:
--------------------------------------------------------------------------------
1 | require 'nngraph'
2 |
3 | function defineG_encoder_decoder(input_nc, output_nc, ngf)
4 | local netG = nil
5 | -- input is (nc) x 256 x 256
6 | local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
7 | -- input is (ngf) x 128 x 128
8 | local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2)
9 | -- input is (ngf * 2) x 64 x 64
10 | local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4)
11 | -- input is (ngf * 4) x 32 x 32
12 | local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
13 | -- input is (ngf * 8) x 16 x 16
14 | local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
15 | -- input is (ngf * 8) x 8 x 8
16 | local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
17 | -- input is (ngf * 8) x 4 x 4
18 | local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
19 | -- input is (ngf * 8) x 2 x 2
20 | local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1)
21 | -- input is (ngf * 8) x 1 x 1
22 |
23 | local d1 = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)
24 | -- input is (ngf * 8) x 2 x 2
25 | local d2 = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)
26 | -- input is (ngf * 8) x 4 x 4
27 | local d3 = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)
28 | -- input is (ngf * 8) x 8 x 8
29 | local d4 = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
30 | -- input is (ngf * 8) x 16 x 16
31 | local d5 = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4)
32 | -- input is (ngf * 4) x 32 x 32
33 | local d6 = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2)
34 | -- input is (ngf * 2) x 64 x 64
35 | local d7 = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, ngf, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf)
36 | -- input is (ngf) x128 x 128
37 | local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf, output_nc, 4, 4, 2, 2, 1, 1)
38 | -- input is (nc) x 256 x 256
39 |
40 | local o1 = d8 - nn.Tanh()
41 |
42 | netG = nn.gModule({e1},{o1})
43 |
44 | return netG
45 | end
46 |
47 | function defineG_unet(input_nc, output_nc, ngf)
48 | local netG = nil
49 | -- input is (nc) x 256 x 256
50 | local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
51 | -- input is (ngf) x 128 x 128
52 | local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2)
53 | -- input is (ngf * 2) x 64 x 64
54 | local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4)
55 | -- input is (ngf * 4) x 32 x 32
56 | local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
57 | -- input is (ngf * 8) x 16 x 16
58 | local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
59 | -- input is (ngf * 8) x 8 x 8
60 | local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
61 | -- input is (ngf * 8) x 4 x 4
62 | local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
63 | -- input is (ngf * 8) x 2 x 2
64 | local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1)
65 | -- input is (ngf * 8) x 1 x 1
66 |
67 | local d1_ = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)
68 | -- input is (ngf * 8) x 2 x 2
69 | local d1 = {d1_,e7} - nn.JoinTable(2)
70 | local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)
71 | -- input is (ngf * 8) x 4 x 4
72 | local d2 = {d2_,e6} - nn.JoinTable(2)
73 | local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)
74 | -- input is (ngf * 8) x 8 x 8
75 | local d3 = {d3_,e5} - nn.JoinTable(2)
76 | local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
77 | -- input is (ngf * 8) x 16 x 16
78 | local d4 = {d4_,e4} - nn.JoinTable(2)
79 | local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4)
80 | -- input is (ngf * 4) x 32 x 32
81 | local d5 = {d5_,e3} - nn.JoinTable(2)
82 | local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2)
83 | -- input is (ngf * 2) x 64 x 64
84 | local d6 = {d6_,e2} - nn.JoinTable(2)
85 | local d7_ = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf)
86 | -- input is (ngf) x128 x 128
87 | local d7 = {d7_,e1} - nn.JoinTable(2)
88 | local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1)
89 | -- input is (nc) x 256 x 256
90 |
91 | local o1 = d8 - nn.Tanh()
92 |
93 | netG = nn.gModule({e1},{o1})
94 |
95 | --graph.dot(netG.fg,'netG')
96 |
97 | return netG
98 | end
99 |
100 | function defineG_unet_128(input_nc, output_nc, ngf)
101 | -- Two layer less than the default unet to handle 128x128 input
102 | local netG = nil
103 | -- input is (nc) x 128 x 128
104 | local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
105 | -- input is (ngf) x 64 x 64
106 | local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2)
107 | -- input is (ngf * 2) x 32 x 32
108 | local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4)
109 | -- input is (ngf * 4) x 16 x 16
110 | local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
111 | -- input is (ngf * 8) x 8 x 8
112 | local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
113 | -- input is (ngf * 8) x 4 x 4
114 | local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8)
115 | -- input is (ngf * 8) x 2 x 2
116 | local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1)
117 | -- input is (ngf * 8) x 1 x 1
118 |
119 | local d1_ = e7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)
120 | -- input is (ngf * 8) x 2 x 2
121 | local d1 = {d1_,e6} - nn.JoinTable(2)
122 | local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)
123 | -- input is (ngf * 8) x 4 x 4
124 | local d2 = {d2_,e5} - nn.JoinTable(2)
125 | local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 8) - nn.Dropout(0.5)
126 | -- input is (ngf * 8) x 8 x 8
127 | local d3 = {d3_,e4} - nn.JoinTable(2)
128 | local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 4)
129 | -- input is (ngf * 8) x 16 x 16
130 | local d4 = {d4_,e3} - nn.JoinTable(2)
131 | local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf * 2)
132 | -- input is (ngf * 4) x 32 x 32
133 | local d5 = {d5_,e2} - nn.JoinTable(2)
134 | local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - nn.SpatialBatchNormalization(ngf)
135 | -- input is (ngf * 2) x 64 x 64
136 | local d6 = {d6_,e1} - nn.JoinTable(2)
137 | local d7 = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1)
138 | -- input is (ngf) x128 x 128
139 |
140 | local o1 = d7 - nn.Tanh()
141 |
142 | netG = nn.gModule({e1},{o1})
143 |
144 | --graph.dot(netG.fg,'netG')
145 |
146 | return netG
147 | end
148 |
149 | function defineD_basic(input_nc, output_nc, ndf)
150 | n_layers = 3
151 | return defineD_n_layers(input_nc, output_nc, ndf, n_layers)
152 | end
153 |
154 | -- rf=1
155 | function defineD_pixelGAN(input_nc, output_nc, ndf)
156 | local netD = nn.Sequential()
157 |
158 | -- input is (nc) x 256 x 256
159 | netD:add(nn.SpatialConvolution(input_nc+output_nc, ndf, 1, 1, 1, 1, 0, 0))
160 | netD:add(nn.LeakyReLU(0.2, true))
161 | -- state size: (ndf) x 256 x 256
162 | netD:add(nn.SpatialConvolution(ndf, ndf * 2, 1, 1, 1, 1, 0, 0))
163 | netD:add(nn.SpatialBatchNormalization(ndf * 2)):add(nn.LeakyReLU(0.2, true))
164 | -- state size: (ndf*2) x 256 x 256
165 | netD:add(nn.SpatialConvolution(ndf * 2, 1, 1, 1, 1, 1, 0, 0))
166 | -- state size: 1 x 256 x 256
167 | netD:add(nn.Sigmoid())
168 | -- state size: 1 x 256 x 256
169 |
170 | return netD
171 | end
172 |
173 | -- if n=0, then use pixelGAN (rf=1)
174 | -- else rf is 16 if n=1
175 | -- 34 if n=2
176 | -- 70 if n=3
177 | -- 142 if n=4
178 | -- 286 if n=5
179 | -- 574 if n=6
180 | function defineD_n_layers(input_nc, output_nc, ndf, n_layers)
181 | if n_layers==0 then
182 | return defineD_pixelGAN(input_nc, output_nc, ndf)
183 | else
184 |
185 | local netD = nn.Sequential()
186 |
187 | -- input is (nc) x 256 x 256
188 | netD:add(nn.SpatialConvolution(input_nc+output_nc, ndf, 4, 4, 2, 2, 1, 1))
189 | netD:add(nn.LeakyReLU(0.2, true))
190 |
191 | local nf_mult = 1
192 | local nf_mult_prev = 1
193 | for n = 1, n_layers-1 do
194 | nf_mult_prev = nf_mult
195 | nf_mult = math.min(2^n,8)
196 | netD:add(nn.SpatialConvolution(ndf * nf_mult_prev, ndf * nf_mult, 4, 4, 2, 2, 1, 1))
197 | netD:add(nn.SpatialBatchNormalization(ndf * nf_mult)):add(nn.LeakyReLU(0.2, true))
198 | end
199 |
200 | -- state size: (ndf*M) x N x N
201 | nf_mult_prev = nf_mult
202 | nf_mult = math.min(2^n_layers,8)
203 | netD:add(nn.SpatialConvolution(ndf * nf_mult_prev, ndf * nf_mult, 4, 4, 1, 1, 1, 1))
204 | netD:add(nn.SpatialBatchNormalization(ndf * nf_mult)):add(nn.LeakyReLU(0.2, true))
205 | -- state size: (ndf*M*2) x (N-1) x (N-1)
206 | netD:add(nn.SpatialConvolution(ndf * nf_mult, 1, 4, 4, 1, 1, 1, 1))
207 | -- state size: 1 x (N-2) x (N-2)
208 |
209 | netD:add(nn.Sigmoid())
210 | -- state size: 1 x (N-2) x (N-2)
211 |
212 | return netD
213 | end
214 | end
215 |
--------------------------------------------------------------------------------
/models/download_model.sh:
--------------------------------------------------------------------------------
1 | FILE=$1
2 | URL=http://efrosgans.eecs.berkeley.edu/pix2pix/models/$FILE.t7
3 | MODEL_FILE=./models/$FILE.t7
4 | wget -N $URL -O $MODEL_FILE
5 |
--------------------------------------------------------------------------------
/scripts/combine_A_and_B.py:
--------------------------------------------------------------------------------
1 | from pdb import set_trace as st
2 | import os
3 | import numpy as np
4 | import cv2
5 | import argparse
6 |
7 | parser = argparse.ArgumentParser('create image pairs')
8 | parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges')
9 | parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg')
10 | parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB')
11 | parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000)
12 | parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)',action='store_true')
13 | args = parser.parse_args()
14 |
15 | for arg in vars(args):
16 | print('[%s] = ' % arg, getattr(args, arg))
17 |
18 | splits = filter( lambda f: not f.startswith('.'), os.listdir(args.fold_A)) # ignore hidden folders like .DS_Store
19 |
20 | for sp in splits:
21 | img_fold_A = os.path.join(args.fold_A, sp)
22 | img_fold_B = os.path.join(args.fold_B, sp)
23 | img_list = filter( lambda f: not f.startswith('.'), os.listdir(img_fold_A)) # ignore hidden folders like .DS_Store
24 | img_list = list(img_list)
25 | if args.use_AB:
26 | img_list = [img_path for img_path in img_list if '_A.' in img_path]
27 |
28 | num_imgs = min(args.num_imgs, len(img_list))
29 | print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list)))
30 | img_fold_AB = os.path.join(args.fold_AB, sp)
31 | if not os.path.isdir(img_fold_AB):
32 | os.makedirs(img_fold_AB)
33 | print('split = %s, number of images = %d' % (sp, num_imgs))
34 | for n in range(num_imgs):
35 | name_A = img_list[n]
36 | path_A = os.path.join(img_fold_A, name_A)
37 | if args.use_AB:
38 | name_B = name_A.replace('_A.', '_B.')
39 | else:
40 | name_B = name_A
41 | path_B = os.path.join(img_fold_B, name_B)
42 | if os.path.isfile(path_A) and os.path.isfile(path_B):
43 | name_AB = name_A
44 | if args.use_AB:
45 | name_AB = name_AB.replace('_A.', '.') # remove _A
46 | path_AB = os.path.join(img_fold_AB, name_AB)
47 | im_A = cv2.imread(path_A, cv2.IMREAD_COLOR)
48 | im_B = cv2.imread(path_B, cv2.IMREAD_COLOR)
49 | im_AB = np.concatenate([im_A, im_B], 1)
50 | cv2.imwrite(path_AB, im_AB)
51 |
52 |
--------------------------------------------------------------------------------
/scripts/edges/PostprocessHED.m:
--------------------------------------------------------------------------------
1 | %%% Prerequisites
2 | % You need to get the cpp file edgesNmsMex.cpp from https://raw.githubusercontent.com/pdollar/edges/master/private/edgesNmsMex.cpp
3 | % and compile it in Matlab: mex edgesNmsMex.cpp
4 | % You also need to download and install Piotr's Computer Vision Matlab Toolbox: https://pdollar.github.io/toolbox/
5 |
6 | %%% parameters
7 | % hed_mat_dir: the hed mat file directory (the output of 'batch_hed.py')
8 | % edge_dir: the output HED edges directory
9 | % image_width: resize the edge map to [image_width, image_width]
10 | % threshold: threshold for image binarization (default 25.0/255.0)
11 | % small_edge: remove small edges (default 5)
12 |
13 | function [] = PostprocessHED(hed_mat_dir, edge_dir, image_width, threshold, small_edge)
14 |
15 | if ~exist(edge_dir, 'dir')
16 | mkdir(edge_dir);
17 | end
18 | fileList = dir(fullfile(hed_mat_dir, '*.mat'));
19 | nFiles = numel(fileList);
20 | fprintf('find %d mat files\n', nFiles);
21 |
22 | for n = 1 : nFiles
23 | if mod(n, 1000) == 0
24 | fprintf('process %d/%d images\n', n, nFiles);
25 | end
26 | fileName = fileList(n).name;
27 | filePath = fullfile(hed_mat_dir, fileName);
28 | jpgName = strrep(fileName, '.mat', '.jpg');
29 | edge_path = fullfile(edge_dir, jpgName);
30 |
31 | if ~exist(edge_path, 'file')
32 | E = GetEdge(filePath);
33 | E = imresize(E,[image_width,image_width]);
34 | E_simple = SimpleEdge(E, threshold, small_edge);
35 | E_simple = uint8(E_simple*255);
36 | imwrite(E_simple, edge_path, 'Quality',100);
37 | end
38 | end
39 | end
40 |
41 |
42 |
43 |
44 | function [E] = GetEdge(filePath)
45 | load(filePath);
46 | E = 1-edge_predict;
47 | end
48 |
49 | function [E4] = SimpleEdge(E, threshold, small_edge)
50 | if nargin <= 1
51 | threshold = 25.0/255.0;
52 | end
53 |
54 | if nargin <= 2
55 | small_edge = 5;
56 | end
57 |
58 | if ndims(E) == 3
59 | E = E(:,:,1);
60 | end
61 |
62 | E1 = 1 - E;
63 | E2 = EdgeNMS(E1);
64 | E3 = double(E2>=max(eps,threshold));
65 | E3 = bwmorph(E3,'thin',inf);
66 | E4 = bwareaopen(E3, small_edge);
67 | E4=1-E4;
68 | end
69 |
70 | function [E_nms] = EdgeNMS( E )
71 | E=single(E);
72 | [Ox,Oy] = gradient2(convTri(E,4));
73 | [Oxx,~] = gradient2(Ox);
74 | [Oxy,Oyy] = gradient2(Oy);
75 | O = mod(atan(Oyy.*sign(-Oxy)./(Oxx+1e-5)),pi);
76 | E_nms = edgesNmsMex(E,O,1,5,1.01,1);
77 | end
78 |
--------------------------------------------------------------------------------
/scripts/edges/batch_hed.py:
--------------------------------------------------------------------------------
1 | # HED batch processing script; modified from https://github.com/s9xie/hed/blob/master/examples/hed/HED-tutorial.ipynb
2 | # Step 1: download the hed repo: https://github.com/s9xie/hed
3 | # Step 2: download the models and protoxt, and put them under {caffe_root}/examples/hed/
4 | # Step 3: put this script under {caffe_root}/examples/hed/
5 | # Step 4: run the following script:
6 | # python batch_hed.py --images_dir=/data/to/path/photos/ --hed_mat_dir=/data/to/path/hed_mat_files/
7 | # The code sometimes crashes after computation is done. Error looks like "Check failed: ... driver shutting down". You can just kill the job.
8 | # For large images, it will produce gpu memory issue. Therefore, you better resize the images before running this script.
9 | # Step 5: run the MATLAB post-processing script "PostprocessHED.m"
10 | import scipy.io as sio
11 | import caffe
12 | import sys
13 | import numpy as np
14 | from PIL import Image
15 | import os
16 | import argparse
17 |
18 |
19 | def parse_args():
20 | parser = argparse.ArgumentParser(description='batch proccesing: photos->edges')
21 | parser.add_argument('--caffe_root', dest='caffe_root', help='caffe root', default='../../', type=str)
22 | parser.add_argument('--caffemodel', dest='caffemodel', help='caffemodel', default='./hed_pretrained_bsds.caffemodel', type=str)
23 | parser.add_argument('--prototxt', dest='prototxt', help='caffe prototxt file', default='./deploy.prototxt', type=str)
24 | parser.add_argument('--images_dir', dest='images_dir', help='directory to store input photos', type=str)
25 | parser.add_argument('--hed_mat_dir', dest='hed_mat_dir', help='directory to store output hed edges in mat file', type=str)
26 | parser.add_argument('--border', dest='border', help='padding border', type=int, default=128)
27 | parser.add_argument('--gpu_id', dest='gpu_id', help='gpu id', type=int, default=1)
28 | args = parser.parse_args()
29 | return args
30 |
31 |
32 | args = parse_args()
33 | for arg in vars(args):
34 | print('[%s] =' % arg, getattr(args, arg))
35 | # Make sure that caffe is on the python path:
36 | caffe_root = args.caffe_root # this file is expected to be in {caffe_root}/examples/hed/
37 | sys.path.insert(0, caffe_root + 'python')
38 |
39 |
40 | if not os.path.exists(args.hed_mat_dir):
41 | print('create output directory %s' % args.hed_mat_dir)
42 | os.makedirs(args.hed_mat_dir)
43 |
44 | imgList = os.listdir(args.images_dir)
45 | nImgs = len(imgList)
46 | print('#images = %d' % nImgs)
47 |
48 | caffe.set_mode_gpu()
49 | caffe.set_device(args.gpu_id)
50 | # load net
51 | net = caffe.Net(args.prototxt, args.caffemodel, caffe.TEST)
52 | # pad border
53 | border = args.border
54 |
55 | for i in range(nImgs):
56 | if i % 500 == 0:
57 | print('processing image %d/%d' % (i, nImgs))
58 | im = Image.open(os.path.join(args.images_dir, imgList[i]))
59 |
60 | in_ = np.array(im, dtype=np.float32)
61 | in_ = np.pad(in_, ((border, border), (border, border), (0, 0)), 'reflect')
62 |
63 | in_ = in_[:, :, 0:3]
64 | in_ = in_[:, :, ::-1]
65 | in_ -= np.array((104.00698793, 116.66876762, 122.67891434))
66 | in_ = in_.transpose((2, 0, 1))
67 | # remove the following two lines if testing with cpu
68 |
69 | # shape for input (data blob is N x C x H x W), set data
70 | net.blobs['data'].reshape(1, *in_.shape)
71 | net.blobs['data'].data[...] = in_
72 | # run net and take argmax for prediction
73 | net.forward()
74 | fuse = net.blobs['sigmoid-fuse'].data[0][0, :, :]
75 | # get rid of the border
76 | fuse = fuse[(border+35):(-border+35), (border+35):(-border+35)]
77 | # save hed file to the disk
78 | name, ext = os.path.splitext(imgList[i])
79 | sio.savemat(os.path.join(args.hed_mat_dir, name + '.mat'), {'edge_predict': fuse})
80 |
--------------------------------------------------------------------------------
/scripts/eval_cityscapes/caffemodel/deploy.prototxt:
--------------------------------------------------------------------------------
1 | layer {
2 | name: "data"
3 | type: "Input"
4 | top: "data"
5 | input_param {
6 | shape {
7 | dim: 1
8 | dim: 3
9 | dim: 500
10 | dim: 500
11 | }
12 | }
13 | }
14 | layer {
15 | name: "conv1_1"
16 | type: "Convolution"
17 | bottom: "data"
18 | top: "conv1_1"
19 | param {
20 | lr_mult: 1
21 | decay_mult: 1
22 | }
23 | param {
24 | lr_mult: 2
25 | decay_mult: 0
26 | }
27 | convolution_param {
28 | num_output: 64
29 | pad: 100
30 | kernel_size: 3
31 | stride: 1
32 | weight_filler {
33 | type: "gaussian"
34 | std: 0.01
35 | }
36 | bias_filler {
37 | type: "constant"
38 | value: 0
39 | }
40 | }
41 | }
42 | layer {
43 | name: "relu1_1"
44 | type: "ReLU"
45 | bottom: "conv1_1"
46 | top: "conv1_1"
47 | }
48 | layer {
49 | name: "conv1_2"
50 | type: "Convolution"
51 | bottom: "conv1_1"
52 | top: "conv1_2"
53 | param {
54 | lr_mult: 1
55 | decay_mult: 1
56 | }
57 | param {
58 | lr_mult: 2
59 | decay_mult: 0
60 | }
61 | convolution_param {
62 | num_output: 64
63 | pad: 1
64 | kernel_size: 3
65 | stride: 1
66 | weight_filler {
67 | type: "gaussian"
68 | std: 0.01
69 | }
70 | bias_filler {
71 | type: "constant"
72 | value: 0
73 | }
74 | }
75 | }
76 | layer {
77 | name: "relu1_2"
78 | type: "ReLU"
79 | bottom: "conv1_2"
80 | top: "conv1_2"
81 | }
82 | layer {
83 | name: "pool1"
84 | type: "Pooling"
85 | bottom: "conv1_2"
86 | top: "pool1"
87 | pooling_param {
88 | pool: MAX
89 | kernel_size: 2
90 | stride: 2
91 | }
92 | }
93 | layer {
94 | name: "conv2_1"
95 | type: "Convolution"
96 | bottom: "pool1"
97 | top: "conv2_1"
98 | param {
99 | lr_mult: 1
100 | decay_mult: 1
101 | }
102 | param {
103 | lr_mult: 2
104 | decay_mult: 0
105 | }
106 | convolution_param {
107 | num_output: 128
108 | pad: 1
109 | kernel_size: 3
110 | stride: 1
111 | weight_filler {
112 | type: "gaussian"
113 | std: 0.01
114 | }
115 | bias_filler {
116 | type: "constant"
117 | value: 0
118 | }
119 | }
120 | }
121 | layer {
122 | name: "relu2_1"
123 | type: "ReLU"
124 | bottom: "conv2_1"
125 | top: "conv2_1"
126 | }
127 | layer {
128 | name: "conv2_2"
129 | type: "Convolution"
130 | bottom: "conv2_1"
131 | top: "conv2_2"
132 | param {
133 | lr_mult: 1
134 | decay_mult: 1
135 | }
136 | param {
137 | lr_mult: 2
138 | decay_mult: 0
139 | }
140 | convolution_param {
141 | num_output: 128
142 | pad: 1
143 | kernel_size: 3
144 | stride: 1
145 | weight_filler {
146 | type: "gaussian"
147 | std: 0.01
148 | }
149 | bias_filler {
150 | type: "constant"
151 | value: 0
152 | }
153 | }
154 | }
155 | layer {
156 | name: "relu2_2"
157 | type: "ReLU"
158 | bottom: "conv2_2"
159 | top: "conv2_2"
160 | }
161 | layer {
162 | name: "pool2"
163 | type: "Pooling"
164 | bottom: "conv2_2"
165 | top: "pool2"
166 | pooling_param {
167 | pool: MAX
168 | kernel_size: 2
169 | stride: 2
170 | }
171 | }
172 | layer {
173 | name: "conv3_1"
174 | type: "Convolution"
175 | bottom: "pool2"
176 | top: "conv3_1"
177 | param {
178 | lr_mult: 1
179 | decay_mult: 1
180 | }
181 | param {
182 | lr_mult: 2
183 | decay_mult: 0
184 | }
185 | convolution_param {
186 | num_output: 256
187 | pad: 1
188 | kernel_size: 3
189 | stride: 1
190 | weight_filler {
191 | type: "gaussian"
192 | std: 0.01
193 | }
194 | bias_filler {
195 | type: "constant"
196 | value: 0
197 | }
198 | }
199 | }
200 | layer {
201 | name: "relu3_1"
202 | type: "ReLU"
203 | bottom: "conv3_1"
204 | top: "conv3_1"
205 | }
206 | layer {
207 | name: "conv3_2"
208 | type: "Convolution"
209 | bottom: "conv3_1"
210 | top: "conv3_2"
211 | param {
212 | lr_mult: 1
213 | decay_mult: 1
214 | }
215 | param {
216 | lr_mult: 2
217 | decay_mult: 0
218 | }
219 | convolution_param {
220 | num_output: 256
221 | pad: 1
222 | kernel_size: 3
223 | stride: 1
224 | weight_filler {
225 | type: "gaussian"
226 | std: 0.01
227 | }
228 | bias_filler {
229 | type: "constant"
230 | value: 0
231 | }
232 | }
233 | }
234 | layer {
235 | name: "relu3_2"
236 | type: "ReLU"
237 | bottom: "conv3_2"
238 | top: "conv3_2"
239 | }
240 | layer {
241 | name: "conv3_3"
242 | type: "Convolution"
243 | bottom: "conv3_2"
244 | top: "conv3_3"
245 | param {
246 | lr_mult: 1
247 | decay_mult: 1
248 | }
249 | param {
250 | lr_mult: 2
251 | decay_mult: 0
252 | }
253 | convolution_param {
254 | num_output: 256
255 | pad: 1
256 | kernel_size: 3
257 | stride: 1
258 | weight_filler {
259 | type: "gaussian"
260 | std: 0.01
261 | }
262 | bias_filler {
263 | type: "constant"
264 | value: 0
265 | }
266 | }
267 | }
268 | layer {
269 | name: "relu3_3"
270 | type: "ReLU"
271 | bottom: "conv3_3"
272 | top: "conv3_3"
273 | }
274 | layer {
275 | name: "pool3"
276 | type: "Pooling"
277 | bottom: "conv3_3"
278 | top: "pool3"
279 | pooling_param {
280 | pool: MAX
281 | kernel_size: 2
282 | stride: 2
283 | }
284 | }
285 | layer {
286 | name: "conv4_1"
287 | type: "Convolution"
288 | bottom: "pool3"
289 | top: "conv4_1"
290 | param {
291 | lr_mult: 1
292 | decay_mult: 1
293 | }
294 | param {
295 | lr_mult: 2
296 | decay_mult: 0
297 | }
298 | convolution_param {
299 | num_output: 512
300 | pad: 1
301 | kernel_size: 3
302 | stride: 1
303 | weight_filler {
304 | type: "gaussian"
305 | std: 0.01
306 | }
307 | bias_filler {
308 | type: "constant"
309 | value: 0
310 | }
311 | }
312 | }
313 | layer {
314 | name: "relu4_1"
315 | type: "ReLU"
316 | bottom: "conv4_1"
317 | top: "conv4_1"
318 | }
319 | layer {
320 | name: "conv4_2"
321 | type: "Convolution"
322 | bottom: "conv4_1"
323 | top: "conv4_2"
324 | param {
325 | lr_mult: 1
326 | decay_mult: 1
327 | }
328 | param {
329 | lr_mult: 2
330 | decay_mult: 0
331 | }
332 | convolution_param {
333 | num_output: 512
334 | pad: 1
335 | kernel_size: 3
336 | stride: 1
337 | weight_filler {
338 | type: "gaussian"
339 | std: 0.01
340 | }
341 | bias_filler {
342 | type: "constant"
343 | value: 0
344 | }
345 | }
346 | }
347 | layer {
348 | name: "relu4_2"
349 | type: "ReLU"
350 | bottom: "conv4_2"
351 | top: "conv4_2"
352 | }
353 | layer {
354 | name: "conv4_3"
355 | type: "Convolution"
356 | bottom: "conv4_2"
357 | top: "conv4_3"
358 | param {
359 | lr_mult: 1
360 | decay_mult: 1
361 | }
362 | param {
363 | lr_mult: 2
364 | decay_mult: 0
365 | }
366 | convolution_param {
367 | num_output: 512
368 | pad: 1
369 | kernel_size: 3
370 | stride: 1
371 | weight_filler {
372 | type: "gaussian"
373 | std: 0.01
374 | }
375 | bias_filler {
376 | type: "constant"
377 | value: 0
378 | }
379 | }
380 | }
381 | layer {
382 | name: "relu4_3"
383 | type: "ReLU"
384 | bottom: "conv4_3"
385 | top: "conv4_3"
386 | }
387 | layer {
388 | name: "pool4"
389 | type: "Pooling"
390 | bottom: "conv4_3"
391 | top: "pool4"
392 | pooling_param {
393 | pool: MAX
394 | kernel_size: 2
395 | stride: 2
396 | }
397 | }
398 | layer {
399 | name: "conv5_1"
400 | type: "Convolution"
401 | bottom: "pool4"
402 | top: "conv5_1"
403 | param {
404 | lr_mult: 1
405 | decay_mult: 1
406 | }
407 | param {
408 | lr_mult: 2
409 | decay_mult: 0
410 | }
411 | convolution_param {
412 | num_output: 512
413 | pad: 1
414 | kernel_size: 3
415 | stride: 1
416 | weight_filler {
417 | type: "gaussian"
418 | std: 0.01
419 | }
420 | bias_filler {
421 | type: "constant"
422 | value: 0
423 | }
424 | }
425 | }
426 | layer {
427 | name: "relu5_1"
428 | type: "ReLU"
429 | bottom: "conv5_1"
430 | top: "conv5_1"
431 | }
432 | layer {
433 | name: "conv5_2"
434 | type: "Convolution"
435 | bottom: "conv5_1"
436 | top: "conv5_2"
437 | param {
438 | lr_mult: 1
439 | decay_mult: 1
440 | }
441 | param {
442 | lr_mult: 2
443 | decay_mult: 0
444 | }
445 | convolution_param {
446 | num_output: 512
447 | pad: 1
448 | kernel_size: 3
449 | stride: 1
450 | weight_filler {
451 | type: "gaussian"
452 | std: 0.01
453 | }
454 | bias_filler {
455 | type: "constant"
456 | value: 0
457 | }
458 | }
459 | }
460 | layer {
461 | name: "relu5_2"
462 | type: "ReLU"
463 | bottom: "conv5_2"
464 | top: "conv5_2"
465 | }
466 | layer {
467 | name: "conv5_3"
468 | type: "Convolution"
469 | bottom: "conv5_2"
470 | top: "conv5_3"
471 | param {
472 | lr_mult: 1
473 | decay_mult: 1
474 | }
475 | param {
476 | lr_mult: 2
477 | decay_mult: 0
478 | }
479 | convolution_param {
480 | num_output: 512
481 | pad: 1
482 | kernel_size: 3
483 | stride: 1
484 | weight_filler {
485 | type: "gaussian"
486 | std: 0.01
487 | }
488 | bias_filler {
489 | type: "constant"
490 | value: 0
491 | }
492 | }
493 | }
494 | layer {
495 | name: "relu5_3"
496 | type: "ReLU"
497 | bottom: "conv5_3"
498 | top: "conv5_3"
499 | }
500 | layer {
501 | name: "pool5"
502 | type: "Pooling"
503 | bottom: "conv5_3"
504 | top: "pool5"
505 | pooling_param {
506 | pool: MAX
507 | kernel_size: 2
508 | stride: 2
509 | }
510 | }
511 | layer {
512 | name: "fc6_cs"
513 | type: "Convolution"
514 | bottom: "pool5"
515 | top: "fc6_cs"
516 | param {
517 | lr_mult: 1
518 | decay_mult: 1
519 | }
520 | param {
521 | lr_mult: 2
522 | decay_mult: 0
523 | }
524 | convolution_param {
525 | num_output: 4096
526 | pad: 0
527 | kernel_size: 7
528 | stride: 1
529 | weight_filler {
530 | type: "gaussian"
531 | std: 0.01
532 | }
533 | bias_filler {
534 | type: "constant"
535 | value: 0
536 | }
537 | }
538 | }
539 | layer {
540 | name: "relu6_cs"
541 | type: "ReLU"
542 | bottom: "fc6_cs"
543 | top: "fc6_cs"
544 | }
545 | layer {
546 | name: "fc7_cs"
547 | type: "Convolution"
548 | bottom: "fc6_cs"
549 | top: "fc7_cs"
550 | param {
551 | lr_mult: 1
552 | decay_mult: 1
553 | }
554 | param {
555 | lr_mult: 2
556 | decay_mult: 0
557 | }
558 | convolution_param {
559 | num_output: 4096
560 | pad: 0
561 | kernel_size: 1
562 | stride: 1
563 | weight_filler {
564 | type: "gaussian"
565 | std: 0.01
566 | }
567 | bias_filler {
568 | type: "constant"
569 | value: 0
570 | }
571 | }
572 | }
573 | layer {
574 | name: "relu7_cs"
575 | type: "ReLU"
576 | bottom: "fc7_cs"
577 | top: "fc7_cs"
578 | }
579 | layer {
580 | name: "score_fr"
581 | type: "Convolution"
582 | bottom: "fc7_cs"
583 | top: "score_fr"
584 | param {
585 | lr_mult: 1
586 | decay_mult: 1
587 | }
588 | param {
589 | lr_mult: 2
590 | decay_mult: 0
591 | }
592 | convolution_param {
593 | num_output: 20
594 | pad: 0
595 | kernel_size: 1
596 | weight_filler {
597 | type: "xavier"
598 | }
599 | bias_filler {
600 | type: "constant"
601 | }
602 | }
603 | }
604 | layer {
605 | name: "upscore2"
606 | type: "Deconvolution"
607 | bottom: "score_fr"
608 | top: "upscore2"
609 | param {
610 | lr_mult: 1
611 | }
612 | convolution_param {
613 | num_output: 20
614 | bias_term: false
615 | kernel_size: 4
616 | stride: 2
617 | weight_filler {
618 | type: "xavier"
619 | }
620 | bias_filler {
621 | type: "constant"
622 | }
623 | }
624 | }
625 | layer {
626 | name: "score_pool4"
627 | type: "Convolution"
628 | bottom: "pool4"
629 | top: "score_pool4"
630 | param {
631 | lr_mult: 1
632 | decay_mult: 1
633 | }
634 | param {
635 | lr_mult: 2
636 | decay_mult: 0
637 | }
638 | convolution_param {
639 | num_output: 20
640 | pad: 0
641 | kernel_size: 1
642 | weight_filler {
643 | type: "xavier"
644 | }
645 | bias_filler {
646 | type: "constant"
647 | }
648 | }
649 | }
650 | layer {
651 | name: "score_pool4c"
652 | type: "Crop"
653 | bottom: "score_pool4"
654 | bottom: "upscore2"
655 | top: "score_pool4c"
656 | crop_param {
657 | axis: 2
658 | offset: 5
659 | }
660 | }
661 | layer {
662 | name: "fuse_pool4"
663 | type: "Eltwise"
664 | bottom: "upscore2"
665 | bottom: "score_pool4c"
666 | top: "fuse_pool4"
667 | eltwise_param {
668 | operation: SUM
669 | }
670 | }
671 | layer {
672 | name: "upscore_pool4"
673 | type: "Deconvolution"
674 | bottom: "fuse_pool4"
675 | top: "upscore_pool4"
676 | param {
677 | lr_mult: 1
678 | }
679 | convolution_param {
680 | num_output: 20
681 | bias_term: false
682 | kernel_size: 4
683 | stride: 2
684 | weight_filler {
685 | type: "xavier"
686 | }
687 | bias_filler {
688 | type: "constant"
689 | }
690 | }
691 | }
692 | layer {
693 | name: "score_pool3"
694 | type: "Convolution"
695 | bottom: "pool3"
696 | top: "score_pool3"
697 | param {
698 | lr_mult: 1
699 | decay_mult: 1
700 | }
701 | param {
702 | lr_mult: 2
703 | decay_mult: 0
704 | }
705 | convolution_param {
706 | num_output: 20
707 | pad: 0
708 | kernel_size: 1
709 | weight_filler {
710 | type: "xavier"
711 | }
712 | bias_filler {
713 | type: "constant"
714 | }
715 | }
716 | }
717 | layer {
718 | name: "score_pool3c"
719 | type: "Crop"
720 | bottom: "score_pool3"
721 | bottom: "upscore_pool4"
722 | top: "score_pool3c"
723 | crop_param {
724 | axis: 2
725 | offset: 9
726 | }
727 | }
728 | layer {
729 | name: "fuse_pool3"
730 | type: "Eltwise"
731 | bottom: "upscore_pool4"
732 | bottom: "score_pool3c"
733 | top: "fuse_pool3"
734 | eltwise_param {
735 | operation: SUM
736 | }
737 | }
738 | layer {
739 | name: "upscore8"
740 | type: "Deconvolution"
741 | bottom: "fuse_pool3"
742 | top: "upscore8"
743 | param {
744 | lr_mult: 1
745 | }
746 | convolution_param {
747 | num_output: 20
748 | bias_term: false
749 | kernel_size: 16
750 | stride: 8
751 | weight_filler {
752 | type: "xavier"
753 | }
754 | bias_filler {
755 | type: "constant"
756 | }
757 | }
758 | }
759 | layer {
760 | name: "score"
761 | type: "Crop"
762 | bottom: "upscore8"
763 | bottom: "data"
764 | top: "score"
765 | crop_param {
766 | axis: 2
767 | offset: 31
768 | }
769 | }
770 |
--------------------------------------------------------------------------------
/scripts/eval_cityscapes/cityscapes.py:
--------------------------------------------------------------------------------
1 | # The following code is modified from https://github.com/shelhamer/clockwork-fcn
2 | import sys
3 | import os
4 | import glob
5 | import numpy as np
6 | from PIL import Image
7 |
8 | class cityscapes:
9 | def __init__(self, data_path):
10 | # data_path something like /data2/cityscapes
11 | self.dir = data_path
12 | self.classes = ['road', 'sidewalk', 'building', 'wall', 'fence',
13 | 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain',
14 | 'sky', 'person', 'rider', 'car', 'truck',
15 | 'bus', 'train', 'motorcycle', 'bicycle']
16 | self.mean = np.array((72.78044, 83.21195, 73.45286), dtype=np.float32)
17 | # import cityscapes label helper and set up label mappings
18 | sys.path.insert(0, '{}/scripts/helpers/'.format(self.dir))
19 | labels = __import__('labels')
20 | self.id2trainId = {label.id: label.trainId for label in labels.labels} # dictionary mapping from raw IDs to train IDs
21 | self.trainId2color = {label.trainId: label.color for label in labels.labels} # dictionary mapping train IDs to colors as 3-tuples
22 |
23 | def get_dset(self, split):
24 | '''
25 | List images as (city, id) for the specified split
26 |
27 | TODO(shelhamer) generate splits from cityscapes itself, instead of
28 | relying on these separately made text files.
29 | '''
30 | if split == 'train':
31 | dataset = open('{}/ImageSets/segFine/train.txt'.format(self.dir)).read().splitlines()
32 | else:
33 | dataset = open('{}/ImageSets/segFine/val.txt'.format(self.dir)).read().splitlines()
34 | return [(item.split('/')[0], item.split('/')[1]) for item in dataset]
35 |
36 | def load_image(self, split, city, idx):
37 | im = Image.open('{}/leftImg8bit_sequence/{}/{}/{}_leftImg8bit.png'.format(self.dir, split, city, idx))
38 | return im
39 |
40 | def assign_trainIds(self, label):
41 | """
42 | Map the given label IDs to the train IDs appropriate for training
43 | Use the label mapping provided in labels.py from the cityscapes scripts
44 | """
45 | label = np.array(label, dtype=np.float32)
46 | if sys.version_info[0] < 3:
47 | for k, v in self.id2trainId.iteritems():
48 | label[label == k] = v
49 | else:
50 | for k, v in self.id2trainId.items():
51 | label[label == k] = v
52 | return label
53 |
54 | def load_label(self, split, city, idx):
55 | """
56 | Load label image as 1 x height x width integer array of label indices.
57 | The leading singleton dimension is required by the loss.
58 | """
59 | label = Image.open('{}/gtFine/{}/{}/{}_gtFine_labelIds.png'.format(self.dir, split, city, idx))
60 | label = self.assign_trainIds(label) # get proper labels for eval
61 | label = np.array(label, dtype=np.uint8)
62 | label = label[np.newaxis, ...]
63 | return label
64 |
65 | def preprocess(self, im):
66 | """
67 | Preprocess loaded image (by load_image) for Caffe:
68 | - cast to float
69 | - switch channels RGB -> BGR
70 | - subtract mean
71 | - transpose to channel x height x width order
72 | """
73 | in_ = np.array(im, dtype=np.float32)
74 | in_ = in_[:, :, ::-1]
75 | in_ -= self.mean
76 | in_ = in_.transpose((2, 0, 1))
77 | return in_
78 |
79 | def palette(self, label):
80 | '''
81 | Map trainIds to colors as specified in labels.py
82 | '''
83 | if label.ndim == 3:
84 | label= label[0]
85 | color = np.empty((label.shape[0], label.shape[1], 3))
86 | if sys.version_info[0] < 3:
87 | for k, v in self.trainId2color.iteritems():
88 | color[label == k, :] = v
89 | else:
90 | for k, v in self.trainId2color.items():
91 | color[label == k, :] = v
92 | return color
93 |
94 | def make_boundaries(label, thickness=None):
95 | """
96 | Input is an image label, output is a numpy array mask encoding the boundaries of the objects
97 | Extract pixels at the true boundary by dilation - erosion of label.
98 | Don't just pick the void label as it is not exclusive to the boundaries.
99 | """
100 | assert(thickness is not None)
101 | import skimage.morphology as skm
102 | void = 255
103 | mask = np.logical_and(label > 0, label != void)[0]
104 | selem = skm.disk(thickness)
105 | boundaries = np.logical_xor(skm.dilation(mask, selem),
106 | skm.erosion(mask, selem))
107 | return boundaries
108 |
109 | def list_label_frames(self, split):
110 | """
111 | Select labeled frames from a split for evaluation
112 | collected as (city, shot, idx) tuples
113 | """
114 | def file2idx(f):
115 | """Helper to convert file path into frame ID"""
116 | city, shot, frame = (os.path.basename(f).split('_')[:3])
117 | return "_".join([city, shot, frame])
118 | frames = []
119 | cities = [os.path.basename(f) for f in glob.glob('{}/gtFine/{}/*'.format(self.dir, split))]
120 | for c in cities:
121 | files = sorted(glob.glob('{}/gtFine/{}/{}/*labelIds.png'.format(self.dir, split, c)))
122 | frames.extend([file2idx(f) for f in files])
123 | return frames
124 |
125 | def collect_frame_sequence(self, split, idx, length):
126 | """
127 | Collect sequence of frames preceding (and including) a labeled frame
128 | as a list of Images.
129 |
130 | Note: 19 preceding frames are provided for each labeled frame.
131 | """
132 | SEQ_LEN = length
133 | city, shot, frame = idx.split('_')
134 | frame = int(frame)
135 | frame_seq = []
136 | for i in range(frame - SEQ_LEN, frame + 1):
137 | frame_path = '{0}/leftImg8bit_sequence/val/{1}/{1}_{2}_{3:0>6d}_leftImg8bit.png'.format(
138 | self.dir, city, shot, i)
139 | frame_seq.append(Image.open(frame_path))
140 | return frame_seq
141 |
--------------------------------------------------------------------------------
/scripts/eval_cityscapes/download_fcn8s.sh:
--------------------------------------------------------------------------------
1 | URL=http://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/fcn-8s-cityscapes/fcn-8s-cityscapes.caffemodel
2 | OUTPUT_FILE=./scripts/eval_cityscapes/caffemodel/fcn-8s-cityscapes.caffemodel
3 | wget -N $URL -O $OUTPUT_FILE
4 |
--------------------------------------------------------------------------------
/scripts/eval_cityscapes/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import caffe
4 | import argparse
5 | import numpy as np
6 | import scipy.misc
7 | from PIL import Image
8 | from util import *
9 | from cityscapes import cityscapes
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("--cityscapes_dir", type=str, required=True, help="Path to the original cityscapes dataset")
13 | parser.add_argument("--result_dir", type=str, required=True, help="Path to the generated images to be evaluated")
14 | parser.add_argument("--output_dir", type=str, required=True, help="Where to save the evaluation results")
15 | parser.add_argument("--caffemodel_dir", type=str, default='./scripts/eval_cityscapes/caffemodel/', help="Where the FCN-8s caffemodel stored")
16 | parser.add_argument("--gpu_id", type=int, default=0, help="Which gpu id to use")
17 | parser.add_argument("--split", type=str, default='val', help="Data split to be evaluated")
18 | parser.add_argument("--save_output_images", type=int, default=0, help="Whether to save the FCN output images")
19 | args = parser.parse_args()
20 |
21 | def main():
22 | if not os.path.isdir(args.output_dir):
23 | os.makedirs(args.output_dir)
24 | if args.save_output_images > 0:
25 | output_image_dir = args.output_dir + 'image_outputs/'
26 | if not os.path.isdir(output_image_dir):
27 | os.makedirs(output_image_dir)
28 | CS = cityscapes(args.cityscapes_dir)
29 | n_cl = len(CS.classes)
30 | label_frames = CS.list_label_frames(args.split)
31 | caffe.set_device(args.gpu_id)
32 | caffe.set_mode_gpu()
33 | net = caffe.Net(args.caffemodel_dir + '/deploy.prototxt',
34 | args.caffemodel_dir + 'fcn-8s-cityscapes.caffemodel',
35 | caffe.TEST)
36 |
37 | hist_perframe = np.zeros((n_cl, n_cl))
38 | for i, idx in enumerate(label_frames):
39 | if i % 10 == 0:
40 | print('Evaluating: %d/%d' % (i, len(label_frames)))
41 | city = idx.split('_')[0]
42 | # idx is city_shot_frame
43 | label = CS.load_label(args.split, city, idx)
44 | im_file = args.result_dir + '/' + idx + '_leftImg8bit.png'
45 | im = np.array(Image.open(im_file))
46 | # im = scipy.misc.imresize(im, (256, 256))
47 | im = scipy.misc.imresize(im, (label.shape[1], label.shape[2]))
48 | out = segrun(net, CS.preprocess(im))
49 | hist_perframe += fast_hist(label.flatten(), out.flatten(), n_cl)
50 | if args.save_output_images > 0:
51 | label_im = CS.palette(label)
52 | pred_im = CS.palette(out)
53 | scipy.misc.imsave(output_image_dir + '/' + str(i) + '_pred.jpg', pred_im)
54 | scipy.misc.imsave(output_image_dir + '/' + str(i) + '_gt.jpg', label_im)
55 | scipy.misc.imsave(output_image_dir + '/' + str(i) + '_input.jpg', im)
56 |
57 | mean_pixel_acc, mean_class_acc, mean_class_iou, per_class_acc, per_class_iou = get_scores(hist_perframe)
58 | with open(args.output_dir + '/evaluation_results.txt', 'w') as f:
59 | f.write('Mean pixel accuracy: %f\n' % mean_pixel_acc)
60 | f.write('Mean class accuracy: %f\n' % mean_class_acc)
61 | f.write('Mean class IoU: %f\n' % mean_class_iou)
62 | f.write('************ Per class numbers below ************\n')
63 | for i, cl in enumerate(CS.classes):
64 | while len(cl) < 15:
65 | cl = cl + ' '
66 | f.write('%s: acc = %f, iou = %f\n' % (cl, per_class_acc[i], per_class_iou[i]))
67 | main()
--------------------------------------------------------------------------------
/scripts/eval_cityscapes/util.py:
--------------------------------------------------------------------------------
1 | # The following code is modified from https://github.com/shelhamer/clockwork-fcn
2 | import numpy as np
3 | import scipy.io as sio
4 |
5 | def get_out_scoremap(net):
6 | return net.blobs['score'].data[0].argmax(axis=0).astype(np.uint8)
7 |
8 | def feed_net(net, in_):
9 | """
10 | Load prepared input into net.
11 | """
12 | net.blobs['data'].reshape(1, *in_.shape)
13 | net.blobs['data'].data[...] = in_
14 |
15 | def segrun(net, in_):
16 | feed_net(net, in_)
17 | net.forward()
18 | return get_out_scoremap(net)
19 |
20 | def fast_hist(a, b, n):
21 | # print('saving')
22 | # sio.savemat('/tmp/fcn_debug/xx.mat', {'a':a, 'b':b, 'n':n})
23 |
24 | k = np.where((a >= 0) & (a < n))[0]
25 | bc = np.bincount(n * a[k].astype(int) + b[k], minlength=n**2)
26 | if len(bc) != n**2:
27 | # ignore this example if dimension mismatch
28 | return 0
29 | return bc.reshape(n, n)
30 |
31 | def get_scores(hist):
32 | # Mean pixel accuracy
33 | acc = np.diag(hist).sum() / (hist.sum() + 1e-12)
34 |
35 | # Per class accuracy
36 | cl_acc = np.diag(hist) / (hist.sum(1) + 1e-12)
37 |
38 | # Per class IoU
39 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + 1e-12)
40 |
41 | return acc, np.nanmean(cl_acc), np.nanmean(iu), cl_acc, iu
--------------------------------------------------------------------------------
/scripts/receptive_field_sizes.m:
--------------------------------------------------------------------------------
1 | % modified from: https://github.com/rbgirshick/rcnn/blob/master/utils/receptive_field_sizes.m
2 | %
3 | % RCNN LICENSE:
4 | %
5 | % Copyright (c) 2014, The Regents of the University of California (Regents)
6 | % All rights reserved.
7 | %
8 | % Redistribution and use in source and binary forms, with or without
9 | % modification, are permitted provided that the following conditions are met:
10 | %
11 | % 1. Redistributions of source code must retain the above copyright notice, this
12 | % list of conditions and the following disclaimer.
13 | % 2. Redistributions in binary form must reproduce the above copyright notice,
14 | % this list of conditions and the following disclaimer in the documentation
15 | % and/or other materials provided with the distribution.
16 | %
17 | % THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
18 | % ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19 | % WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20 | % DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
21 | % ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22 | % (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
23 | % LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
24 | % ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25 | % (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26 | % SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 |
28 | function receptive_field_sizes()
29 |
30 |
31 | % compute input size from a given output size
32 | f = @(output_size, ksize, stride) (output_size - 1) * stride + ksize;
33 |
34 |
35 | %% n=1 discriminator
36 |
37 | % fix the output size to 1 and derive the receptive field in the input
38 | out = ...
39 | f(f(f(1, 4, 1), ... % conv2 -> conv3
40 | 4, 1), ... % conv1 -> conv2
41 | 4, 2); % input -> conv1
42 |
43 | fprintf('n=1 discriminator receptive field size: %d\n', out);
44 |
45 |
46 | %% n=2 discriminator
47 |
48 | % fix the output size to 1 and derive the receptive field in the input
49 | out = ...
50 | f(f(f(f(1, 4, 1), ... % conv3 -> conv4
51 | 4, 1), ... % conv2 -> conv3
52 | 4, 2), ... % conv1 -> conv2
53 | 4, 2); % input -> conv1
54 |
55 | fprintf('n=2 discriminator receptive field size: %d\n', out);
56 |
57 |
58 | %% n=3 discriminator
59 |
60 | % fix the output size to 1 and derive the receptive field in the input
61 | out = ...
62 | f(f(f(f(f(1, 4, 1), ... % conv4 -> conv5
63 | 4, 1), ... % conv3 -> conv4
64 | 4, 2), ... % conv2 -> conv3
65 | 4, 2), ... % conv1 -> conv2
66 | 4, 2); % input -> conv1
67 |
68 | fprintf('n=3 discriminator receptive field size: %d\n', out);
69 |
70 |
71 | %% n=4 discriminator
72 |
73 | % fix the output size to 1 and derive the receptive field in the input
74 | out = ...
75 | f(f(f(f(f(f(1, 4, 1), ... % conv5 -> conv6
76 | 4, 1), ... % conv4 -> conv5
77 | 4, 2), ... % conv3 -> conv4
78 | 4, 2), ... % conv2 -> conv3
79 | 4, 2), ... % conv1 -> conv2
80 | 4, 2); % input -> conv1
81 |
82 | fprintf('n=4 discriminator receptive field size: %d\n', out);
83 |
84 |
85 | %% n=5 discriminator
86 |
87 | % fix the output size to 1 and derive the receptive field in the input
88 | out = ...
89 | f(f(f(f(f(f(f(1, 4, 1), ... % conv6 -> conv7
90 | 4, 1), ... % conv5 -> conv6
91 | 4, 2), ... % conv4 -> conv5
92 | 4, 2), ... % conv3 -> conv4
93 | 4, 2), ... % conv2 -> conv3
94 | 4, 2), ... % conv1 -> conv2
95 | 4, 2); % input -> conv1
96 |
97 | fprintf('n=5 discriminator receptive field size: %d\n', out);
--------------------------------------------------------------------------------
/test.lua:
--------------------------------------------------------------------------------
1 | -- usage: DATA_ROOT=/path/to/data/ name=expt1 which_direction=BtoA th test.lua
2 | --
3 | -- code derived from https://github.com/soumith/dcgan.torch
4 | --
5 |
6 | require 'image'
7 | require 'nn'
8 | require 'nngraph'
9 | util = paths.dofile('util/util.lua')
10 | torch.setdefaulttensortype('torch.FloatTensor')
11 |
12 | opt = {
13 | DATA_ROOT = '', -- path to images (should have subfolders 'train', 'val', etc)
14 | batchSize = 1, -- # images in batch
15 | loadSize = 256, -- scale images to this size
16 | fineSize = 256, -- then crop to this size
17 | flip=0, -- horizontal mirroring data augmentation
18 | display = 1, -- display samples while training. 0 = false
19 | display_id = 200, -- display window id.
20 | gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X
21 | how_many = 'all', -- how many test images to run (set to all to run on every image found in the data/phase folder)
22 | which_direction = 'AtoB', -- AtoB or BtoA
23 | phase = 'val', -- train, val, test ,etc
24 | preprocess = 'regular', -- for special purpose preprocessing, e.g., for colorization, change this (selects preprocessing functions in util.lua)
25 | aspect_ratio = 1.0, -- aspect ratio of result images
26 | name = '', -- name of experiment, selects which model to run, should generally should be passed on command line
27 | input_nc = 3, -- # of input image channels
28 | output_nc = 3, -- # of output image channels
29 | serial_batches = 1, -- if 1, takes images in order to make batches, otherwise takes them randomly
30 | serial_batch_iter = 1, -- iter into serial image list
31 | cudnn = 1, -- set to 0 to not use cudnn (untested)
32 | checkpoints_dir = './checkpoints', -- loads models from here
33 | results_dir='./results/', -- saves results here
34 | which_epoch = 'latest', -- which epoch to test? set to 'latest' to use latest cached model
35 | }
36 |
37 |
38 | -- one-line argument parser. parses enviroment variables to override the defaults
39 | for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end
40 | opt.nThreads = 1 -- test only works with 1 thread...
41 | print(opt)
42 | if opt.display == 0 then opt.display = false end
43 |
44 | opt.manualSeed = torch.random(1, 10000) -- set seed
45 | print("Random Seed: " .. opt.manualSeed)
46 | torch.manualSeed(opt.manualSeed)
47 | torch.setdefaulttensortype('torch.FloatTensor')
48 |
49 | opt.netG_name = opt.name .. '/' .. opt.which_epoch .. '_net_G'
50 |
51 | local data_loader = paths.dofile('data/data.lua')
52 | print('#threads...' .. opt.nThreads)
53 | local data = data_loader.new(opt.nThreads, opt)
54 | print("Dataset Size: ", data:size())
55 |
56 | -- translation direction
57 | local idx_A = nil
58 | local idx_B = nil
59 | local input_nc = opt.input_nc
60 | local output_nc = opt.output_nc
61 | if opt.which_direction=='AtoB' then
62 | idx_A = {1, input_nc}
63 | idx_B = {input_nc+1, input_nc+output_nc}
64 | elseif opt.which_direction=='BtoA' then
65 | idx_A = {input_nc+1, input_nc+output_nc}
66 | idx_B = {1, input_nc}
67 | else
68 | error(string.format('bad direction %s',opt.which_direction))
69 | end
70 | ----------------------------------------------------------------------------
71 |
72 | local input = torch.FloatTensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
73 | local target = torch.FloatTensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
74 |
75 | print('checkpoints_dir', opt.checkpoints_dir)
76 | local netG = util.load(paths.concat(opt.checkpoints_dir, opt.netG_name .. '.t7'), opt)
77 | --netG:evaluate()
78 |
79 | print(netG)
80 |
81 |
82 | function TableConcat(t1,t2)
83 | for i=1,#t2 do
84 | t1[#t1+1] = t2[i]
85 | end
86 | return t1
87 | end
88 |
89 | if opt.how_many=='all' then
90 | opt.how_many=data:size()
91 | end
92 | opt.how_many=math.min(opt.how_many, data:size())
93 |
94 | local filepaths = {} -- paths to images tested on
95 | for n=1,math.floor(opt.how_many/opt.batchSize) do
96 | print('processing batch ' .. n)
97 |
98 | local data_curr, filepaths_curr = data:getBatch()
99 | filepaths_curr = util.basename_batch(filepaths_curr)
100 | print('filepaths_curr: ', filepaths_curr)
101 |
102 | input = data_curr[{ {}, idx_A, {}, {} }]
103 | target = data_curr[{ {}, idx_B, {}, {} }]
104 |
105 | if opt.gpu > 0 then
106 | input = input:cuda()
107 | end
108 |
109 | if opt.preprocess == 'colorization' then
110 | local output_AB = netG:forward(input):float()
111 | local input_L = input:float()
112 | output = util.deprocessLAB_batch(input_L, output_AB)
113 | local target_AB = target:float()
114 | target = util.deprocessLAB_batch(input_L, target_AB)
115 | input = util.deprocessL_batch(input_L)
116 | else
117 | output = util.deprocess_batch(netG:forward(input))
118 | input = util.deprocess_batch(input):float()
119 | output = output:float()
120 | target = util.deprocess_batch(target):float()
121 | end
122 | paths.mkdir(paths.concat(opt.results_dir, opt.netG_name .. '_' .. opt.phase))
123 | local image_dir = paths.concat(opt.results_dir, opt.netG_name .. '_' .. opt.phase, 'images')
124 | paths.mkdir(image_dir)
125 | paths.mkdir(paths.concat(image_dir,'input'))
126 | paths.mkdir(paths.concat(image_dir,'output'))
127 | paths.mkdir(paths.concat(image_dir,'target'))
128 | for i=1, opt.batchSize do
129 | image.save(paths.concat(image_dir,'input',filepaths_curr[i]), image.scale(input[i],input[i]:size(2),input[i]:size(3)/opt.aspect_ratio))
130 | image.save(paths.concat(image_dir,'output',filepaths_curr[i]), image.scale(output[i],output[i]:size(2),output[i]:size(3)/opt.aspect_ratio))
131 | image.save(paths.concat(image_dir,'target',filepaths_curr[i]), image.scale(target[i],target[i]:size(2),target[i]:size(3)/opt.aspect_ratio))
132 | end
133 | print('Saved images to: ', image_dir)
134 |
135 | if opt.display then
136 | if opt.preprocess == 'regular' then
137 | disp = require 'display'
138 | disp.image(util.scaleBatch(input,100,100),{win=opt.display_id, title='input'})
139 | disp.image(util.scaleBatch(output,100,100),{win=opt.display_id+1, title='output'})
140 | disp.image(util.scaleBatch(target,100,100),{win=opt.display_id+2, title='target'})
141 |
142 | print('Displayed images')
143 | end
144 | end
145 |
146 | filepaths = TableConcat(filepaths, filepaths_curr)
147 | end
148 |
149 | -- make webpage
150 | io.output(paths.concat(opt.results_dir,opt.netG_name .. '_' .. opt.phase, 'index.html'))
151 |
152 | io.write('
Image # | Input | Output | Ground Truth |
' .. filepaths[i] .. ' | ') 158 | io.write('