├── .gitignore
├── LICENSE
├── README.md
├── data
├── __init__.py
├── aligned_dataset.py
├── base_dataset.py
├── image_folder.py
├── single_dataset.py
└── template_dataset.py
├── datasets
├── bibtex
│ ├── cityscapes.tex
│ ├── facades.tex
│ ├── handbags.tex
│ ├── night2day.tex
│ ├── shoes.tex
│ └── transattr.tex
├── download_dataset.sh
├── download_mini_dataset.sh
└── download_testset.sh
├── imgs
├── day2night.gif
├── results_matrix.jpg
└── teaser.jpg
├── models
├── __init__.py
├── base_model.py
├── bicycle_gan_model.py
├── networks.py
├── pix2pix_model.py
└── template_model.py
├── options
├── __init__.py
├── base_options.py
├── test_options.py
├── train_options.py
└── video_options.py
├── pretrained_models
└── download_model.sh
├── requirements.txt
├── scripts
├── check_all.sh
├── install_conda.sh
├── install_pip.sh
├── test_before_push.py
├── test_edges2handbags.sh
├── test_edges2shoes.sh
├── test_facades.sh
├── test_maps.sh
├── test_night2day.sh
├── train.sh
├── train_edges2shoes.sh
├── train_facades.sh
└── video_edges2shoes.sh
├── test.py
├── train.py
├── util
├── __init__.py
├── html.py
├── util.py
└── visualizer.py
└── video.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | *.fuse*
3 | *.pth
4 | *.pyc
5 | debug*
6 | datasets/
7 | videos/
8 | checkpoints/
9 | pretrained_models/
10 | results/
11 | build/
12 | dist/
13 | *.png
14 | torch.egg-info/
15 | */**/__pycache__
16 | torch/version.py
17 | torch/csrc/generic/TensorMethods.cpp
18 | torch/lib/*.so*
19 | torch/lib/*.dylib*
20 | torch/lib/*.h
21 | torch/lib/build
22 | torch/lib/tmp_install
23 | torch/lib/include
24 | torch/lib/torch_shm_manager
25 | torch/csrc/cudnn/cuDNN.cpp
26 | torch/csrc/nn/THNN.cwrap
27 | torch/csrc/nn/THNN.cpp
28 | torch/csrc/nn/THCUNN.cwrap
29 | torch/csrc/nn/THCUNN.cpp
30 | torch/csrc/nn/THNN_generic.cwrap
31 | torch/csrc/nn/THNN_generic.cpp
32 | torch/csrc/nn/THNN_generic.h
33 | docs/src/**/*
34 | test/data/legacy_modules.t7
35 | test/data/gpu_tensors.pt
36 | test/htmlcov
37 | test/.coverage
38 | */*.pyc
39 | */**/*.pyc
40 | */**/**/*.pyc
41 | */**/**/**/*.pyc
42 | */**/**/**/**/*.pyc
43 | */*.so*
44 | */**/*.so*
45 | */**/*.dylib*
46 | test/data/legacy_serialized.pt
47 | *~
48 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017, Jun-Yan Zhu, Richard Zhang, and Deepak Pathak
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 |
25 |
26 | --------------------------- LICENSE FOR pytorch-CycleGAN-and-pix2pix ----------------
27 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
28 | All rights reserved.
29 |
30 | Redistribution and use in source and binary forms, with or without
31 | modification, are permitted provided that the following conditions are met:
32 |
33 | * Redistributions of source code must retain the above copyright notice, this
34 | list of conditions and the following disclaimer.
35 |
36 | * Redistributions in binary form must reproduce the above copyright notice,
37 | this list of conditions and the following disclaimer in the documentation
38 | and/or other materials provided with the distribution.
39 |
40 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
41 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
42 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
43 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
44 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
45 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
46 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
47 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
48 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
49 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
50 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # BicycleGAN
6 | [Project Page](https://junyanz.github.io/BicycleGAN/) | [Paper](https://arxiv.org/abs/1711.11586) | [Video](https://youtu.be/JvGysD2EFhw)
7 |
8 |
9 | Pytorch implementation for multimodal image-to-image translation. For example, given the same night image, our model is able to synthesize possible day images with different types of lighting, sky and clouds. The training requires paired data.
10 |
11 | **Note**: The current software works well with PyTorch 0.41+. Check out the older [branch](https://github.com/junyanz/BicycleGAN/tree/pytorch0.3.1) that supports PyTorch 0.1-0.3.
12 |
13 |
14 |
15 | **Toward Multimodal Image-to-Image Translation.**
16 | [Jun-Yan Zhu](https://www.cs.cmu.edu/~junyanz/),
17 | [Richard Zhang](https://richzhang.github.io/), [Deepak Pathak](http://people.eecs.berkeley.edu/~pathak/), [Trevor Darrell](https://people.eecs.berkeley.edu/~trevor/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/), [Oliver Wang](http://www.oliverwang.info/), [Eli Shechtman](https://research.adobe.com/person/eli-shechtman/).
18 | UC Berkeley and Adobe Research
19 | In Neural Information Processing Systems, 2017.
20 |
21 | ## Example results
22 |
23 |
24 |
25 | ## Other Implementations
26 | - [[Tensorflow]](https://github.com/gitlimlab/BicycleGAN-Tensorflow) by Youngwoon Lee (USC CLVR Lab).
27 | - [[Tensorflow]](https://github.com/kvmanohar22/img2imgGAN) by Kv Manohar.
28 |
29 | ## Prerequisites
30 | - Linux or macOS
31 | - Python 3
32 | - CPU or NVIDIA GPU + CUDA CuDNN
33 |
34 |
35 | ## Getting Started ###
36 | ### Installation
37 | - Clone this repo:
38 | ```bash
39 | git clone -b master --single-branch https://github.com/junyanz/BicycleGAN.git
40 | cd BicycleGAN
41 | ```
42 | - Install PyTorch and dependencies from http://pytorch.org
43 | - Install python libraries [visdom](https://github.com/facebookresearch/visdom), [dominate](https://github.com/Knio/dominate), and [moviepy](https://github.com/Zulko/moviepy).
44 |
45 | For pip users:
46 | ```bash
47 | bash ./scripts/install_pip.sh
48 | ```
49 |
50 | For conda users:
51 | ```bash
52 | bash ./scripts/install_conda.sh
53 | ```
54 |
55 |
56 | ### Use a Pre-trained Model
57 | - Download some test photos (e.g., edges2shoes):
58 | ```bash
59 | bash ./datasets/download_testset.sh edges2shoes
60 | ```
61 | - Download a pre-trained model (e.g., edges2shoes):
62 | ```bash
63 | bash ./pretrained_models/download_model.sh edges2shoes
64 | ```
65 |
66 | - Generate results with the model
67 | ```bash
68 | bash ./scripts/test_edges2shoes.sh
69 | ```
70 | The test results will be saved to a html file here: `./results/edges2shoes/val/index.html`.
71 |
72 | - Generate results with synchronized latent vectors
73 | ```bash
74 | bash ./scripts/test_edges2shoes.sh --sync
75 | ```
76 | Results can be found at `./results/edges2shoes/val_sync/index.html`.
77 |
78 | ### Generate Morphing Videos
79 | - We can also produce a morphing video similar to this [GIF](imgs/day2night.gif) and Youtube [video](http://www.youtube.com/watch?v=JvGysD2EFhw&t=2m21s).
80 | ```bash
81 | bash ./scripts/video_edges2shoes.sh
82 | ```
83 | Results can be found at `./videos/edges2shoes/`.
84 |
85 | ### Model Training
86 | - To train a model, download the training images (e.g., edges2shoes).
87 | ```bash
88 | bash ./datasets/download_dataset.sh edges2shoes
89 | ```
90 |
91 | - Train a model:
92 | ```bash
93 | bash ./scripts/train_edges2shoes.sh
94 | ```
95 | - To view training results and loss plots, run `python -m visdom.server` and click the URL http://localhost:8097. To see more intermediate results, check out `./checkpoints/edges2shoes_bicycle_gan/web/index.html`
96 | - See more training details for other datasets in `./scripts/train.sh`.
97 |
98 | ### Datasets (from pix2pix)
99 | Download the datasets using the following script. Many of the datasets are collected by other researchers. Please cite their papers if you use the data.
100 | - Download the testset.
101 | ```bash
102 | bash ./datasets/download_testset.sh dataset_name
103 | ```
104 | - Download the training and testset.
105 | ```bash
106 | bash ./datasets/download_dataset.sh dataset_name
107 | ```
108 | - `facades`: 400 images from [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade). [[Citation](datasets/bibtex/facades.tex)]
109 | - `maps`: 1096 training images scraped from Google Maps
110 | - `edges2shoes`: 50k training images from [UT Zappos50K dataset](http://vision.cs.utexas.edu/projects/finegrained/utzap50k). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. [[Citation](datasets/bibtex/shoes.tex)]
111 | - `edges2handbags`: 137K Amazon Handbag images from [iGAN project](https://github.com/junyanz/iGAN). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. [[Citation](datasets/bibtex/handbags.tex)]
112 | - `night2day`: around 20K natural scene images from [Transient Attributes dataset](http://transattr.cs.brown.edu/) [[Citation](datasets/bibtex/transattr.tex)]
113 |
114 | ## Models
115 | Download the pre-trained models with the following script.
116 | ```bash
117 | bash ./pretrained_models/download_model.sh model_name
118 | ```
119 | - `edges2shoes` (edge -> photo) trained on UT Zappos50K dataset.
120 | - `edges2handbags` (edge -> photo) trained on Amazon handbags images..
121 | ```bash
122 | bash ./pretrained_models/download_model.sh edges2handbags
123 | bash ./datasets/download_testset.sh edges2handbags
124 | bash ./scripts/test_edges2handbags.sh
125 | ```
126 | - `night2day` (nighttime scene -> daytime scene) trained on around 100 [webcams](http://transattr.cs.brown.edu/).
127 | ```bash
128 | bash ./pretrained_models/download_model.sh night2day
129 | bash ./datasets/download_testset.sh night2day
130 | bash ./scripts/test_night2day.sh
131 | ```
132 | - `facades` (facade label -> facade photo) trained on the CMP Facades dataset.
133 | ```bash
134 | bash ./pretrained_models/download_model.sh facades
135 | bash ./datasets/download_testset.sh facades
136 | bash ./scripts/test_facades.sh
137 | ```
138 | - `maps` (map photo -> aerial photo) trained on 1096 training images scraped from Google Maps.
139 | ```bash
140 | bash ./pretrained_models/download_model.sh maps
141 | bash ./datasets/download_testset.sh maps
142 | bash ./scripts/test_maps.sh
143 | ```
144 |
145 | ### Metrics
146 |
147 | Figure 6 shows realism vs diversity of our method.
148 |
149 | - **Realism** We use the Amazon Mechanical Turk (AMT) Real vs Fake test from [this repository](https://github.com/phillipi/AMT_Real_vs_Fake), first introduced in [this work](http://richzhang.github.io/colorization/).
150 |
151 | - **Diversity** For each input image, we produce 20 translations by randomly sampling 20 `z` vectors. We compute LPIPS distance between consecutive pairs to get 19 paired distances. You can compute this by putting the 20 images into a directory and using [this script](https://github.com/richzhang/PerceptualSimilarity/blob/master/compute_dists_pair.py) (note that we used version 0.0 rather than default 0.1, so use flag `-v 0.0`). This is done for 100 input images. This results in 1900 total distances (100 images X 19 paired distances each), which are averaged together. A larger number means higher diversity.
152 |
153 | ### Citation
154 |
155 | If you find this useful for your research, please use the following.
156 |
157 | ```
158 | @inproceedings{zhu2017toward,
159 | title={Toward multimodal image-to-image translation},
160 | author={Zhu, Jun-Yan and Zhang, Richard and Pathak, Deepak and Darrell, Trevor and Efros, Alexei A and Wang, Oliver and Shechtman, Eli},
161 | booktitle={Advances in Neural Information Processing Systems},
162 | year={2017}
163 | }
164 |
165 | ```
166 | If you use modules from CycleGAN or pix2pix paper, please use the following:
167 | ```
168 | @inproceedings{CycleGAN2017,
169 | title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networkss},
170 | author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A},
171 | booktitle={Computer Vision (ICCV), 2017 IEEE International Conference on},
172 | year={2017}
173 | }
174 |
175 |
176 | @inproceedings{isola2017image,
177 | title={Image-to-Image Translation with Conditional Adversarial Networks},
178 | author={Isola, Phillip and Zhu, Jun-Yan and Zhou, Tinghui and Efros, Alexei A},
179 | booktitle={Computer Vision and Pattern Recognition (CVPR), 2017 IEEE Conference on},
180 | year={2017}
181 | }
182 | ```
183 | ### Acknowledgements
184 |
185 | This code borrows heavily from the [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) repository.
186 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes all the modules related to data loading and preprocessing
2 |
3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4 | You need to implement four functions:
5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6 | -- <__len__>: return the size of dataset.
7 | -- <__getitem__>: get a data point from data loader.
8 | -- : (optionally) add dataset-specific options and set default options.
9 |
10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11 | See our template dataset class 'template_dataset.py' for more details.
12 | """
13 | import importlib
14 | import torch.utils.data
15 | from data.base_dataset import BaseDataset
16 |
17 |
18 | def find_dataset_using_name(dataset_name):
19 | """Import the module "data/[dataset_name]_dataset.py".
20 |
21 | In the file, the class called DatasetNameDataset() will
22 | be instantiated. It has to be a subclass of BaseDataset,
23 | and it is case-insensitive.
24 | """
25 | dataset_filename = "data." + dataset_name + "_dataset"
26 | datasetlib = importlib.import_module(dataset_filename)
27 |
28 | dataset = None
29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
30 | for name, cls in datasetlib.__dict__.items():
31 | if name.lower() == target_dataset_name.lower() \
32 | and issubclass(cls, BaseDataset):
33 | dataset = cls
34 |
35 | if dataset is None:
36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
37 |
38 | return dataset
39 |
40 |
41 | def get_option_setter(dataset_name):
42 | """Return the static method of the dataset class."""
43 | dataset_class = find_dataset_using_name(dataset_name)
44 | return dataset_class.modify_commandline_options
45 |
46 |
47 | def create_dataset(opt):
48 | """Create a dataset given the option.
49 |
50 | This function wraps the class CustomDatasetDataLoader.
51 | This is the main interface between this package and 'train.py'/'test.py'
52 |
53 | Example:
54 | >>> from data import create_dataset
55 | >>> dataset = create_dataset(opt)
56 | """
57 | data_loader = CustomDatasetDataLoader(opt)
58 | dataset = data_loader.load_data()
59 | return dataset
60 |
61 |
62 | class CustomDatasetDataLoader():
63 | """Wrapper class of Dataset class that performs multi-threaded data loading"""
64 |
65 | def __init__(self, opt):
66 | """Initialize this class
67 |
68 | Step 1: create a dataset instance given the name [dataset_mode]
69 | Step 2: create a multi-threaded data loader.
70 | """
71 | self.opt = opt
72 | dataset_class = find_dataset_using_name(opt.dataset_mode)
73 | self.dataset = dataset_class(opt)
74 | print("dataset [%s] was created" % type(self.dataset).__name__)
75 | self.dataloader = torch.utils.data.DataLoader(
76 | self.dataset,
77 | batch_size=opt.batch_size,
78 | shuffle=not opt.serial_batches,
79 | num_workers=int(opt.num_threads))
80 |
81 | def load_data(self):
82 | return self
83 |
84 | def __len__(self):
85 | """Return the number of data in the dataset"""
86 | return min(len(self.dataset), self.opt.max_dataset_size)
87 |
88 | def __iter__(self):
89 | """Return a batch of data"""
90 | for i, data in enumerate(self.dataloader):
91 | if i * self.opt.batch_size >= self.opt.max_dataset_size:
92 | break
93 | yield data
94 |
--------------------------------------------------------------------------------
/data/aligned_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from data.base_dataset import BaseDataset, get_params, get_transform
3 | from data.image_folder import make_dataset
4 | from PIL import Image
5 |
6 |
7 | class AlignedDataset(BaseDataset):
8 | """A dataset class for paired image dataset.
9 |
10 | It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}.
11 | During test time, you need to prepare a directory '/path/to/data/test'.
12 | """
13 |
14 | def __init__(self, opt):
15 | """Initialize this dataset class.
16 |
17 | Parameters:
18 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
19 | """
20 | BaseDataset.__init__(self, opt)
21 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory
22 | self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths
23 | assert(self.opt.load_size >= self.opt.crop_size) # crop_size should be smaller than the size of loaded image
24 | self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
25 | self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc
26 |
27 | def __getitem__(self, index):
28 | """Return a data point and its metadata information.
29 |
30 | Parameters:
31 | index - - a random integer for data indexing
32 |
33 | Returns a dictionary that contains A, B, A_paths and B_paths
34 | A (tensor) - - an image in the input domain
35 | B (tensor) - - its corresponding image in the target domain
36 | A_paths (str) - - image paths
37 | B_paths (str) - - image paths (same as A_paths)
38 | """
39 | # read a image given a random integer index
40 | AB_path = self.AB_paths[index]
41 | AB = Image.open(AB_path).convert('RGB')
42 | # split AB image into A and B
43 | w, h = AB.size
44 | w2 = int(w / 2)
45 | A = AB.crop((0, 0, w2, h))
46 | B = AB.crop((w2, 0, w, h))
47 |
48 | # apply the same transform to both A and B
49 | transform_params = get_params(self.opt, A.size)
50 | A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1))
51 | B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1))
52 |
53 | A = A_transform(A)
54 | B = B_transform(B)
55 |
56 | return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
57 |
58 | def __len__(self):
59 | """Return the total number of images in the dataset."""
60 | return len(self.AB_paths)
61 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2 |
3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4 | """
5 | import random
6 | import numpy as np
7 | import torch.utils.data as data
8 | from PIL import Image
9 | import torchvision.transforms as transforms
10 | from abc import ABC, abstractmethod
11 |
12 |
13 | class BaseDataset(data.Dataset, ABC):
14 | """This class is an abstract base class (ABC) for datasets.
15 |
16 | To create a subclass, you need to implement the following four functions:
17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18 | -- <__len__>: return the size of dataset.
19 | -- <__getitem__>: get a data point.
20 | -- : (optionally) add dataset-specific options and set default options.
21 | """
22 |
23 | def __init__(self, opt):
24 | """Initialize the class; save the options in the class
25 |
26 | Parameters:
27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28 | """
29 | self.opt = opt
30 | self.root = opt.dataroot
31 |
32 | @staticmethod
33 | def modify_commandline_options(parser, is_train):
34 | """Add new dataset-specific options, and rewrite default values for existing options.
35 |
36 | Parameters:
37 | parser -- original option parser
38 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
39 |
40 | Returns:
41 | the modified parser.
42 | """
43 | return parser
44 |
45 | @abstractmethod
46 | def __len__(self):
47 | """Return the total number of images in the dataset."""
48 | return 0
49 |
50 | @abstractmethod
51 | def __getitem__(self, index):
52 | """Return a data point and its metadata information.
53 |
54 | Parameters:
55 | index - - a random integer for data indexing
56 |
57 | Returns:
58 | a dictionary of data with their names. It ususally contains the data itself and its metadata information.
59 | """
60 | pass
61 |
62 |
63 | def get_params(opt, size):
64 | w, h = size
65 | new_h = h
66 | new_w = w
67 | if opt.preprocess == 'resize_and_crop':
68 | new_h = new_w = opt.load_size
69 | elif opt.preprocess == 'scale_width_and_crop':
70 | new_w = opt.load_size
71 | new_h = opt.load_size * h // w
72 |
73 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
74 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
75 |
76 | flip = random.random() > 0.5
77 |
78 | return {'crop_pos': (x, y), 'flip': flip}
79 |
80 |
81 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
82 | transform_list = []
83 | if grayscale:
84 | transform_list.append(transforms.Grayscale(1))
85 | if 'resize' in opt.preprocess:
86 | osize = [opt.load_size, opt.load_size]
87 | transform_list.append(transforms.Resize(osize, method))
88 | elif 'scale_width' in opt.preprocess:
89 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
90 |
91 | if 'crop' in opt.preprocess:
92 | if params is None:
93 | transform_list.append(transforms.RandomCrop(opt.crop_size))
94 | else:
95 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
96 |
97 | if opt.preprocess == 'none':
98 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
99 |
100 | if not opt.no_flip:
101 | if params is None:
102 | transform_list.append(transforms.RandomHorizontalFlip())
103 | elif params['flip']:
104 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
105 |
106 | if convert:
107 | transform_list += [transforms.ToTensor()]
108 | if grayscale:
109 | transform_list += [transforms.Normalize((0.5,), (0.5,))]
110 | else:
111 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
112 | return transforms.Compose(transform_list)
113 |
114 |
115 | def __make_power_2(img, base, method=Image.BICUBIC):
116 | ow, oh = img.size
117 | h = int(round(oh / base) * base)
118 | w = int(round(ow / base) * base)
119 | if (h == oh) and (w == ow):
120 | return img
121 |
122 | __print_size_warning(ow, oh, w, h)
123 | return img.resize((w, h), method)
124 |
125 |
126 | def __scale_width(img, target_width, method=Image.BICUBIC):
127 | ow, oh = img.size
128 | if (ow == target_width):
129 | return img
130 | w = target_width
131 | h = int(target_width * oh / ow)
132 | return img.resize((w, h), method)
133 |
134 |
135 | def __crop(img, pos, size):
136 | ow, oh = img.size
137 | x1, y1 = pos
138 | tw = th = size
139 | if (ow > tw or oh > th):
140 | return img.crop((x1, y1, x1 + tw, y1 + th))
141 | return img
142 |
143 |
144 | def __flip(img, flip):
145 | if flip:
146 | return img.transpose(Image.FLIP_LEFT_RIGHT)
147 | return img
148 |
149 |
150 | def __print_size_warning(ow, oh, w, h):
151 | """Print warning information about image size(only print once)"""
152 | if not hasattr(__print_size_warning, 'has_printed'):
153 | print("The image size needs to be a multiple of 4. "
154 | "The loaded image size was (%d, %d), so it was adjusted to "
155 | "(%d, %d). This adjustment will be done to all images "
156 | "whose sizes are not multiples of 4" % (ow, oh, w, h))
157 | __print_size_warning.has_printed = True
158 |
--------------------------------------------------------------------------------
/data/image_folder.py:
--------------------------------------------------------------------------------
1 | """A modified image folder class
2 |
3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4 | so that this class can load images from both current directory and its subdirectories.
5 | """
6 |
7 | import torch.utils.data as data
8 |
9 | from PIL import Image
10 | import os
11 | import os.path
12 |
13 | IMG_EXTENSIONS = [
14 | '.jpg', '.JPG', '.jpeg', '.JPEG',
15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16 | ]
17 |
18 |
19 | def is_image_file(filename):
20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
21 |
22 |
23 | def make_dataset(dir, max_dataset_size=float("inf")):
24 | images = []
25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
26 |
27 | for root, _, fnames in sorted(os.walk(dir)):
28 | for fname in fnames:
29 | if is_image_file(fname):
30 | path = os.path.join(root, fname)
31 | images.append(path)
32 | return images[:min(max_dataset_size, len(images))]
33 |
34 |
35 | def default_loader(path):
36 | return Image.open(path).convert('RGB')
37 |
38 |
39 | class ImageFolder(data.Dataset):
40 |
41 | def __init__(self, root, transform=None, return_paths=False,
42 | loader=default_loader):
43 | imgs = make_dataset(root)
44 | if len(imgs) == 0:
45 | raise(RuntimeError("Found 0 images in: " + root + "\n"
46 | "Supported image extensions are: " +
47 | ",".join(IMG_EXTENSIONS)))
48 |
49 | self.root = root
50 | self.imgs = imgs
51 | self.transform = transform
52 | self.return_paths = return_paths
53 | self.loader = loader
54 |
55 | def __getitem__(self, index):
56 | path = self.imgs[index]
57 | img = self.loader(path)
58 | if self.transform is not None:
59 | img = self.transform(img)
60 | if self.return_paths:
61 | return img, path
62 | else:
63 | return img
64 |
65 | def __len__(self):
66 | return len(self.imgs)
67 |
--------------------------------------------------------------------------------
/data/single_dataset.py:
--------------------------------------------------------------------------------
1 | from data.base_dataset import BaseDataset, get_transform
2 | from data.image_folder import make_dataset
3 | from PIL import Image
4 |
5 |
6 | class SingleDataset(BaseDataset):
7 | """This dataset class can load a set of images specified by the path --dataroot /path/to/data.
8 |
9 | It can be used for generating CycleGAN results only for one side with the model option '-model test'.
10 | """
11 |
12 | def __init__(self, opt):
13 | """Initialize this dataset class.
14 |
15 | Parameters:
16 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
17 | """
18 | BaseDataset.__init__(self, opt)
19 | self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))
20 | input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
21 | self.transform = get_transform(opt, grayscale=(input_nc == 1))
22 |
23 | def __getitem__(self, index):
24 | """Return a data point and its metadata information.
25 |
26 | Parameters:
27 | index - - a random integer for data indexing
28 |
29 | Returns a dictionary that contains A and A_paths
30 | A(tensor) - - an image in one domain
31 | A_paths(str) - - the path of the image
32 | """
33 | A_path = self.A_paths[index]
34 | A_img = Image.open(A_path).convert('RGB')
35 | A = self.transform(A_img)
36 | return {'A': A, 'A_paths': A_path}
37 |
38 | def __len__(self):
39 | """Return the total number of images in the dataset."""
40 | return len(self.A_paths)
41 |
--------------------------------------------------------------------------------
/data/template_dataset.py:
--------------------------------------------------------------------------------
1 | """Dataset class template
2 |
3 | This module provides a template for users to implement custom datasets.
4 | You can specify '--dataset_mode template' to use this dataset.
5 | The class name should be consistent with both the filename and its dataset_mode option.
6 | The filename should be _dataset.py
7 | The class name should be Dataset.py
8 | You need to implement the following functions:
9 | -- : Add dataset-specific options and rewrite default values for existing options.
10 | -- <__init__>: Initialize this dataset class.
11 | -- <__getitem__>: Return a data point and its metadata information.
12 | -- <__len__>: Return the number of images.
13 | """
14 | from data.base_dataset import BaseDataset, get_transform
15 | # from data.image_folder import make_dataset
16 | # from PIL import Image
17 |
18 |
19 | class TemplateDataset(BaseDataset):
20 | """A template dataset class for you to implement custom datasets."""
21 | @staticmethod
22 | def modify_commandline_options(parser, is_train):
23 | """Add new dataset-specific options, and rewrite default values for existing options.
24 |
25 | Parameters:
26 | parser -- original option parser
27 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
28 |
29 | Returns:
30 | the modified parser.
31 | """
32 | parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
33 | parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
34 | return parser
35 |
36 | def __init__(self, opt):
37 | """Initialize this dataset class.
38 |
39 | Parameters:
40 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
41 |
42 | A few things can be done here.
43 | - save the options (have been done in BaseDataset)
44 | - get image paths and meta information of the dataset.
45 | - define the image transformation.
46 | """
47 | # save the option and dataset root
48 | BaseDataset.__init__(self, opt)
49 | # get the image paths of your dataset;
50 | self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
51 | # define the default transform function. You can use ; You can also define your custom transform function
52 | self.transform = get_transform(opt)
53 |
54 | def __getitem__(self, index):
55 | """Return a data point and its metadata information.
56 |
57 | Parameters:
58 | index -- a random integer for data indexing
59 |
60 | Returns:
61 | a dictionary of data with their names. It usually contains the data itself and its metadata information.
62 |
63 | Step 1: get a random image path: e.g., path = self.image_paths[index]
64 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
65 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
66 | Step 4: return a data point as a dictionary.
67 | """
68 | path = 'temp' # needs to be a string
69 | data_A = None # needs to be a tensor
70 | data_B = None # needs to be a tensor
71 | return {'data_A': data_A, 'data_B': data_B, 'path': path}
72 |
73 | def __len__(self):
74 | """Return the total number of images."""
75 | return len(self.image_paths)
76 |
--------------------------------------------------------------------------------
/datasets/bibtex/cityscapes.tex:
--------------------------------------------------------------------------------
1 | @inproceedings{Cordts2016Cityscapes,
2 | title={The Cityscapes Dataset for Semantic Urban Scene Understanding},
3 | author={Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt},
4 | booktitle={Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
5 | year={2016}
6 | }
7 |
--------------------------------------------------------------------------------
/datasets/bibtex/facades.tex:
--------------------------------------------------------------------------------
1 | @INPROCEEDINGS{Tylecek13,
2 | author = {Radim Tyle{\v c}ek, Radim {\v S}{\' a}ra},
3 | title = {Spatial Pattern Templates for Recognition of Objects with Regular Structure},
4 | booktitle = {Proc. GCPR},
5 | year = {2013},
6 | address = {Saarbrucken, Germany},
7 | }
8 |
--------------------------------------------------------------------------------
/datasets/bibtex/handbags.tex:
--------------------------------------------------------------------------------
1 | @inproceedings{zhu2016generative,
2 | title={Generative Visual Manipulation on the Natural Image Manifold},
3 | author={Zhu, Jun-Yan and Kr{\"a}henb{\"u}hl, Philipp and Shechtman, Eli and Efros, Alexei A.},
4 | booktitle={Proceedings of European Conference on Computer Vision (ECCV)},
5 | year={2016}
6 | }
7 |
8 | @InProceedings{xie15hed,
9 | author = {"Xie, Saining and Tu, Zhuowen"},
10 | Title = {Holistically-Nested Edge Detection},
11 | Booktitle = "Proceedings of IEEE International Conference on Computer Vision",
12 | Year = {2015},
13 | }
14 |
--------------------------------------------------------------------------------
/datasets/bibtex/night2day.tex:
--------------------------------------------------------------------------------
1 | @article{laffont2014transient,
2 | title={Transient attributes for high-level understanding and editing of outdoor scenes},
3 | author={Laffont, Pierre-Yves and Ren, Zhile and Tao, Xiaofeng and Qian, Chao and Hays, James},
4 | journal={ACM Transactions on Graphics (TOG)},
5 | volume={33},
6 | number={4},
7 | pages={149},
8 | year={2014},
9 | publisher={ACM}
10 | }
11 |
--------------------------------------------------------------------------------
/datasets/bibtex/shoes.tex:
--------------------------------------------------------------------------------
1 | @InProceedings{fine-grained,
2 | author = {A. Yu and K. Grauman},
3 | title = {{F}ine-{G}rained {V}isual {C}omparisons with {L}ocal {L}earning},
4 | booktitle = {Computer Vision and Pattern Recognition (CVPR)},
5 | month = {June},
6 | year = {2014}
7 | }
8 |
9 | @InProceedings{xie15hed,
10 | author = {"Xie, Saining and Tu, Zhuowen"},
11 | Title = {Holistically-Nested Edge Detection},
12 | Booktitle = "Proceedings of IEEE International Conference on Computer Vision",
13 | Year = {2015},
14 | }
15 |
--------------------------------------------------------------------------------
/datasets/bibtex/transattr.tex:
--------------------------------------------------------------------------------
1 | @article {Laffont14,
2 | title = {Transient Attributes for High-Level Understanding and Editing of Outdoor Scenes},
3 | author = {Pierre-Yves Laffont and Zhile Ren and Xiaofeng Tao and Chao Qian and James Hays},
4 | journal = {ACM Transactions on Graphics (proceedings of SIGGRAPH)},
5 | volume = {33},
6 | number = {4},
7 | year = {2014}
8 | }
9 |
--------------------------------------------------------------------------------
/datasets/download_dataset.sh:
--------------------------------------------------------------------------------
1 | FILE=$1
2 |
3 | if [[ $FILE != "cityscapes" && $FILE != "night2day" && $FILE != "edges2handbags" && $FILE != "edges2shoes" && $FILE != "facades" && $FILE != "maps" ]]; then
4 | echo "Available datasets are cityscapes, night2day, edges2handbags, edges2shoes, facades, maps"
5 | exit 1
6 | fi
7 |
8 | echo "Specified [$FILE]"
9 |
10 | URL=http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/$FILE.tar.gz
11 | TAR_FILE=./datasets/$FILE.tar.gz
12 | TARGET_DIR=./datasets/$FILE/
13 | wget -N $URL -O $TAR_FILE
14 | mkdir -p $TARGET_DIR
15 | tar -zxvf $TAR_FILE -C ./datasets/
16 | rm $TAR_FILE
17 |
--------------------------------------------------------------------------------
/datasets/download_mini_dataset.sh:
--------------------------------------------------------------------------------
1 | FILE=$1
2 | echo "Specified [$FILE]"
3 | URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
4 | ZIP_FILE=./datasets/$FILE.zip
5 | TARGET_DIR=./datasets/$FILE/
6 | wget -N $URL -O $ZIP_FILE
7 | mkdir $TARGET_DIR
8 | unzip $ZIP_FILE -d ./datasets/
9 | rm $ZIP_FILE
10 |
--------------------------------------------------------------------------------
/datasets/download_testset.sh:
--------------------------------------------------------------------------------
1 | FILE=$1
2 | URL=http://efrosgans.eecs.berkeley.edu/BicycleGAN/testset/${FILE}.tar.gz
3 | TAR_FILE=./datasets/$FILE.tar.gz
4 | TARGET_DIR=./datasets/$FILE/
5 | wget -N $URL -O $TAR_FILE
6 | mkdir $TARGET_DIR
7 | tar -zxvf $TAR_FILE -C ./datasets/
8 | rm $TAR_FILE
9 |
--------------------------------------------------------------------------------
/imgs/day2night.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyanz/BicycleGAN/40b9d52c27b9831f56c1c7c7a6ddde8bc9149067/imgs/day2night.gif
--------------------------------------------------------------------------------
/imgs/results_matrix.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyanz/BicycleGAN/40b9d52c27b9831f56c1c7c7a6ddde8bc9149067/imgs/results_matrix.jpg
--------------------------------------------------------------------------------
/imgs/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyanz/BicycleGAN/40b9d52c27b9831f56c1c7c7a6ddde8bc9149067/imgs/teaser.jpg
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | """This package contains modules related to objective functions, optimizations, and network architectures.
2 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
3 | You need to implement the following five functions:
4 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
5 | -- : unpack data from dataset and apply preprocessing.
6 | -- : produce intermediate results.
7 | -- : calculate loss, gradients, and update network weights.
8 | -- : (optionally) add model-specific options and set default options.
9 | In the function <__init__>, you need to define four lists:
10 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
11 | -- self.model_names (str list): specify the images that you want to display and save.
12 | -- self.visual_names (str list): define networks used in our training.
13 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
14 | Now you can use the model class by specifying flag '--model dummy'.
15 | See our template model class 'template_model.py' for an example.
16 | """
17 |
18 | import importlib
19 | from models.base_model import BaseModel
20 |
21 |
22 | def find_model_using_name(model_name):
23 | """Import the module "models/[model_name]_model.py".
24 | In the file, the class called DatasetNameModel() will
25 | be instantiated. It has to be a subclass of BaseModel,
26 | and it is case-insensitive.
27 | """
28 | model_filename = "models." + model_name + "_model"
29 | modellib = importlib.import_module(model_filename)
30 | model = None
31 | target_model_name = model_name.replace('_', '') + 'model'
32 | for name, cls in modellib.__dict__.items():
33 | if name.lower() == target_model_name.lower() \
34 | and issubclass(cls, BaseModel):
35 | model = cls
36 |
37 | if model is None:
38 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
39 | exit(0)
40 |
41 | return model
42 |
43 |
44 | def get_option_setter(model_name):
45 | """Return the static method of the model class."""
46 | model_class = find_model_using_name(model_name)
47 | return model_class.modify_commandline_options
48 |
49 |
50 | def create_model(opt):
51 | """Create a model given the option.
52 | This function warps the class CustomDatasetDataLoader.
53 | This is the main interface between this package and 'train.py'/'test.py'
54 | Example:
55 | >>> from models import create_model
56 | >>> model = create_model(opt)
57 | """
58 | model = find_model_using_name(opt.model)
59 | instance = model(opt)
60 | print("model [%s] was created" % type(instance).__name__)
61 | return instance
62 |
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | from abc import ABC, abstractmethod
5 | from . import networks
6 |
7 |
8 | class BaseModel(ABC):
9 | """This class is an abstract base class (ABC) for models.
10 | To create a subclass, you need to implement the following five functions:
11 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
12 | -- : unpack data from dataset and apply preprocessing.
13 | -- : produce intermediate results.
14 | -- : calculate losses, gradients, and update network weights.
15 | -- : (optionally) add model-specific options and set default options.
16 | """
17 |
18 | def __init__(self, opt):
19 | """Initialize the BaseModel class.
20 |
21 | Parameters:
22 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
23 |
24 | When creating your custom class, you need to implement your own initialization.
25 | In this fucntion, you should first call `BaseModel.__init__(self, opt)`
26 | Then, you need to define four lists:
27 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
28 | -- self.model_names (str list): specify the images that you want to display and save.
29 | -- self.visual_names (str list): define networks used in our training.
30 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
31 | """
32 | self.opt = opt
33 | self.gpu_ids = opt.gpu_ids
34 | self.isTrain = opt.isTrain
35 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
36 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
37 | if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
38 | torch.backends.cudnn.benchmark = True
39 | self.loss_names = []
40 | self.model_names = []
41 | self.visual_names = []
42 | self.optimizers = []
43 | self.image_paths = []
44 |
45 | @staticmethod
46 | def modify_commandline_options(parser, is_train):
47 | """Add new model-specific options, and rewrite default values for existing options.
48 |
49 | Parameters:
50 | parser -- original option parser
51 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
52 |
53 | Returns:
54 | the modified parser.
55 | """
56 | return parser
57 |
58 | @abstractmethod
59 | def set_input(self, input):
60 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
61 |
62 | Parameters:
63 | input (dict): includes the data itself and its metadata information.
64 | """
65 | pass
66 |
67 | @abstractmethod
68 | def forward(self):
69 | """Run forward pass; called by both functions and ."""
70 | pass
71 |
72 | def is_train(self):
73 | """check if the current batch is good for training."""
74 | return True
75 |
76 | @abstractmethod
77 | def optimize_parameters(self):
78 | """Calculate losses, gradients, and update network weights; called in every training iteration"""
79 | pass
80 |
81 | def setup(self, opt):
82 | """Load and print networks; create schedulers
83 |
84 | Parameters:
85 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
86 | """
87 | if self.isTrain:
88 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
89 | if not self.isTrain or opt.continue_train:
90 | self.load_networks(opt.epoch)
91 | self.print_networks(opt.verbose)
92 |
93 | def eval(self):
94 | """Make models eval mode during test time"""
95 | for name in self.model_names:
96 | if isinstance(name, str):
97 | net = getattr(self, 'net' + name)
98 | net.eval()
99 |
100 | def test(self):
101 | """Forward function used in test time.
102 |
103 | This function wraps function in no_grad() so we don't save intermediate steps for backprop
104 | It also calls to produce additional visualization results
105 | """
106 | with torch.no_grad():
107 | self.forward()
108 | self.compute_visuals()
109 |
110 | def compute_visuals(self):
111 | """Calculate additional output images for visdom and HTML visualization"""
112 | pass
113 |
114 | def get_image_paths(self):
115 | """ Return image paths that are used to load current data"""
116 | return self.image_paths
117 |
118 | def update_learning_rate(self):
119 | """Update learning rates for all the networks; called at the end of every epoch"""
120 | for scheduler in self.schedulers:
121 | scheduler.step()
122 | lr = self.optimizers[0].param_groups[0]['lr']
123 | print('learning rate = %.7f' % lr)
124 |
125 | def get_current_visuals(self):
126 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
127 | visual_ret = OrderedDict()
128 | for name in self.visual_names:
129 | if isinstance(name, str):
130 | visual_ret[name] = getattr(self, name)
131 | return visual_ret
132 |
133 | def get_current_losses(self):
134 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
135 | errors_ret = OrderedDict()
136 | for name in self.loss_names:
137 | if isinstance(name, str):
138 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
139 | return errors_ret
140 |
141 | def save_networks(self, epoch):
142 | """Save all the networks to the disk.
143 |
144 | Parameters:
145 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
146 | """
147 | for name in self.model_names:
148 | if isinstance(name, str):
149 | save_filename = '%s_net_%s.pth' % (epoch, name)
150 | save_path = os.path.join(self.save_dir, save_filename)
151 | net = getattr(self, 'net' + name)
152 |
153 | if len(self.gpu_ids) > 0 and torch.cuda.is_available():
154 | torch.save(net.module.cpu().state_dict(), save_path)
155 | net.cuda(self.gpu_ids[0])
156 | else:
157 | torch.save(net.cpu().state_dict(), save_path)
158 |
159 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
160 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
161 | key = keys[i]
162 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
163 | if module.__class__.__name__.startswith('InstanceNorm') and \
164 | (key == 'running_mean' or key == 'running_var'):
165 | if getattr(module, key) is None:
166 | state_dict.pop('.'.join(keys))
167 | if module.__class__.__name__.startswith('InstanceNorm') and \
168 | (key == 'num_batches_tracked'):
169 | state_dict.pop('.'.join(keys))
170 | else:
171 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
172 |
173 | def load_networks(self, epoch):
174 | """Load all the networks from the disk.
175 |
176 | Parameters:
177 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
178 | """
179 | for name in self.model_names:
180 | if isinstance(name, str):
181 | load_filename = '%s_net_%s.pth' % (epoch, name)
182 | load_path = os.path.join(self.save_dir, load_filename)
183 | net = getattr(self, 'net' + name)
184 | if isinstance(net, torch.nn.DataParallel):
185 | net = net.module
186 | print('loading the model from %s' % load_path)
187 | # if you are using PyTorch newer than 0.4 (e.g., built from
188 | # GitHub source), you can remove str() on self.device
189 | state_dict = torch.load(load_path, map_location=str(self.device))
190 | if hasattr(state_dict, '_metadata'):
191 | del state_dict._metadata
192 |
193 | # patch InstanceNorm checkpoints prior to 0.4
194 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
195 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
196 | net.load_state_dict(state_dict)
197 |
198 | def print_networks(self, verbose):
199 | """Print the total number of parameters in the network and (if verbose) network architecture
200 |
201 | Parameters:
202 | verbose (bool) -- if verbose: print the network architecture
203 | """
204 | print('---------- Networks initialized -------------')
205 | for name in self.model_names:
206 | if isinstance(name, str):
207 | net = getattr(self, 'net' + name)
208 | num_params = 0
209 | for param in net.parameters():
210 | num_params += param.numel()
211 | if verbose:
212 | print(net)
213 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
214 | print('-----------------------------------------------')
215 |
216 | def set_requires_grad(self, nets, requires_grad=False):
217 | """Set requires_grad=False for all the networks to avoid unnecessary computations
218 | Parameters:
219 | nets (network list) -- a list of networks
220 | requires_grad (bool) -- whether the networks require gradients or not
221 | """
222 | if not isinstance(nets, list):
223 | nets = [nets]
224 | for net in nets:
225 | if net is not None:
226 | for param in net.parameters():
227 | param.requires_grad = requires_grad
228 |
--------------------------------------------------------------------------------
/models/bicycle_gan_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base_model import BaseModel
3 | from . import networks
4 |
5 |
6 | class BiCycleGANModel(BaseModel):
7 | @staticmethod
8 | def modify_commandline_options(parser, is_train=True):
9 | return parser
10 |
11 | def __init__(self, opt):
12 | if opt.isTrain:
13 | assert opt.batch_size % 2 == 0 # load two images at one time.
14 |
15 | BaseModel.__init__(self, opt)
16 | # specify the training losses you want to print out. The program will call base_model.get_current_losses
17 | self.loss_names = ['G_GAN', 'D', 'G_GAN2', 'D2', 'G_L1', 'z_L1', 'kl']
18 | # specify the images you want to save/display. The program will call base_model.get_current_visuals
19 | self.visual_names = ['real_A_encoded', 'real_B_encoded', 'fake_B_random', 'fake_B_encoded']
20 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
21 | use_D = opt.isTrain and opt.lambda_GAN > 0.0
22 | use_D2 = opt.isTrain and opt.lambda_GAN2 > 0.0 and not opt.use_same_D
23 | use_E = opt.isTrain or not opt.no_encode
24 | use_vae = True
25 | self.model_names = ['G']
26 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.nz, opt.ngf, netG=opt.netG,
27 | norm=opt.norm, nl=opt.nl, use_dropout=opt.use_dropout, init_type=opt.init_type, init_gain=opt.init_gain,
28 | gpu_ids=self.gpu_ids, where_add=opt.where_add, upsample=opt.upsample)
29 | D_output_nc = opt.input_nc + opt.output_nc if opt.conditional_D else opt.output_nc
30 | if use_D:
31 | self.model_names += ['D']
32 | self.netD = networks.define_D(D_output_nc, opt.ndf, netD=opt.netD, norm=opt.norm, nl=opt.nl,
33 | init_type=opt.init_type, init_gain=opt.init_gain, num_Ds=opt.num_Ds, gpu_ids=self.gpu_ids)
34 | if use_D2:
35 | self.model_names += ['D2']
36 | self.netD2 = networks.define_D(D_output_nc, opt.ndf, netD=opt.netD2, norm=opt.norm, nl=opt.nl,
37 | init_type=opt.init_type, init_gain=opt.init_gain, num_Ds=opt.num_Ds, gpu_ids=self.gpu_ids)
38 | else:
39 | self.netD2 = None
40 | if use_E:
41 | self.model_names += ['E']
42 | self.netE = networks.define_E(opt.output_nc, opt.nz, opt.nef, netE=opt.netE, norm=opt.norm, nl=opt.nl,
43 | init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, vaeLike=use_vae)
44 |
45 | if opt.isTrain:
46 | self.criterionGAN = networks.GANLoss(gan_mode=opt.gan_mode).to(self.device)
47 | self.criterionL1 = torch.nn.L1Loss()
48 | self.criterionZ = torch.nn.L1Loss()
49 | # initialize optimizers
50 | self.optimizers = []
51 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
52 | self.optimizers.append(self.optimizer_G)
53 | if use_E:
54 | self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
55 | self.optimizers.append(self.optimizer_E)
56 |
57 | if use_D:
58 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
59 | self.optimizers.append(self.optimizer_D)
60 | if use_D2:
61 | self.optimizer_D2 = torch.optim.Adam(self.netD2.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
62 | self.optimizers.append(self.optimizer_D2)
63 |
64 | def is_train(self):
65 | """check if the current batch is good for training."""
66 | return self.opt.isTrain and self.real_A.size(0) == self.opt.batch_size
67 |
68 | def set_input(self, input):
69 | AtoB = self.opt.direction == 'AtoB'
70 | self.real_A = input['A' if AtoB else 'B'].to(self.device)
71 | self.real_B = input['B' if AtoB else 'A'].to(self.device)
72 | self.image_paths = input['A_paths' if AtoB else 'B_paths']
73 |
74 | def get_z_random(self, batch_size, nz, random_type='gauss'):
75 | if random_type == 'uni':
76 | z = torch.rand(batch_size, nz) * 2.0 - 1.0
77 | elif random_type == 'gauss':
78 | z = torch.randn(batch_size, nz)
79 | return z.detach().to(self.device)
80 |
81 | def encode(self, input_image):
82 | mu, logvar = self.netE.forward(input_image)
83 | std = logvar.mul(0.5).exp_()
84 | eps = self.get_z_random(std.size(0), std.size(1))
85 | z = eps.mul(std).add_(mu)
86 | return z, mu, logvar
87 |
88 | def test(self, z0=None, encode=False):
89 | with torch.no_grad():
90 | if encode: # use encoded z
91 | z0, _ = self.netE(self.real_B)
92 | if z0 is None:
93 | z0 = self.get_z_random(self.real_A.size(0), self.opt.nz)
94 | self.fake_B = self.netG(self.real_A, z0)
95 | return self.real_A, self.fake_B, self.real_B
96 |
97 | def forward(self):
98 | # get real images
99 | half_size = self.opt.batch_size // 2
100 | # A1, B1 for encoded; A2, B2 for random
101 | self.real_A_encoded = self.real_A[0:half_size]
102 | self.real_B_encoded = self.real_B[0:half_size]
103 | self.real_A_random = self.real_A[half_size:]
104 | self.real_B_random = self.real_B[half_size:]
105 | # get encoded z
106 | self.z_encoded, self.mu, self.logvar = self.encode(self.real_B_encoded)
107 | # get random z
108 | self.z_random = self.get_z_random(self.real_A_encoded.size(0), self.opt.nz)
109 | # generate fake_B_encoded
110 | self.fake_B_encoded = self.netG(self.real_A_encoded, self.z_encoded)
111 | # generate fake_B_random
112 | self.fake_B_random = self.netG(self.real_A_encoded, self.z_random)
113 | if self.opt.conditional_D: # tedious conditoinal data
114 | self.fake_data_encoded = torch.cat([self.real_A_encoded, self.fake_B_encoded], 1)
115 | self.real_data_encoded = torch.cat([self.real_A_encoded, self.real_B_encoded], 1)
116 | self.fake_data_random = torch.cat([self.real_A_encoded, self.fake_B_random], 1)
117 | self.real_data_random = torch.cat([self.real_A_random, self.real_B_random], 1)
118 | else:
119 | self.fake_data_encoded = self.fake_B_encoded
120 | self.fake_data_random = self.fake_B_random
121 | self.real_data_encoded = self.real_B_encoded
122 | self.real_data_random = self.real_B_random
123 |
124 | # compute z_predict
125 | if self.opt.lambda_z > 0.0:
126 | self.mu2, logvar2 = self.netE(self.fake_B_random) # mu2 is a point estimate
127 |
128 | def backward_D(self, netD, real, fake):
129 | # Fake, stop backprop to the generator by detaching fake_B
130 | pred_fake = netD(fake.detach())
131 | # real
132 | pred_real = netD(real)
133 | loss_D_fake, _ = self.criterionGAN(pred_fake, False)
134 | loss_D_real, _ = self.criterionGAN(pred_real, True)
135 | # Combined loss
136 | loss_D = loss_D_fake + loss_D_real
137 | loss_D.backward()
138 | return loss_D, [loss_D_fake, loss_D_real]
139 |
140 | def backward_G_GAN(self, fake, netD=None, ll=0.0):
141 | if ll > 0.0:
142 | pred_fake = netD(fake)
143 | loss_G_GAN, _ = self.criterionGAN(pred_fake, True)
144 | else:
145 | loss_G_GAN = 0
146 | return loss_G_GAN * ll
147 |
148 | def backward_EG(self):
149 | # 1, G(A) should fool D
150 | self.loss_G_GAN = self.backward_G_GAN(self.fake_data_encoded, self.netD, self.opt.lambda_GAN)
151 | if self.opt.use_same_D:
152 | self.loss_G_GAN2 = self.backward_G_GAN(self.fake_data_random, self.netD, self.opt.lambda_GAN2)
153 | else:
154 | self.loss_G_GAN2 = self.backward_G_GAN(self.fake_data_random, self.netD2, self.opt.lambda_GAN2)
155 | # 2. KL loss
156 | if self.opt.lambda_kl > 0.0:
157 | self.loss_kl = torch.sum(1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) * (-0.5 * self.opt.lambda_kl)
158 | else:
159 | self.loss_kl = 0
160 | # 3, reconstruction |fake_B-real_B|
161 | if self.opt.lambda_L1 > 0.0:
162 | self.loss_G_L1 = self.criterionL1(self.fake_B_encoded, self.real_B_encoded) * self.opt.lambda_L1
163 | else:
164 | self.loss_G_L1 = 0.0
165 |
166 | self.loss_G = self.loss_G_GAN + self.loss_G_GAN2 + self.loss_G_L1 + self.loss_kl
167 | self.loss_G.backward(retain_graph=True)
168 |
169 | def update_D(self):
170 | self.set_requires_grad([self.netD, self.netD2], True)
171 | # update D1
172 | if self.opt.lambda_GAN > 0.0:
173 | self.optimizer_D.zero_grad()
174 | self.loss_D, self.losses_D = self.backward_D(self.netD, self.real_data_encoded, self.fake_data_encoded)
175 | if self.opt.use_same_D:
176 | self.loss_D2, self.losses_D2 = self.backward_D(self.netD, self.real_data_random, self.fake_data_random)
177 | self.optimizer_D.step()
178 |
179 | if self.opt.lambda_GAN2 > 0.0 and not self.opt.use_same_D:
180 | self.optimizer_D2.zero_grad()
181 | self.loss_D2, self.losses_D2 = self.backward_D(self.netD2, self.real_data_random, self.fake_data_random)
182 | self.optimizer_D2.step()
183 |
184 | def backward_G_alone(self):
185 | # 3, reconstruction |(E(G(A, z_random)))-z_random|
186 | if self.opt.lambda_z > 0.0:
187 | self.loss_z_L1 = self.criterionZ(self.mu2, self.z_random) * self.opt.lambda_z
188 | self.loss_z_L1.backward()
189 | else:
190 | self.loss_z_L1 = 0.0
191 |
192 | def update_G_and_E(self):
193 | # update G and E
194 | self.set_requires_grad([self.netD, self.netD2], False)
195 | self.optimizer_E.zero_grad()
196 | self.optimizer_G.zero_grad()
197 | self.backward_EG()
198 |
199 | # update G alone
200 | if self.opt.lambda_z > 0.0:
201 | self.set_requires_grad([self.netE], False)
202 | self.backward_G_alone()
203 | self.set_requires_grad([self.netE], True)
204 |
205 | self.optimizer_E.step()
206 | self.optimizer_G.step()
207 |
208 | def optimize_parameters(self):
209 | self.forward()
210 | self.update_G_and_E()
211 | self.update_D()
212 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import functools
5 | from torch.optim import lr_scheduler
6 |
7 | ###############################################################################
8 | # Helper functions
9 | ###############################################################################
10 |
11 |
12 | def init_weights(net, init_type='normal', init_gain=0.02):
13 | """Initialize network weights.
14 | Parameters:
15 | net (network) -- network to be initialized
16 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
17 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
18 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
19 | work better for some applications. Feel free to try yourself.
20 | """
21 | def init_func(m): # define the initialization function
22 | classname = m.__class__.__name__
23 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
24 | if init_type == 'normal':
25 | init.normal_(m.weight.data, 0.0, init_gain)
26 | elif init_type == 'xavier':
27 | init.xavier_normal_(m.weight.data, gain=init_gain)
28 | elif init_type == 'kaiming':
29 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
30 | elif init_type == 'orthogonal':
31 | init.orthogonal_(m.weight.data, gain=init_gain)
32 | else:
33 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
34 | if hasattr(m, 'bias') and m.bias is not None:
35 | init.constant_(m.bias.data, 0.0)
36 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
37 | init.normal_(m.weight.data, 1.0, init_gain)
38 | init.constant_(m.bias.data, 0.0)
39 |
40 | print('initialize network with %s' % init_type)
41 | net.apply(init_func) # apply the initialization function
42 |
43 |
44 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
45 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
46 | Parameters:
47 | net (network) -- the network to be initialized
48 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
49 | gain (float) -- scaling factor for normal, xavier and orthogonal.
50 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
51 | Return an initialized network.
52 | """
53 | if len(gpu_ids) > 0:
54 | assert(torch.cuda.is_available())
55 | net.to(gpu_ids[0])
56 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
57 | init_weights(net, init_type, init_gain=init_gain)
58 | return net
59 |
60 |
61 | def get_scheduler(optimizer, opt):
62 | """Return a learning rate scheduler
63 | Parameters:
64 | optimizer -- the optimizer of the network
65 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
66 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
67 | For 'linear', we keep the same learning rate for the first epochs
68 | and linearly decay the rate to zero over the next epochs.
69 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
70 | See https://pytorch.org/docs/stable/optim.html for more details.
71 | """
72 | if opt.lr_policy == 'linear':
73 | def lambda_rule(epoch):
74 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
75 | return lr_l
76 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
77 | elif opt.lr_policy == 'step':
78 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
79 | elif opt.lr_policy == 'plateau':
80 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
81 | elif opt.lr_policy == 'cosine':
82 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
83 | else:
84 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
85 | return scheduler
86 |
87 |
88 | def get_norm_layer(norm_type='instance'):
89 | """Return a normalization layer
90 | Parameters:
91 | norm_type (str) -- the name of the normalization layer: batch | instance | none
92 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
93 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
94 | """
95 | if norm_type == 'batch':
96 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
97 | elif norm_type == 'instance':
98 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
99 | elif norm_type == 'none':
100 | norm_layer = None
101 | else:
102 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
103 | return norm_layer
104 |
105 |
106 | def get_non_linearity(layer_type='relu'):
107 | if layer_type == 'relu':
108 | nl_layer = functools.partial(nn.ReLU, inplace=True)
109 | elif layer_type == 'lrelu':
110 | nl_layer = functools.partial(
111 | nn.LeakyReLU, negative_slope=0.2, inplace=True)
112 | elif layer_type == 'elu':
113 | nl_layer = functools.partial(nn.ELU, inplace=True)
114 | else:
115 | raise NotImplementedError(
116 | 'nonlinearity activitation [%s] is not found' % layer_type)
117 | return nl_layer
118 |
119 |
120 | def define_G(input_nc, output_nc, nz, ngf, netG='unet_128', norm='batch', nl='relu',
121 | use_dropout=False, init_type='xavier', init_gain=0.02, gpu_ids=[], where_add='input', upsample='bilinear'):
122 | net = None
123 | norm_layer = get_norm_layer(norm_type=norm)
124 | nl_layer = get_non_linearity(layer_type=nl)
125 |
126 | if nz == 0:
127 | where_add = 'input'
128 |
129 | if netG == 'unet_128' and where_add == 'input':
130 | net = G_Unet_add_input(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
131 | use_dropout=use_dropout, upsample=upsample)
132 | elif netG == 'unet_256' and where_add == 'input':
133 | net = G_Unet_add_input(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
134 | use_dropout=use_dropout, upsample=upsample)
135 | elif netG == 'unet_128' and where_add == 'all':
136 | net = G_Unet_add_all(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
137 | use_dropout=use_dropout, upsample=upsample)
138 | elif netG == 'unet_256' and where_add == 'all':
139 | net = G_Unet_add_all(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
140 | use_dropout=use_dropout, upsample=upsample)
141 | else:
142 | raise NotImplementedError('Generator model name [%s] is not recognized' % net)
143 |
144 | return init_net(net, init_type, init_gain, gpu_ids)
145 |
146 |
147 | def define_D(input_nc, ndf, netD, norm='batch', nl='lrelu', init_type='xavier', init_gain=0.02, num_Ds=1, gpu_ids=[]):
148 | net = None
149 | norm_layer = get_norm_layer(norm_type=norm)
150 | nl = 'lrelu' # use leaky relu for D
151 | nl_layer = get_non_linearity(layer_type=nl)
152 |
153 | if netD == 'basic_128':
154 | net = D_NLayers(input_nc, ndf, n_layers=2, norm_layer=norm_layer, nl_layer=nl_layer)
155 | elif netD == 'basic_256':
156 | net = D_NLayers(input_nc, ndf, n_layers=3, norm_layer=norm_layer, nl_layer=nl_layer)
157 | elif netD == 'basic_128_multi':
158 | net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=2, norm_layer=norm_layer, num_D=num_Ds)
159 | elif netD == 'basic_256_multi':
160 | net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=3, norm_layer=norm_layer, num_D=num_Ds)
161 | else:
162 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
163 | return init_net(net, init_type, init_gain, gpu_ids)
164 |
165 |
166 | def define_E(input_nc, output_nc, ndf, netE,
167 | norm='batch', nl='lrelu',
168 | init_type='xavier', init_gain=0.02, gpu_ids=[], vaeLike=False):
169 | net = None
170 | norm_layer = get_norm_layer(norm_type=norm)
171 | nl = 'lrelu' # use leaky relu for E
172 | nl_layer = get_non_linearity(layer_type=nl)
173 | if netE == 'resnet_128':
174 | net = E_ResNet(input_nc, output_nc, ndf, n_blocks=4, norm_layer=norm_layer,
175 | nl_layer=nl_layer, vaeLike=vaeLike)
176 | elif netE == 'resnet_256':
177 | net = E_ResNet(input_nc, output_nc, ndf, n_blocks=5, norm_layer=norm_layer,
178 | nl_layer=nl_layer, vaeLike=vaeLike)
179 | elif netE == 'conv_128':
180 | net = E_NLayers(input_nc, output_nc, ndf, n_layers=4, norm_layer=norm_layer,
181 | nl_layer=nl_layer, vaeLike=vaeLike)
182 | elif netE == 'conv_256':
183 | net = E_NLayers(input_nc, output_nc, ndf, n_layers=5, norm_layer=norm_layer,
184 | nl_layer=nl_layer, vaeLike=vaeLike)
185 | else:
186 | raise NotImplementedError('Encoder model name [%s] is not recognized' % net)
187 |
188 | return init_net(net, init_type, init_gain, gpu_ids)
189 |
190 |
191 | class D_NLayersMulti(nn.Module):
192 | def __init__(self, input_nc, ndf=64, n_layers=3,
193 | norm_layer=nn.BatchNorm2d, num_D=1):
194 | super(D_NLayersMulti, self).__init__()
195 | # st()
196 | self.num_D = num_D
197 | if num_D == 1:
198 | layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
199 | self.model = nn.Sequential(*layers)
200 | else:
201 | layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
202 | self.add_module("model_0", nn.Sequential(*layers))
203 | self.down = nn.AvgPool2d(3, stride=2, padding=[
204 | 1, 1], count_include_pad=False)
205 | for i in range(1, num_D):
206 | ndf_i = int(round(ndf / (2**i)))
207 | layers = self.get_layers(input_nc, ndf_i, n_layers, norm_layer)
208 | self.add_module("model_%d" % i, nn.Sequential(*layers))
209 |
210 | def get_layers(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
211 | kw = 4
212 | padw = 1
213 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw,
214 | stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
215 |
216 | nf_mult = 1
217 | nf_mult_prev = 1
218 | for n in range(1, n_layers):
219 | nf_mult_prev = nf_mult
220 | nf_mult = min(2**n, 8)
221 | sequence += [
222 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
223 | kernel_size=kw, stride=2, padding=padw),
224 | norm_layer(ndf * nf_mult),
225 | nn.LeakyReLU(0.2, True)
226 | ]
227 |
228 | nf_mult_prev = nf_mult
229 | nf_mult = min(2**n_layers, 8)
230 | sequence += [
231 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
232 | kernel_size=kw, stride=1, padding=padw),
233 | norm_layer(ndf * nf_mult),
234 | nn.LeakyReLU(0.2, True)
235 | ]
236 |
237 | sequence += [nn.Conv2d(ndf * nf_mult, 1,
238 | kernel_size=kw, stride=1, padding=padw)]
239 |
240 | return sequence
241 |
242 | def forward(self, input):
243 | if self.num_D == 1:
244 | return self.model(input)
245 | result = []
246 | down = input
247 | for i in range(self.num_D):
248 | model = getattr(self, "model_%d" % i)
249 | result.append(model(down))
250 | if i != self.num_D - 1:
251 | down = self.down(down)
252 | return result
253 |
254 |
255 | class D_NLayers(nn.Module):
256 | """Defines a PatchGAN discriminator"""
257 |
258 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
259 | """Construct a PatchGAN discriminator
260 | Parameters:
261 | input_nc (int) -- the number of channels in input images
262 | ndf (int) -- the number of filters in the last conv layer
263 | n_layers (int) -- the number of conv layers in the discriminator
264 | norm_layer -- normalization layer
265 | """
266 | super(D_NLayers, self).__init__()
267 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
268 | use_bias = norm_layer.func != nn.BatchNorm2d
269 | else:
270 | use_bias = norm_layer != nn.BatchNorm2d
271 |
272 | kw = 4
273 | padw = 1
274 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
275 | nf_mult = 1
276 | nf_mult_prev = 1
277 | for n in range(1, n_layers): # gradually increase the number of filters
278 | nf_mult_prev = nf_mult
279 | nf_mult = min(2 ** n, 8)
280 | sequence += [
281 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
282 | norm_layer(ndf * nf_mult),
283 | nn.LeakyReLU(0.2, True)
284 | ]
285 |
286 | nf_mult_prev = nf_mult
287 | nf_mult = min(2 ** n_layers, 8)
288 | sequence += [
289 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
290 | norm_layer(ndf * nf_mult),
291 | nn.LeakyReLU(0.2, True)
292 | ]
293 |
294 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
295 | self.model = nn.Sequential(*sequence)
296 |
297 | def forward(self, input):
298 | """Standard forward."""
299 | return self.model(input)
300 |
301 |
302 | ##############################################################################
303 | # Classes
304 | ##############################################################################
305 | class RecLoss(nn.Module):
306 | def __init__(self, use_L2=True):
307 | super(RecLoss, self).__init__()
308 | self.use_L2 = use_L2
309 |
310 | def __call__(self, input, target, batch_mean=True):
311 | if self.use_L2:
312 | diff = (input - target) ** 2
313 | else:
314 | diff = torch.abs(input - target)
315 | if batch_mean:
316 | return torch.mean(diff)
317 | else:
318 | return torch.mean(torch.mean(torch.mean(diff, dim=1), dim=2), dim=3)
319 |
320 |
321 | # Defines the GAN loss which uses either LSGAN or the regular GAN.
322 | # When LSGAN is used, it is basically same as MSELoss,
323 | # but it abstracts away the need to create the target label tensor
324 | # that has the same size as the input
325 | class GANLoss(nn.Module):
326 | """Define different GAN objectives.
327 |
328 | The GANLoss class abstracts away the need to create the target label tensor
329 | that has the same size as the input.
330 | """
331 |
332 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
333 | """ Initialize the GANLoss class.
334 |
335 | Parameters:
336 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
337 | target_real_label (bool) - - label for a real image
338 | target_fake_label (bool) - - label of a fake image
339 |
340 | Note: Do not use sigmoid as the last layer of Discriminator.
341 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
342 | """
343 | super(GANLoss, self).__init__()
344 | self.register_buffer('real_label', torch.tensor(target_real_label))
345 | self.register_buffer('fake_label', torch.tensor(target_fake_label))
346 | self.gan_mode = gan_mode
347 | if gan_mode == 'lsgan':
348 | self.loss = nn.MSELoss()
349 | elif gan_mode == 'vanilla':
350 | self.loss = nn.BCEWithLogitsLoss()
351 | elif gan_mode in ['wgangp']:
352 | self.loss = None
353 | else:
354 | raise NotImplementedError('gan mode %s not implemented' % gan_mode)
355 |
356 | def get_target_tensor(self, prediction, target_is_real):
357 | """Create label tensors with the same size as the input.
358 |
359 | Parameters:
360 | prediction (tensor) - - tpyically the prediction from a discriminator
361 | target_is_real (bool) - - if the ground truth label is for real images or fake images
362 |
363 | Returns:
364 | A label tensor filled with ground truth label, and with the size of the input
365 | """
366 |
367 | if target_is_real:
368 | target_tensor = self.real_label
369 | else:
370 | target_tensor = self.fake_label
371 | return target_tensor.expand_as(prediction)
372 |
373 | def __call__(self, predictions, target_is_real):
374 | """Calculate loss given Discriminator's output and grount truth labels.
375 |
376 | Parameters:
377 | prediction (tensor list) - - tpyically the prediction output from a discriminator; supports multi Ds.
378 | target_is_real (bool) - - if the ground truth label is for real images or fake images
379 |
380 | Returns:
381 | the calculated loss.
382 | """
383 | all_losses = []
384 | for prediction in predictions:
385 | if self.gan_mode in ['lsgan', 'vanilla']:
386 | target_tensor = self.get_target_tensor(prediction, target_is_real)
387 | loss = self.loss(prediction, target_tensor)
388 | elif self.gan_mode == 'wgangp':
389 | if target_is_real:
390 | loss = -prediction.mean()
391 | else:
392 | loss = prediction.mean()
393 | all_losses.append(loss)
394 | total_loss = sum(all_losses)
395 | return total_loss, all_losses
396 |
397 |
398 | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
399 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
400 | Arguments:
401 | netD (network) -- discriminator network
402 | real_data (tensor array) -- real images
403 | fake_data (tensor array) -- generated images from the generator
404 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
405 | type (str) -- if we mix real and fake data or not [real | fake | mixed].
406 | constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
407 | lambda_gp (float) -- weight for this loss
408 | Returns the gradient penalty loss
409 | """
410 | if lambda_gp > 0.0:
411 | if type == 'real': # either use real images, fake images, or a linear interpolation of two.
412 | interpolatesv = real_data
413 | elif type == 'fake':
414 | interpolatesv = fake_data
415 | elif type == 'mixed':
416 | alpha = torch.rand(real_data.shape[0], 1)
417 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
418 | alpha = alpha.to(device)
419 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
420 | else:
421 | raise NotImplementedError('{} not implemented'.format(type))
422 | interpolatesv.requires_grad_(True)
423 | disc_interpolates = netD(interpolatesv)
424 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
425 | grad_outputs=torch.ones(disc_interpolates.size()).to(device),
426 | create_graph=True, retain_graph=True, only_inputs=True)
427 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data
428 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
429 | return gradient_penalty, gradients
430 | else:
431 | return 0.0, None
432 |
433 | # Defines the Unet generator.
434 | # |num_downs|: number of downsamplings in UNet. For example,
435 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1
436 | # at the bottleneck
437 |
438 |
439 | class G_Unet_add_input(nn.Module):
440 | def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
441 | norm_layer=None, nl_layer=None, use_dropout=False,
442 | upsample='basic'):
443 | super(G_Unet_add_input, self).__init__()
444 | self.nz = nz
445 | max_nchn = 8
446 | # construct unet structure
447 | unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn,
448 | innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
449 | for i in range(num_downs - 5):
450 | unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block,
451 | norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
452 | unet_block = UnetBlock(ngf * 4, ngf * 4, ngf * max_nchn, unet_block,
453 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
454 | unet_block = UnetBlock(ngf * 2, ngf * 2, ngf * 4, unet_block,
455 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
456 | unet_block = UnetBlock(ngf, ngf, ngf * 2, unet_block,
457 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
458 | unet_block = UnetBlock(input_nc + nz, output_nc, ngf, unet_block,
459 | outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
460 |
461 | self.model = unet_block
462 |
463 | def forward(self, x, z=None):
464 | if self.nz > 0:
465 | z_img = z.view(z.size(0), z.size(1), 1, 1).expand(
466 | z.size(0), z.size(1), x.size(2), x.size(3))
467 | x_with_z = torch.cat([x, z_img], 1)
468 | else:
469 | x_with_z = x # no z
470 |
471 | return self.model(x_with_z)
472 |
473 |
474 | def upsampleLayer(inplanes, outplanes, upsample='basic', padding_type='zero'):
475 | # padding_type = 'zero'
476 | if upsample == 'basic':
477 | upconv = [nn.ConvTranspose2d(
478 | inplanes, outplanes, kernel_size=4, stride=2, padding=1)]
479 | elif upsample == 'bilinear':
480 | upconv = [nn.Upsample(scale_factor=2, mode='bilinear'),
481 | nn.ReflectionPad2d(1),
482 | nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=1, padding=0)]
483 | else:
484 | raise NotImplementedError(
485 | 'upsample layer [%s] not implemented' % upsample)
486 | return upconv
487 |
488 |
489 | # Defines the submodule with skip connection.
490 | # X -------------------identity---------------------- X
491 | # |-- downsampling -- |submodule| -- upsampling --|
492 | class UnetBlock(nn.Module):
493 | def __init__(self, input_nc, outer_nc, inner_nc,
494 | submodule=None, outermost=False, innermost=False,
495 | norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='zero'):
496 | super(UnetBlock, self).__init__()
497 | self.outermost = outermost
498 | p = 0
499 | downconv = []
500 | if padding_type == 'reflect':
501 | downconv += [nn.ReflectionPad2d(1)]
502 | elif padding_type == 'replicate':
503 | downconv += [nn.ReplicationPad2d(1)]
504 | elif padding_type == 'zero':
505 | p = 1
506 | else:
507 | raise NotImplementedError(
508 | 'padding [%s] is not implemented' % padding_type)
509 | downconv += [nn.Conv2d(input_nc, inner_nc,
510 | kernel_size=4, stride=2, padding=p)]
511 | # downsample is different from upsample
512 | downrelu = nn.LeakyReLU(0.2, True)
513 | downnorm = norm_layer(inner_nc) if norm_layer is not None else None
514 | uprelu = nl_layer()
515 | upnorm = norm_layer(outer_nc) if norm_layer is not None else None
516 |
517 | if outermost:
518 | upconv = upsampleLayer(
519 | inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
520 | down = downconv
521 | up = [uprelu] + upconv + [nn.Tanh()]
522 | model = down + [submodule] + up
523 | elif innermost:
524 | upconv = upsampleLayer(
525 | inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
526 | down = [downrelu] + downconv
527 | up = [uprelu] + upconv
528 | if upnorm is not None:
529 | up += [upnorm]
530 | model = down + up
531 | else:
532 | upconv = upsampleLayer(
533 | inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
534 | down = [downrelu] + downconv
535 | if downnorm is not None:
536 | down += [downnorm]
537 | up = [uprelu] + upconv
538 | if upnorm is not None:
539 | up += [upnorm]
540 |
541 | if use_dropout:
542 | model = down + [submodule] + up + [nn.Dropout(0.5)]
543 | else:
544 | model = down + [submodule] + up
545 |
546 | self.model = nn.Sequential(*model)
547 |
548 | def forward(self, x):
549 | if self.outermost:
550 | return self.model(x)
551 | else:
552 | return torch.cat([self.model(x), x], 1)
553 |
554 |
555 | def conv3x3(in_planes, out_planes):
556 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1,
557 | padding=1, bias=True)
558 |
559 |
560 | # two usage cases, depend on kw and padw
561 | def upsampleConv(inplanes, outplanes, kw, padw):
562 | sequence = []
563 | sequence += [nn.Upsample(scale_factor=2, mode='nearest')]
564 | sequence += [nn.Conv2d(inplanes, outplanes, kernel_size=kw,
565 | stride=1, padding=padw, bias=True)]
566 | return nn.Sequential(*sequence)
567 |
568 |
569 | def meanpoolConv(inplanes, outplanes):
570 | sequence = []
571 | sequence += [nn.AvgPool2d(kernel_size=2, stride=2)]
572 | sequence += [nn.Conv2d(inplanes, outplanes,
573 | kernel_size=1, stride=1, padding=0, bias=True)]
574 | return nn.Sequential(*sequence)
575 |
576 |
577 | def convMeanpool(inplanes, outplanes):
578 | sequence = []
579 | sequence += [conv3x3(inplanes, outplanes)]
580 | sequence += [nn.AvgPool2d(kernel_size=2, stride=2)]
581 | return nn.Sequential(*sequence)
582 |
583 |
584 | class BasicBlockUp(nn.Module):
585 | def __init__(self, inplanes, outplanes, norm_layer=None, nl_layer=None):
586 | super(BasicBlockUp, self).__init__()
587 | layers = []
588 | if norm_layer is not None:
589 | layers += [norm_layer(inplanes)]
590 | layers += [nl_layer()]
591 | layers += [upsampleConv(inplanes, outplanes, kw=3, padw=1)]
592 | if norm_layer is not None:
593 | layers += [norm_layer(outplanes)]
594 | layers += [conv3x3(outplanes, outplanes)]
595 | self.conv = nn.Sequential(*layers)
596 | self.shortcut = upsampleConv(inplanes, outplanes, kw=1, padw=0)
597 |
598 | def forward(self, x):
599 | out = self.conv(x) + self.shortcut(x)
600 | return out
601 |
602 |
603 | class BasicBlock(nn.Module):
604 | def __init__(self, inplanes, outplanes, norm_layer=None, nl_layer=None):
605 | super(BasicBlock, self).__init__()
606 | layers = []
607 | if norm_layer is not None:
608 | layers += [norm_layer(inplanes)]
609 | layers += [nl_layer()]
610 | layers += [conv3x3(inplanes, inplanes)]
611 | if norm_layer is not None:
612 | layers += [norm_layer(inplanes)]
613 | layers += [nl_layer()]
614 | layers += [convMeanpool(inplanes, outplanes)]
615 | self.conv = nn.Sequential(*layers)
616 | self.shortcut = meanpoolConv(inplanes, outplanes)
617 |
618 | def forward(self, x):
619 | out = self.conv(x) + self.shortcut(x)
620 | return out
621 |
622 |
623 | class E_ResNet(nn.Module):
624 | def __init__(self, input_nc=3, output_nc=1, ndf=64, n_blocks=4,
625 | norm_layer=None, nl_layer=None, vaeLike=False):
626 | super(E_ResNet, self).__init__()
627 | self.vaeLike = vaeLike
628 | max_ndf = 4
629 | conv_layers = [
630 | nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1, bias=True)]
631 | for n in range(1, n_blocks):
632 | input_ndf = ndf * min(max_ndf, n)
633 | output_ndf = ndf * min(max_ndf, n + 1)
634 | conv_layers += [BasicBlock(input_ndf,
635 | output_ndf, norm_layer, nl_layer)]
636 | conv_layers += [nl_layer(), nn.AvgPool2d(8)]
637 | if vaeLike:
638 | self.fc = nn.Sequential(*[nn.Linear(output_ndf, output_nc)])
639 | self.fcVar = nn.Sequential(*[nn.Linear(output_ndf, output_nc)])
640 | else:
641 | self.fc = nn.Sequential(*[nn.Linear(output_ndf, output_nc)])
642 | self.conv = nn.Sequential(*conv_layers)
643 |
644 | def forward(self, x):
645 | x_conv = self.conv(x)
646 | conv_flat = x_conv.view(x.size(0), -1)
647 | output = self.fc(conv_flat)
648 | if self.vaeLike:
649 | outputVar = self.fcVar(conv_flat)
650 | return output, outputVar
651 | else:
652 | return output
653 | return output
654 |
655 |
656 | # Defines the Unet generator.
657 | # |num_downs|: number of downsamplings in UNet. For example,
658 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1
659 | # at the bottleneck
660 | class G_Unet_add_all(nn.Module):
661 | def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
662 | norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic'):
663 | super(G_Unet_add_all, self).__init__()
664 | self.nz = nz
665 | # construct unet structure
666 | unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, None, innermost=True,
667 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
668 | unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, unet_block,
669 | norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
670 | for i in range(num_downs - 6):
671 | unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, unet_block,
672 | norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
673 | unet_block = UnetBlock_with_z(ngf * 4, ngf * 4, ngf * 8, nz, unet_block,
674 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
675 | unet_block = UnetBlock_with_z(ngf * 2, ngf * 2, ngf * 4, nz, unet_block,
676 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
677 | unet_block = UnetBlock_with_z(
678 | ngf, ngf, ngf * 2, nz, unet_block, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
679 | unet_block = UnetBlock_with_z(input_nc, output_nc, ngf, nz, unet_block,
680 | outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
681 | self.model = unet_block
682 |
683 | def forward(self, x, z):
684 | return self.model(x, z)
685 |
686 |
687 | class UnetBlock_with_z(nn.Module):
688 | def __init__(self, input_nc, outer_nc, inner_nc, nz=0,
689 | submodule=None, outermost=False, innermost=False,
690 | norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='zero'):
691 | super(UnetBlock_with_z, self).__init__()
692 | p = 0
693 | downconv = []
694 | if padding_type == 'reflect':
695 | downconv += [nn.ReflectionPad2d(1)]
696 | elif padding_type == 'replicate':
697 | downconv += [nn.ReplicationPad2d(1)]
698 | elif padding_type == 'zero':
699 | p = 1
700 | else:
701 | raise NotImplementedError(
702 | 'padding [%s] is not implemented' % padding_type)
703 |
704 | self.outermost = outermost
705 | self.innermost = innermost
706 | self.nz = nz
707 | input_nc = input_nc + nz
708 | downconv += [nn.Conv2d(input_nc, inner_nc,
709 | kernel_size=4, stride=2, padding=p)]
710 | # downsample is different from upsample
711 | downrelu = nn.LeakyReLU(0.2, True)
712 | uprelu = nl_layer()
713 |
714 | if outermost:
715 | upconv = upsampleLayer(
716 | inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
717 | down = downconv
718 | up = [uprelu] + upconv + [nn.Tanh()]
719 | elif innermost:
720 | upconv = upsampleLayer(
721 | inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
722 | down = [downrelu] + downconv
723 | up = [uprelu] + upconv
724 | if norm_layer is not None:
725 | up += [norm_layer(outer_nc)]
726 | else:
727 | upconv = upsampleLayer(
728 | inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
729 | down = [downrelu] + downconv
730 | if norm_layer is not None:
731 | down += [norm_layer(inner_nc)]
732 | up = [uprelu] + upconv
733 |
734 | if norm_layer is not None:
735 | up += [norm_layer(outer_nc)]
736 |
737 | if use_dropout:
738 | up += [nn.Dropout(0.5)]
739 | self.down = nn.Sequential(*down)
740 | self.submodule = submodule
741 | self.up = nn.Sequential(*up)
742 |
743 | def forward(self, x, z):
744 | # print(x.size())
745 | if self.nz > 0:
746 | z_img = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), x.size(2), x.size(3))
747 | x_and_z = torch.cat([x, z_img], 1)
748 | else:
749 | x_and_z = x
750 |
751 | if self.outermost:
752 | x1 = self.down(x_and_z)
753 | x2 = self.submodule(x1, z)
754 | return self.up(x2)
755 | elif self.innermost:
756 | x1 = self.up(self.down(x_and_z))
757 | return torch.cat([x1, x], 1)
758 | else:
759 | x1 = self.down(x_and_z)
760 | x2 = self.submodule(x1, z)
761 | return torch.cat([self.up(x2), x], 1)
762 |
763 |
764 | class E_NLayers(nn.Module):
765 | def __init__(self, input_nc, output_nc=1, ndf=64, n_layers=3,
766 | norm_layer=None, nl_layer=None, vaeLike=False):
767 | super(E_NLayers, self).__init__()
768 | self.vaeLike = vaeLike
769 |
770 | kw, padw = 4, 1
771 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw,
772 | stride=2, padding=padw), nl_layer()]
773 |
774 | nf_mult = 1
775 | nf_mult_prev = 1
776 | for n in range(1, n_layers):
777 | nf_mult_prev = nf_mult
778 | nf_mult = min(2**n, 4)
779 | sequence += [
780 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
781 | kernel_size=kw, stride=2, padding=padw)]
782 | if norm_layer is not None:
783 | sequence += [norm_layer(ndf * nf_mult)]
784 | sequence += [nl_layer()]
785 | sequence += [nn.AvgPool2d(8)]
786 | self.conv = nn.Sequential(*sequence)
787 | self.fc = nn.Sequential(*[nn.Linear(ndf * nf_mult, output_nc)])
788 | if vaeLike:
789 | self.fcVar = nn.Sequential(*[nn.Linear(ndf * nf_mult, output_nc)])
790 |
791 | def forward(self, x):
792 | x_conv = self.conv(x)
793 | conv_flat = x_conv.view(x.size(0), -1)
794 | output = self.fc(conv_flat)
795 | if self.vaeLike:
796 | outputVar = self.fcVar(conv_flat)
797 | return output, outputVar
798 | return output
799 |
--------------------------------------------------------------------------------
/models/pix2pix_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base_model import BaseModel
3 | from . import networks
4 |
5 |
6 | class Pix2PixModel(BaseModel):
7 | """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.
8 |
9 | The model training requires '--dataset_mode aligned' dataset.
10 | By default, it uses a '--netG unet256' U-Net generator,
11 | a '--netD basic' discriminator (PatchGAN),
12 | and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).
13 |
14 | pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
15 | """
16 | @staticmethod
17 | def modify_commandline_options(parser, is_train=True):
18 | """Add new dataset-specific options, and rewrite default values for existing options.
19 |
20 | Parameters:
21 | parser -- original option parser
22 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
23 |
24 | Returns:
25 | the modified parser.
26 |
27 | For pix2pix, we do not use image buffer
28 | The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1
29 | By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets.
30 | """
31 | # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/)
32 | parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='aligned')
33 | parser.set_defaults(where_add='input', nz=0)
34 | if is_train:
35 | parser.set_defaults(gan_mode='vanilla', lambda_l1=100.0)
36 |
37 | return parser
38 |
39 | def __init__(self, opt):
40 | """Initialize the pix2pix class.
41 |
42 | Parameters:
43 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
44 | """
45 | BaseModel.__init__(self, opt)
46 | # specify the training losses you want to print out. The training/test scripts will call
47 | self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
48 | # specify the images you want to save/display. The training/test scripts will call
49 | self.visual_names = ['real_A', 'fake_B', 'real_B']
50 | # specify the models you want to save to the disk. The training/test scripts will call and
51 | if self.isTrain:
52 | self.model_names = ['G', 'D']
53 | else: # during test time, only load G
54 | self.model_names = ['G']
55 | # define networks (both generator and discriminator)
56 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.nz, opt.ngf, netG=opt.netG,
57 | norm=opt.norm, nl=opt.nl, use_dropout=opt.use_dropout, init_type=opt.init_type, init_gain=opt.init_gain,
58 | gpu_ids=self.gpu_ids, where_add=opt.where_add, upsample=opt.upsample)
59 |
60 | if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
61 | self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, netD=opt.netD2, norm=opt.norm, nl=opt.nl,
62 | init_type=opt.init_type, init_gain=opt.init_gain, num_Ds=opt.num_Ds, gpu_ids=self.gpu_ids)
63 |
64 | if self.isTrain:
65 | # define loss functions
66 | self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
67 | self.criterionL1 = torch.nn.L1Loss()
68 | # initialize optimizers; schedulers will be automatically created by function .
69 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
70 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
71 | self.optimizers.append(self.optimizer_G)
72 | self.optimizers.append(self.optimizer_D)
73 |
74 | def set_input(self, input):
75 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
76 |
77 | Parameters:
78 | input (dict): include the data itself and its metadata information.
79 |
80 | The option 'direction' can be used to swap images in domain A and domain B.
81 | """
82 | AtoB = self.opt.direction == 'AtoB'
83 | self.real_A = input['A' if AtoB else 'B'].to(self.device)
84 | self.real_B = input['B' if AtoB else 'A'].to(self.device)
85 | self.image_paths = input['A_paths' if AtoB else 'B_paths']
86 |
87 | def forward(self):
88 | """Run forward pass; called by both functions and ."""
89 | self.fake_B = self.netG(self.real_A) # G(A)
90 |
91 | def backward_D(self):
92 | """Calculate GAN loss for the discriminator"""
93 | # Fake; stop backprop to the generator by detaching fake_B
94 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator
95 | pred_fake = self.netD(fake_AB.detach())
96 | self.loss_D_fake, _ = self.criterionGAN(pred_fake, False)
97 | # Real
98 | real_AB = torch.cat((self.real_A, self.real_B), 1)
99 | pred_real = self.netD(real_AB)
100 | self.loss_D_real, _ = self.criterionGAN(pred_real, True)
101 | # combine loss and calculate gradients
102 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
103 | self.loss_D.backward()
104 |
105 | def backward_G(self):
106 | """Calculate GAN and L1 loss for the generator"""
107 | # First, G(A) should fake the discriminator
108 | fake_AB = torch.cat((self.real_A, self.fake_B), 1)
109 | pred_fake = self.netD(fake_AB)
110 | self.loss_G_GAN, _ = self.criterionGAN(pred_fake, True)
111 | # Second, G(A) = B
112 | self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
113 | # combine loss and calculate gradients
114 | self.loss_G = self.loss_G_GAN + self.loss_G_L1
115 | self.loss_G.backward()
116 |
117 | def optimize_parameters(self):
118 | self.forward() # compute fake images: G(A)
119 | # update D
120 | self.set_requires_grad(self.netD, True) # enable backprop for D
121 | self.optimizer_D.zero_grad() # set D's gradients to zero
122 | self.backward_D() # calculate gradients for D
123 | self.optimizer_D.step() # update D's weights
124 | # update G
125 | self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G
126 | self.optimizer_G.zero_grad() # set G's gradients to zero
127 | self.backward_G() # calculate graidents for G
128 | self.optimizer_G.step() # udpate G's weights
129 |
--------------------------------------------------------------------------------
/models/template_model.py:
--------------------------------------------------------------------------------
1 | """Model class template
2 |
3 | This module provides a template for users to implement custom models.
4 | You can specify '--model template' to use this model.
5 | The class name should be consistent with both the filename and its model option.
6 | The filename should be _dataset.py
7 | The class name should be Dataset.py
8 | It implements a simple image-to-image translation baseline based on regression loss.
9 | Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
10 | min_ ||netG(data_A) - data_B||_1
11 | You need to implement the following functions:
12 | : Add model-specific options and rewrite default values for existing options.
13 | <__init__>: Initialize this model class.
14 | : Unpack input data and perform data pre-processing.
15 | : Run forward pass. This will be called by both and .
16 | : Update network weights; it will be called in every training iteration.
17 | """
18 | import torch
19 | from .base_model import BaseModel
20 | from . import networks
21 |
22 |
23 | class TemplateModel(BaseModel):
24 | @staticmethod
25 | def modify_commandline_options(parser, is_train=True):
26 | """Add new model-specific options and rewrite default values for existing options.
27 |
28 | Parameters:
29 | parser -- the option parser
30 | is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
31 |
32 | Returns:
33 | the modified parser.
34 | """
35 | parser.set_defaults(dataset_mode='aligned', netG='unet_256', where_add='input', nz=0) # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
36 | if is_train:
37 | parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model.
38 |
39 | return parser
40 |
41 | def __init__(self, opt):
42 | """Initialize this model class.
43 |
44 | Parameters:
45 | opt -- training/test options
46 |
47 | A few things can be done here.
48 | - (required) call the initialization function of BaseModel
49 | - define loss function, visualization images, model names, and optimizers
50 | """
51 | BaseModel.__init__(self, opt) # call the initialization method of BaseModel
52 | # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
53 | self.loss_names = ['loss_G']
54 | # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
55 | self.visual_names = ['data_A', 'data_B', 'output']
56 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
57 | # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
58 | self.model_names = ['G']
59 | # define networks; you can use opt.isTrain to specify different behaviors for training and test.
60 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.nz, opt.ngf, netG=opt.netG, gpu_ids=self.gpu_ids, where_add=opt.where_add)
61 | if self.isTrain: # only defined during training time
62 | # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
63 | # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
64 | self.criterionLoss = torch.nn.L1Loss()
65 | # define and initialize optimizers. You can define one optimizer for each network.
66 | # If two networks are updated at the same time, you can use itertools.chain to group them. See bicycle_gan_model.py for an example.
67 | self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
68 | self.optimizers = [self.optimizer]
69 |
70 | # Our program will automatically call to define schedulers, load networks, and print networks
71 |
72 | def set_input(self, input):
73 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
74 |
75 | Parameters:
76 | input: a dictionary that contains the data itself and its metadata information.
77 | """
78 | AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B
79 | self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A
80 | self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B
81 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths
82 |
83 | def forward(self):
84 | """Run forward pass. This will be called by both functions and ."""
85 | self.output = self.netG(self.data_A) # generate output image given the input data_A
86 |
87 | def backward(self):
88 | """Calculate losses, gradients, and update network weights; called in every training iteration"""
89 | # caculate the intermediate results if necessary; here self.output has been computed during function
90 | # calculate loss given the input and intermediate results
91 | self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
92 | self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G
93 |
94 | def optimize_parameters(self):
95 | """Update network weights; it will be called in every training iteration."""
96 | self.forward() # first call forward to calculate intermediate results
97 | self.optimizer.zero_grad() # clear network G's existing gradients
98 | self.backward() # calculate gradients for network G
99 | self.optimizer.step() # update gradients for network G
100 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyanz/BicycleGAN/40b9d52c27b9831f56c1c7c7a6ddde8bc9149067/options/__init__.py
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from util import util
4 | import torch
5 | import models
6 | import data
7 |
8 |
9 | class BaseOptions():
10 | def __init__(self):
11 | self.initialized = False
12 |
13 | def initialize(self, parser):
14 | """This class defines options used during both training and test time.
15 |
16 | It also implements several helper functions such as parsing, printing, and saving the options.
17 | It also gathers additional options defined in functions in both dataset class and model class.
18 | """
19 | parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
20 | parser.add_argument('--batch_size', type=int, default=2, help='input batch size')
21 | parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
22 | parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
23 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
24 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
25 | parser.add_argument('--nz', type=int, default=8, help='#latent vector')
26 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2, -1 for CPU mode')
27 | parser.add_argument('--name', type=str, default='', help='name of the experiment. It decides where to store samples and models')
28 | parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='not implemented')
29 | parser.add_argument('--dataset_mode', type=str, default='aligned', help='aligned,single')
30 | parser.add_argument('--model', type=str, default='bicycle_gan', help='chooses which model to use. bicycle,, ...')
31 | parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
32 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
33 | parser.add_argument('--num_threads', default=4, type=int, help='# sthreads for loading data')
34 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
35 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
36 | parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
37 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
38 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
39 |
40 | # model parameters
41 | parser.add_argument('--num_Ds', type=int, default=2, help='number of Discrminators')
42 | parser.add_argument('--netD', type=str, default='basic_256_multi', help='selects model to use for netD')
43 | parser.add_argument('--netD2', type=str, default='basic_256_multi', help='selects model to use for netD2')
44 | parser.add_argument('--netG', type=str, default='unet_256', help='selects model to use for netG')
45 | parser.add_argument('--netE', type=str, default='resnet_256', help='selects model to use for netE')
46 | parser.add_argument('--nef', type=int, default=64, help='# of encoder filters in the first conv layer')
47 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
48 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
49 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
50 | parser.add_argument('--upsample', type=str, default='basic', help='basic | bilinear')
51 | parser.add_argument('--nl', type=str, default='relu', help='non-linearity activation: relu | lrelu | elu')
52 |
53 | # extra parameters
54 | parser.add_argument('--where_add', type=str, default='all', help='input|all|middle; where to add z in the network G')
55 | parser.add_argument('--conditional_D', action='store_true', help='if use conditional GAN for D')
56 | parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal | xavier | kaiming | orthogonal]')
57 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
58 | parser.add_argument('--center_crop', action='store_true', help='if apply for center cropping for the test')
59 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
60 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
61 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
62 |
63 | # special tasks
64 | self.initialized = True
65 | return parser
66 |
67 | def gather_options(self):
68 | """Initialize our parser with basic options(only once).
69 | Add additional model-specific and dataset-specific options.
70 | These options are difined in the function
71 | in model and dataset classes.
72 | """
73 | if not self.initialized: # check if it has been initialized
74 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
75 | parser = self.initialize(parser)
76 |
77 | # get the basic options
78 | opt, _ = parser.parse_known_args()
79 |
80 | # modify model-related parser options
81 | model_name = opt.model
82 | model_option_setter = models.get_option_setter(model_name)
83 | parser = model_option_setter(parser, self.isTrain)
84 | opt, _ = parser.parse_known_args() # parse again with new defaults
85 |
86 | # modify dataset-related parser options
87 | dataset_name = opt.dataset_mode
88 | dataset_option_setter = data.get_option_setter(dataset_name)
89 | parser = dataset_option_setter(parser, self.isTrain)
90 |
91 | # save and return the parser
92 | self.parser = parser
93 | return parser.parse_args()
94 |
95 | def print_options(self, opt):
96 | """Print and save options
97 |
98 | It will print both current options and default values(if different).
99 | It will save options into a text file / [checkpoints_dir] / opt.txt
100 | """
101 | message = ''
102 | message += '----------------- Options ---------------\n'
103 | for k, v in sorted(vars(opt).items()):
104 | comment = ''
105 | default = self.parser.get_default(k)
106 | if v != default:
107 | comment = '\t[default: %s]' % str(default)
108 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
109 | message += '----------------- End -------------------'
110 | print(message)
111 |
112 | # save to the disk
113 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
114 | util.mkdirs(expr_dir)
115 | file_name = os.path.join(expr_dir, 'opt.txt')
116 | with open(file_name, 'wt') as opt_file:
117 | opt_file.write(message)
118 | opt_file.write('\n')
119 |
120 | def parse(self):
121 | """Parse our options, create checkpoints directory suffix, and set up gpu device."""
122 | opt = self.gather_options()
123 | opt.isTrain = self.isTrain # train or test
124 |
125 | # process opt.suffix
126 | if opt.suffix:
127 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
128 | opt.name = opt.name + suffix
129 |
130 | self.print_options(opt)
131 |
132 | # set gpu ids
133 | str_ids = opt.gpu_ids.split(',')
134 | opt.gpu_ids = []
135 | for str_id in str_ids:
136 | id = int(str_id)
137 | if id >= 0:
138 | opt.gpu_ids.append(id)
139 | if len(opt.gpu_ids) > 0:
140 | torch.cuda.set_device(opt.gpu_ids[0])
141 |
142 | self.opt = opt
143 | return self.opt
144 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | def initialize(self, parser):
6 | BaseOptions.initialize(self, parser)
7 | parser.add_argument('--results_dir', type=str, default='../results/', help='saves results here.')
8 | parser.add_argument('--phase', type=str, default='val', help='train, val, test, etc')
9 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
10 | parser.add_argument('--n_samples', type=int, default=5, help='#samples')
11 | parser.add_argument('--no_encode', action='store_true', help='do not produce encoded image')
12 | parser.add_argument('--sync', action='store_true', help='use the same latent code for different input images')
13 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio for the results')
14 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
15 |
16 | self.isTrain = False
17 | return parser
18 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TrainOptions(BaseOptions):
5 | def initialize(self, parser):
6 | BaseOptions.initialize(self, parser)
7 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
8 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
9 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
10 | parser.add_argument('--display_port', type=int, default=8097, help='visdom display port')
11 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
12 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
13 | parser.add_argument('--update_html_freq', type=int, default=4000, help='frequency of saving training results to html')
14 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
15 | parser.add_argument('--save_latest_freq', type=int, default=10000, help='frequency of saving the latest results')
16 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
17 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
18 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
19 | parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla | lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
20 | # training parameters
21 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
22 | parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
23 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
24 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
25 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
26 | parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy: linear | step | plateau | cosine')
27 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
28 | parser.add_argument('--lr_decay_iters', type=int, default=100, help='multiply by a gamma every lr_decay_iters iterations')
29 | # lambda parameters
30 | parser.add_argument('--lambda_L1', type=float, default=10.0, help='weight for |B-G(A, E(B))|')
31 | parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight on D loss. D(G(A, E(B)))')
32 | parser.add_argument('--lambda_GAN2', type=float, default=1.0, help='weight on D2 loss, D(G(A, random_z))')
33 | parser.add_argument('--lambda_z', type=float, default=0.5, help='weight for ||E(G(random_z)) - random_z||')
34 | parser.add_argument('--lambda_kl', type=float, default=0.01, help='weight for KL loss')
35 | parser.add_argument('--use_same_D', action='store_true', help='if two Ds share the weights or not')
36 | self.isTrain = True
37 | return parser
38 |
--------------------------------------------------------------------------------
/options/video_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class VideoOptions(BaseOptions):
5 | def initialize(self, parser):
6 | BaseOptions.initialize(self, parser)
7 | parser.add_argument('--results_dir', type=str, default='../video/', help='saves results here.')
8 | parser.add_argument('--phase', type=str, default='val', help='train, val, test, etc')
9 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
10 | parser.add_argument('--n_samples', type=int, default=5, help='#samples')
11 | parser.add_argument('--num_frames', type=int, default=4, help='number of the frames used in the morphing sequence')
12 | parser.add_argument('--align_mode', type=str, default='horizontal', help='ways of aligning the input images')
13 | parser.add_argument('--border', type=int, default='0', help='border between results')
14 | parser.add_argument('--seed', type=int, default=50, help='random seed for latent vectors')
15 | parser.add_argument('--fps', type=int, default=8, help='speed of the generated video')
16 | self.isTrain = False
17 | return parser
18 |
--------------------------------------------------------------------------------
/pretrained_models/download_model.sh:
--------------------------------------------------------------------------------
1 | FILE=$1
2 |
3 | echo "Note: available models are edges2shoes, edges2handbags, night2day, maps, and facades"
4 | echo "downloading [$FILE]"
5 | MODEL_DIR=./pretrained_models/${FILE}
6 | mkdir -p ${MODEL_DIR}
7 |
8 |
9 | MODEL_FILE_G=${MODEL_DIR}/latest_net_G.pth
10 | URL_G=http://efrosgans.eecs.berkeley.edu/BicycleGAN//models/${FILE}_net_G.pth
11 | wget -N $URL_G -O $MODEL_FILE_G
12 |
13 |
14 | MODEL_FILE_E=${MODEL_DIR}/latest_net_E.pth
15 | URL_E=http://efrosgans.eecs.berkeley.edu/BicycleGAN//models/${FILE}_net_E.pth
16 | wget -N $URL_E -O $MODEL_FILE_E
17 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=0.4.1
2 | torchvision>=0.2.1
3 | dominate>=2.3.1
4 | visdom>=0.1.8.3
5 |
--------------------------------------------------------------------------------
/scripts/check_all.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | DOWNLOAD_MODEL=${1}
3 | # test code
4 | echo 'test edges2handbags'
5 | if [ ${DOWNLOAD_MODEL} -eq 1 ]
6 | then
7 | bash ./pretrained_models/download_model.sh edges2handbags
8 | fi
9 | bash ./datasets/download_testset.sh edges2handbags
10 | bash ./scripts/test_edges2handbags.sh
11 |
12 | echo 'test edges2shoes'
13 | if [ ${DOWNLOAD_MODEL} -eq 1 ]
14 | then
15 | bash ./pretrained_models/download_model.sh edges2shoes
16 | fi
17 | bash ./datasets/download_testset.sh edges2shoes
18 | bash ./scripts/test_edges2shoes.sh
19 |
20 | echo 'test facades_label2image'
21 | if [ ${DOWNLOAD_MODEL} -eq 1 ]
22 | then
23 | bash ./pretrained_models/download_model.sh night2day
24 | fi
25 | bash ./datasets/download_testset.sh night2day
26 | bash ./scripts/test_night2day.sh
27 |
28 | echo 'test maps'
29 | if [ ${DOWNLOAD_MODEL} -eq 1 ]
30 | then
31 | bash ./pretrained_models/download_model.sh maps
32 | fi
33 | bash ./datasets/download_testset.sh maps
34 | bash ./scripts/test_maps.sh
35 |
36 | echo 'test facades'
37 | if [ ${DOWNLOAD_MODEL} -eq 1 ]
38 | then
39 | bash ./pretrained_models/download_model.sh facades
40 | fi
41 | bash ./datasets/download_testset.sh facades
42 | bash ./scripts/test_facades.sh
43 |
44 | echo 'test night2day'
45 | if [ ${DOWNLOAD_MODEL} -eq 1 ]
46 | then
47 | bash ./pretrained_models/download_model.sh night2day
48 | fi
49 | bash ./datasets/download_testset.sh night2day
50 | bash ./scripts/test_night2day.sh
51 |
52 | echo 'video edges2shoes'
53 | bash ./scripts/video_edges2shoes.sh
54 |
55 | echo "train a pix2pix model"
56 | bash ./datasets/download_dataset.sh facades
57 | python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix \
58 | --netG unet_256 --direction BtoA --lambda_L1 10 --dataset_mode aligned \
59 | --gan_mode lsgan --norm batch --niter 1 --niter_decay 0 --save_epoch_freq 1
60 | echo "train a bicyclegan model"
61 | python train.py --dataroot ./datasets/facades --name facades_bicycle --model bicycle_gan \
62 | --direction BtoA --dataset_mode aligned \
63 | --gan_mode lsgan --norm batch --niter 1 --niter_decay 0 --save_epoch_freq 1
64 |
--------------------------------------------------------------------------------
/scripts/install_conda.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | conda install -c conda-forge visdom
3 | conda install -c conda-forge dominate
4 | conda install -c conda-forge moviepy
5 |
--------------------------------------------------------------------------------
/scripts/install_pip.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | sudo -H pip install visdom
3 | sudo -H pip install dominate
4 | sudo -H pip install moviepy
5 |
--------------------------------------------------------------------------------
/scripts/test_before_push.py:
--------------------------------------------------------------------------------
1 | """Simple script to make sure basic usage such as training, testing, saving and loading runs without errors."""
2 | import os
3 |
4 |
5 | def run(command):
6 | print(command)
7 | exit_status = os.system(command)
8 | if exit_status > 0:
9 | exit(1)
10 |
11 |
12 | if __name__ == '__main__':
13 | if not os.path.exists('./datasets/mini_pix2pix'):
14 | run('bash ./datasets/download_mini_dataset.sh mini_pix2pix')
15 |
16 | # pix2pix train/test
17 | run('python train.py --model pix2pix --name temp_pix2pix --dataroot ./datasets/mini_pix2pix --niter 1 --niter_decay 0 --save_latest_freq 10 --display_id -1')
18 |
19 | # template train/test
20 | run('python train.py --model template --name temp2 --dataroot ./datasets/mini_pix2pix --niter 1 --niter_decay 0 --save_latest_freq 10 --display_id -1')
21 |
22 | run('bash ./scripts/test_edges2shoes.sh')
23 | run('bash ./scripts/test_edges2shoes.sh --sync')
24 | run('bash ./scripts/video_edges2shoes.sh')
25 | # run('bash ./scripts/train_facades.sh')
26 |
--------------------------------------------------------------------------------
/scripts/test_edges2handbags.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | # models
3 | RESULTS_DIR='./results/edges2handbags'
4 | G_PATH='./pretrained_models/edges2handbags_net_G.pth'
5 | E_PATH='./pretrained_models/edges2handbags_net_E.pth'
6 |
7 | # dataset
8 | CLASS='edges2handbags'
9 | DIRECTION='AtoB' # from domain A to domain B
10 | LOAD_SIZE=256 # scale images to this size
11 | CROP_SIZE=256 # then crop to this size
12 | INPUT_NC=1 # number of channels in the input image
13 |
14 | # misc
15 | GPU_ID=0 # gpu id
16 | NUM_TEST=10 # number of input images duirng test
17 | NUM_SAMPLES=10 # number of samples per input images
18 |
19 | # command
20 | CUDA_VISIBLE_DEVICES=${GPU_ID} python ./test.py \
21 | --dataroot ./datasets/${CLASS} \
22 | --results_dir ${RESULTS_DIR} \
23 | --checkpoints_dir ./pretrained_models/ \
24 | --name ${CLASS} \
25 | --direction ${DIRECTION} \
26 | --load_size ${LOAD_SIZE} \
27 | --crop_size ${CROP_SIZE} \
28 | --input_nc ${INPUT_NC} \
29 | --num_test ${NUM_TEST} \
30 | --n_samples ${NUM_SAMPLES} \
31 | --center_crop \
32 | --no_flip
33 |
--------------------------------------------------------------------------------
/scripts/test_edges2shoes.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | # models
3 | RESULTS_DIR='./results/edges2shoes'
4 | G_PATH='./pretrained_models/edges2shoes_net_G.pth'
5 | E_PATH='./pretrained_models/edges2shoes_net_E.pth'
6 |
7 | # dataset
8 | CLASS='edges2shoes'
9 | DIRECTION='AtoB' # from domain A to domain B
10 | LOAD_SIZE=256 # scale images to this size
11 | CROP_SIZE=256 # then crop to this size
12 | INPUT_NC=1 # number of channels in the input image
13 |
14 | # misc
15 | GPU_ID=0 # gpu id
16 | NUM_TEST=10 # number of input images duirng test
17 | NUM_SAMPLES=10 # number of samples per input images
18 |
19 | # command
20 | CUDA_VISIBLE_DEVICES=${GPU_ID} python ./test.py \
21 | --dataroot ./datasets/${CLASS} \
22 | --results_dir ${RESULTS_DIR} \
23 | --checkpoints_dir ./pretrained_models/ \
24 | --name ${CLASS} \
25 | --direction ${DIRECTION} \
26 | --load_size ${LOAD_SIZE} \
27 | --crop_size ${CROP_SIZE} \
28 | --input_nc ${INPUT_NC} \
29 | --num_test ${NUM_TEST} \
30 | --n_samples ${NUM_SAMPLES} \
31 | --center_crop \
32 | --no_flip
33 |
--------------------------------------------------------------------------------
/scripts/test_facades.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | # models
3 | RESULTS_DIR='./results/facades'
4 |
5 | # dataset
6 | CLASS='facades'
7 | DIRECTION='BtoA' # from domain A to domain B
8 | LOAD_SIZE=286 # scale images to this size
9 | CROP_SIZE=256 # then crop to this size
10 | INPUT_NC=3 # number of channels in the input image
11 |
12 | # misc
13 | GPU_ID=0 # gpu id
14 | NUM_TEST=10 # number of input images duirng test
15 | NUM_SAMPLES=10 # number of samples per input images
16 |
17 |
18 | # command
19 | CUDA_VISIBLE_DEVICES=${GPU_ID} python ./test.py \
20 | --dataroot ./datasets/${CLASS} \
21 | --results_dir ${RESULTS_DIR} \
22 | --checkpoints_dir ./pretrained_models/ \
23 | --name ${CLASS} \
24 | --direction ${DIRECTION} \
25 | --load_size ${LOAD_SIZE} \
26 | --crop_size ${CROP_SIZE} \
27 | --input_nc ${INPUT_NC} \
28 | --num_test ${NUM_TEST} \
29 | --n_samples ${NUM_SAMPLES} \
30 | --center_crop \
31 | --no_flip
32 |
--------------------------------------------------------------------------------
/scripts/test_maps.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | # models
3 | RESULTS_DIR='./results/maps'
4 | G_PATH='./pretrained_models/map2aerial_net_G.pth'
5 | E_PATH='./pretrained_models/map2aerial_net_E.pth'
6 |
7 | # dataset
8 | CLASS='maps'
9 | DIRECTION='BtoA' # from domain A to domain B
10 | LOAD_SIZE=512 # scale images to this size
11 | CROP_SIZE=512 # then crop to this size
12 | INPUT_NC=3 # number of channels in the input image
13 | ASPECT_RATIO=1.0
14 | # change aspect ratio for the test images
15 |
16 | # misc
17 | GPU_ID=0 # gpu id
18 | NUM_TEST=10 # number of input images duirng test
19 | NUM_SAMPLES=10 # number of samples per input images
20 |
21 |
22 | # command
23 | CUDA_VISIBLE_DEVICES=${GPU_ID} python ./test.py \
24 | --dataroot ./datasets/${CLASS} \
25 | --results_dir ${RESULTS_DIR} \
26 | --checkpoints_dir ./pretrained_models/ \
27 | --name ${CLASS} \
28 | --direction ${DIRECTION} \
29 | --load_size ${LOAD_SIZE} \
30 | --crop_size ${CROP_SIZE} \
31 | --input_nc ${INPUT_NC} \
32 | --num_test ${NUM_TEST} \
33 | --n_samples ${NUM_SAMPLES} \
34 | --aspect_ratio ${ASPECT_RATIO} \
35 | --center_crop \
36 | --no_flip \
37 | --no_encode
38 |
--------------------------------------------------------------------------------
/scripts/test_night2day.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | # models
3 | RESULTS_DIR='./results/night2day'
4 | G_PATH='./pretrained_models/night2day_net_G.pth'
5 | E_PATH='./pretrained_models/night2day_net_E.pth'
6 |
7 | # dataset
8 | CLASS='night2day'
9 | DIRECTION='AtoB' # from domain A to domain B
10 | LOAD_SIZE=286 # scale images to this size
11 | CROP_SIZE=256 # then crop to this size
12 | INPUT_NC=3 # number of channels in the input image
13 | ASPECT_RATIO=1.4
14 | # change aspect ratio for the test images
15 |
16 | # misc
17 | GPU_ID=0 # gpu id
18 | NUM_TEST=10 # number of input images duirng test
19 | NUM_SAMPLES=10 # number of samples per input images
20 |
21 |
22 | # command
23 | CUDA_VISIBLE_DEVICES=${GPU_ID} python ./test.py \
24 | --dataroot ./datasets/${CLASS} \
25 | --results_dir ${RESULTS_DIR} \
26 | --checkpoints_dir ./pretrained_models/ \
27 | --name ${CLASS} \
28 | --direction ${DIRECTION} \
29 | --load_size ${LOAD_SIZE} \
30 | --crop_size ${CROP_SIZE} \
31 | --input_nc ${INPUT_NC} \
32 | --num_test ${NUM_TEST} \
33 | --n_samples ${NUM_SAMPLES} \
34 | --aspect_ratio ${ASPECT_RATIO} \
35 | --center_crop \
36 | --no_flip
37 |
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | CLASS='edges2shoes' # facades, day2night, edges2shoes, edges2handbags, maps
3 | MODEL='bicycle_gan'
4 | CLASS=${1}
5 | GPU_ID=${2}
6 | DISPLAY_ID=$((GPU_ID*10+1))
7 | PORT=2005
8 | NZ=8
9 |
10 |
11 | CHECKPOINTS_DIR=../checkpoints/${CLASS}/
12 | DATE=`date '+%d_%m_%Y_%H'`
13 | NAME=${CLASS}_${MODEL}_${DATE}
14 |
15 |
16 | # dataset
17 | NO_FLIP=''
18 | DIRECTION='AtoB'
19 | LOAD_SIZE=286
20 | CROP_SIZE=256
21 | INPUT_NC=3
22 |
23 | # dataset parameters
24 | case ${CLASS} in
25 | 'facades')
26 | NITER=200
27 | NITER_DECAY=200
28 | SAVE_EPOCH=25
29 | DIRECTION='BtoA'
30 | ;;
31 | 'edges2shoes')
32 | NITER=30
33 | NITER_DECAY=30
34 | LOAD_SIZE=256
35 | SAVE_EPOCH=5
36 | INPUT_NC=1
37 | NO_FLIP='--no_flip'
38 | ;;
39 | 'edges2handbags')
40 | NITER=15
41 | NITER_DECAY=15
42 | LOAD_SIZE=256
43 | SAVE_EPOCH=5
44 | INPUT_NC=1
45 | ;;
46 | 'maps')
47 | NITER=200
48 | NITER_DECAY=200
49 | LOAD_SIZE=600
50 | SAVE_EPOCH=25
51 | DIRECTION='BtoA'
52 | ;;
53 | 'day2night')
54 | NITER=50
55 | NITER_DECAY=50
56 | SAVE_EPOCH=10
57 | ;;
58 | *)
59 | echo 'WRONG category'${CLASS}
60 | ;;
61 | esac
62 |
63 |
64 |
65 | # command
66 | CUDA_VISIBLE_DEVICES=${GPU_ID} python ./train.py \
67 | --display_id ${DISPLAY_ID} \
68 | --dataroot ./datasets/${CLASS} \
69 | --name ${NAME} \
70 | --model ${MODEL} \
71 | --display_port ${PORT} \
72 | --direction ${DIRECTION} \
73 | --checkpoints_dir ${CHECKPOINTS_DIR} \
74 | --load_size ${LOAD_SIZE} \
75 | --crop_size ${CROP_SIZE} \
76 | --nz ${NZ} \
77 | --save_epoch_freq ${SAVE_EPOCH} \
78 | --input_nc ${INPUT_NC} \
79 | --niter ${NITER} \
80 | --niter_decay ${NITER_DECAY} \
81 | --use_dropout
82 |
--------------------------------------------------------------------------------
/scripts/train_edges2shoes.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | MODEL='bicycle_gan'
3 | # dataset details
4 | CLASS='edges2shoes' # facades, day2night, edges2shoes, edges2handbags, maps
5 | NZ=8
6 | NO_FLIP='--no_flip'
7 | DIRECTION='AtoB'
8 | LOAD_SIZE=256
9 | CROP_SIZE=256
10 | INPUT_NC=1
11 | NITER=30
12 | NITER_DECAY=30
13 |
14 | # training
15 | GPU_ID=0
16 | DISPLAY_ID=$((GPU_ID*10+1))
17 | CHECKPOINTS_DIR=../checkpoints/${CLASS}/
18 | NAME=${CLASS}_${MODEL}
19 |
20 | # command
21 | CUDA_VISIBLE_DEVICES=${GPU_ID} python ./train.py \
22 | --display_id ${DISPLAY_ID} \
23 | --dataroot ./datasets/${CLASS} \
24 | --name ${NAME} \
25 | --model ${MODEL} \
26 | --direction ${DIRECTION} \
27 | --checkpoints_dir ${CHECKPOINTS_DIR} \
28 | --load_size ${LOAD_SIZE} \
29 | --crop_size ${CROP_SIZE} \
30 | --nz ${NZ} \
31 | --input_nc ${INPUT_NC} \
32 | --niter ${NITER} \
33 | --niter_decay ${NITER_DECAY} \
34 | --use_dropout
35 |
--------------------------------------------------------------------------------
/scripts/train_facades.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | MODEL='bicycle_gan'
3 | # dataset details
4 | CLASS='facades' # facades, day2night, edges2shoes, edges2handbags, maps
5 | NZ=8
6 | NO_FLIP=''
7 | DIRECTION='BtoA'
8 | LOAD_SIZE=256
9 | CROP_SIZE=256
10 | INPUT_NC=3
11 | NITER=200
12 | NITER_DECAY=200
13 | SAVE_EPOCH=25
14 |
15 | # training
16 | GPU_ID=0
17 | DISPLAY_ID=$((GPU_ID*10+1))
18 | CHECKPOINTS_DIR=../checkpoints/${CLASS}/
19 | NAME=${CLASS}_${MODEL}
20 |
21 | # command
22 | python ./train.py \
23 | --display_id ${DISPLAY_ID} \
24 | --dataroot ./datasets/${CLASS} \
25 | --name ${NAME} \
26 | --model ${MODEL} \
27 | --direction ${DIRECTION} \
28 | --checkpoints_dir ${CHECKPOINTS_DIR} \
29 | --load_size ${LOAD_SIZE} \
30 | --crop_size ${CROP_SIZE} \
31 | --nz ${NZ} \
32 | --input_nc ${INPUT_NC} \
33 | --niter ${NITER} \
34 | --niter_decay ${NITER_DECAY} \
35 | --save_epoch_freq ${SAVE_EPOCH} \
36 | --use_dropout
37 |
--------------------------------------------------------------------------------
/scripts/video_edges2shoes.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | # models
3 | RESULTS_DIR='./videos/edges2shoes'
4 | G_PATH='./pretrained_models/edges2shoes_net_G.pth'
5 |
6 | # dataset
7 | CLASS='edges2shoes'
8 | DIRECTION='AtoB'
9 | LOAD_SIZE=256
10 | CROP_SIZE=256
11 | INPUT_NC=1
12 |
13 | # misc
14 | GPU_ID=0
15 | NUM_TEST=5 # number of input images duirng test
16 | NUM_SAMPLES=20 # number of samples per input images
17 |
18 | # command
19 | CUDA_VISIBLE_DEVICES=${GPU_ID} python ./video.py \
20 | --dataroot ./datasets/${CLASS} \
21 | --results_dir ${RESULTS_DIR} \
22 | --checkpoints_dir ./pretrained_models/ \
23 | --name ${CLASS} \
24 | --direction ${DIRECTION} \
25 | --load_size ${LOAD_SIZE} --crop_size ${CROP_SIZE} \
26 | --input_nc ${INPUT_NC} \
27 | --num_test ${NUM_TEST} \
28 | --n_samples ${NUM_SAMPLES} \
29 | --center_crop \
30 | --no_flip
31 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | from options.test_options import TestOptions
3 | from data import create_dataset
4 | from models import create_model
5 | from util.visualizer import save_images
6 | from itertools import islice
7 | from util import html
8 |
9 |
10 | # options
11 | opt = TestOptions().parse()
12 | opt.num_threads = 1 # test code only supports num_threads=1
13 | opt.batch_size = 1 # test code only supports batch_size=1
14 | opt.serial_batches = True # no shuffle
15 |
16 | # create dataset
17 | dataset = create_dataset(opt)
18 | model = create_model(opt)
19 | model.setup(opt)
20 | model.eval()
21 | print('Loading model %s' % opt.model)
22 |
23 | # create website
24 | web_dir = os.path.join(opt.results_dir, opt.phase + '_sync' if opt.sync else opt.phase)
25 | webpage = html.HTML(web_dir, 'Training = %s, Phase = %s, Class =%s' % (opt.name, opt.phase, opt.name))
26 |
27 | # sample random z
28 | if opt.sync:
29 | z_samples = model.get_z_random(opt.n_samples + 1, opt.nz)
30 |
31 | # test stage
32 | for i, data in enumerate(islice(dataset, opt.num_test)):
33 | model.set_input(data)
34 | print('process input image %3.3d/%3.3d' % (i, opt.num_test))
35 | if not opt.sync:
36 | z_samples = model.get_z_random(opt.n_samples + 1, opt.nz)
37 | for nn in range(opt.n_samples + 1):
38 | encode = nn == 0 and not opt.no_encode
39 | real_A, fake_B, real_B = model.test(z_samples[[nn]], encode=encode)
40 | if nn == 0:
41 | images = [real_A, real_B, fake_B]
42 | names = ['input', 'ground truth', 'encoded']
43 | else:
44 | images.append(fake_B)
45 | names.append('random_sample%2.2d' % nn)
46 |
47 | img_path = 'input_%3.3d' % i
48 | save_images(webpage, images, names, img_path, aspect_ratio=opt.aspect_ratio, width=opt.crop_size)
49 |
50 | webpage.save()
51 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """General-purpose training script for image-to-image translation.
2 | This script works for various models (with option '--model': e.g., bicycle_gan, pix2pix, test) and
3 | different datasets (with option '--dataset_mode': e.g., aligned, single).
4 | You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model').
5 | It first creates model, dataset, and visualizer given the option.
6 | It then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models.
7 | The script supports continue/resume training. Use '--continue_train' to resume your previous training.
8 | Example:
9 | Train a BiCycleGAN model:
10 | python train.py --dataroot ./datasets/facades --name facades_bicyclegan --model bicycle_gan --direction BtoA
11 | Train a pix2pix model:
12 | python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
13 | See options/base_options.py and options/train_options.py for more training options.
14 | """
15 | import time
16 | from options.train_options import TrainOptions
17 | from data import create_dataset
18 | from models import create_model
19 | from util.visualizer import Visualizer
20 |
21 | if __name__ == '__main__':
22 | opt = TrainOptions().parse() # get training options
23 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
24 | dataset_size = len(dataset) # get the number of images in the dataset.
25 | print('The number of training images = %d' % dataset_size)
26 |
27 | model = create_model(opt) # create a model given opt.model and other options
28 | model.setup(opt) # regular setup: load and print networks; create schedulers
29 | visualizer = Visualizer(opt) # create a visualizer that display/save images and plots
30 | total_iters = 0 # the total number of training iterations
31 |
32 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): # outer loop for different epochs; we save the model by , +
33 | epoch_start_time = time.time() # timer for entire epoch
34 | iter_data_time = time.time() # timer for data loading per iteration
35 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
36 |
37 | for i, data in enumerate(dataset): # inner loop within one epoch
38 | iter_start_time = time.time() # timer for computation per iteration
39 | if total_iters % opt.print_freq == 0:
40 | t_data = iter_start_time - iter_data_time
41 | visualizer.reset()
42 | total_iters += opt.batch_size
43 | epoch_iter += opt.batch_size
44 | model.set_input(data) # unpack data from dataset and apply preprocessing
45 | if not model.is_train(): # if this batch of input data is enough for training.
46 | print('skip this batch')
47 | continue
48 | model.optimize_parameters() # calculate loss functions, get gradients, update network weights
49 |
50 | if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file
51 | save_result = total_iters % opt.update_html_freq == 0
52 | model.compute_visuals()
53 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
54 |
55 | if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk
56 | losses = model.get_current_losses()
57 | t_comp = (time.time() - iter_start_time) / opt.batch_size
58 | visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
59 | if opt.display_id > 0:
60 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)
61 |
62 | if total_iters % opt.save_latest_freq == 0: # cache our latest model every iterations
63 | print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
64 | model.save_networks('latest')
65 |
66 | iter_data_time = time.time()
67 | if epoch % opt.save_epoch_freq == 0: # cache our model every epochs
68 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
69 | model.save_networks('latest')
70 | model.save_networks(epoch)
71 |
72 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
73 | model.update_learning_rate() # update learning rates at the end of every epoch.
74 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/junyanz/BicycleGAN/40b9d52c27b9831f56c1c7c7a6ddde8bc9149067/util/__init__.py
--------------------------------------------------------------------------------
/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br
3 | import os
4 |
5 |
6 | class HTML:
7 | """This HTML class allows us to save images and write texts into a single HTML file.
8 |
9 | It consists of functions such as (add a text header to the HTML file),
10 | (add a row of images to the HTML file), and (save the HTML to the disk).
11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
12 | """
13 |
14 | def __init__(self, web_dir, title, refresh=0):
15 | """Initialize the HTML classes
16 |
17 | Parameters:
18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0:
32 | with self.doc.head:
33 | meta(http_equiv="refresh", content=str(refresh))
34 |
35 | def get_image_dir(self):
36 | """Return the directory that stores images"""
37 | return self.img_dir
38 |
39 | def add_header(self, text):
40 | """Insert a header to the HTML file
41 |
42 | Parameters:
43 | text (str) -- the header text
44 | """
45 | with self.doc:
46 | h3(text)
47 |
48 | def add_images(self, ims, txts, links, width=400):
49 | """add images to the HTML file
50 |
51 | Parameters:
52 | ims (str list) -- a list of image paths
53 | txts (str list) -- a list of image names shown on the website
54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
55 | """
56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table
57 | self.doc.add(self.t)
58 | with self.t:
59 | with tr():
60 | for im, txt, link in zip(ims, txts, links):
61 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
62 | with p():
63 | with a(href=os.path.join('images', link)):
64 | img(style="width:%dpx" % width, src=os.path.join('images', im))
65 | br()
66 | p(txt)
67 |
68 | def save(self):
69 | """save the current content to the HMTL file"""
70 | html_file = '%s/index.html' % self.web_dir
71 | f = open(html_file, 'wt')
72 | f.write(self.doc.render())
73 | f.close()
74 |
75 |
76 | if __name__ == '__main__': # we show an example usage here.
77 | html = HTML('web/', 'test_html')
78 | html.add_header('hello world')
79 |
80 | ims, txts, links = [], [], []
81 | for n in range(4):
82 | ims.append('image_%d.png' % n)
83 | txts.append('text_%d' % n)
84 | links.append('image_%d.png' % n)
85 | html.add_images(ims, txts, links)
86 | html.save()
87 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | import os
6 | import pickle
7 |
8 |
9 | def tensor2im(input_image, imtype=np.uint8):
10 | """"Convert a Tensor array into a numpy image array.
11 | Parameters:
12 | input_image (tensor) -- the input image tensor array
13 | imtype (type) -- the desired type of the converted numpy array
14 | """
15 | if not isinstance(input_image, np.ndarray):
16 | if isinstance(input_image, torch.Tensor): # get the data from a variable
17 | image_tensor = input_image.data
18 | else:
19 | return input_image
20 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
21 | if image_numpy.shape[0] == 1: # grayscale to RGB
22 | image_numpy = np.tile(image_numpy, (3, 1, 1))
23 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
24 | else: # if it is a numpy array, do nothing
25 | image_numpy = input_image
26 | return image_numpy.astype(imtype)
27 |
28 |
29 | def tensor2vec(vector_tensor):
30 | numpy_vec = vector_tensor.data.cpu().numpy()
31 | if numpy_vec.ndim == 4:
32 | return numpy_vec[:, :, 0, 0]
33 | else:
34 | return numpy_vec
35 |
36 |
37 | def pickle_load(file_name):
38 | data = None
39 | with open(file_name, 'rb') as f:
40 | data = pickle.load(f)
41 | return data
42 |
43 |
44 | def pickle_save(file_name, data):
45 | with open(file_name, 'wb') as f:
46 | pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
47 |
48 |
49 | def diagnose_network(net, name='network'):
50 | """Calculate and print the mean of average absolute(gradients)
51 | Parameters:
52 | net (torch network) -- Torch network
53 | name (str) -- the name of the network
54 | """
55 | mean = 0.0
56 | count = 0
57 | for param in net.parameters():
58 | if param.grad is not None:
59 | mean += torch.mean(torch.abs(param.grad.data))
60 | count += 1
61 | if count > 0:
62 | mean = mean / count
63 | print(name)
64 | print(mean)
65 |
66 |
67 | def interp_z(z0, z1, num_frames, interp_mode='linear'):
68 | zs = []
69 | if interp_mode == 'linear':
70 | for n in range(num_frames):
71 | ratio = n / float(num_frames - 1)
72 | z_t = (1 - ratio) * z0 + ratio * z1
73 | zs.append(z_t[np.newaxis, :])
74 | zs = np.concatenate(zs, axis=0).astype(np.float32)
75 |
76 | if interp_mode == 'slerp':
77 | z0_n = z0 / (np.linalg.norm(z0) + 1e-10)
78 | z1_n = z1 / (np.linalg.norm(z1) + 1e-10)
79 | omega = np.arccos(np.dot(z0_n, z1_n))
80 | sin_omega = np.sin(omega)
81 | if sin_omega < 1e-10 and sin_omega > -1e-10:
82 | zs = interp_z(z0, z1, num_frames, interp_mode='linear')
83 | else:
84 | for n in range(num_frames):
85 | ratio = n / float(num_frames - 1)
86 | z_t = np.sin((1 - ratio) * omega) / sin_omega * z0 + np.sin(ratio * omega) / sin_omega * z1
87 | zs.append(z_t[np.newaxis, :])
88 | zs = np.concatenate(zs, axis=0).astype(np.float32)
89 |
90 | return zs
91 |
92 |
93 | def save_image(image_numpy, image_path, aspect_ratio=1.0):
94 | """Save a numpy image to the disk
95 | Parameters:
96 | image_numpy (numpy array) -- input numpy array
97 | image_path (str) -- the path of the image
98 | """
99 |
100 | image_pil = Image.fromarray(image_numpy)
101 | h, w, _ = image_numpy.shape
102 |
103 | if aspect_ratio > 1.0:
104 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
105 | if aspect_ratio < 1.0:
106 | image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
107 | image_pil.save(image_path)
108 |
109 |
110 | def print_numpy(x, val=True, shp=False):
111 | """Print the mean, min, max, median, std, and size of a numpy array
112 | Parameters:
113 | val (bool) -- if print the values of the numpy array
114 | shp (bool) -- if print the shape of the numpy array
115 | """
116 | x = x.astype(np.float64)
117 | if shp:
118 | print('shape,', x.shape)
119 | if val:
120 | x = x.flatten()
121 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
122 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
123 |
124 |
125 | def mkdirs(paths):
126 | """create empty directories if they don't exist
127 | Parameters:
128 | paths (str list) -- a list of directory paths
129 | """
130 | if isinstance(paths, list) and not isinstance(paths, str):
131 | for path in paths:
132 | mkdir(path)
133 | else:
134 | mkdir(paths)
135 |
136 |
137 | def mkdir(path):
138 | """create a single empty directory if it didn't exist
139 | Parameters:
140 | path (str) -- a single directory path
141 | """
142 | if not os.path.exists(path):
143 | os.makedirs(path)
144 |
--------------------------------------------------------------------------------
/util/visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import sys
4 | import ntpath
5 | import time
6 | from . import util
7 | from . import html
8 | from subprocess import Popen, PIPE
9 |
10 |
11 | if sys.version_info[0] == 2:
12 | VisdomExceptionBase = Exception
13 | else:
14 | VisdomExceptionBase = ConnectionError
15 |
16 |
17 | def save_images(webpage, images, names, image_path, aspect_ratio=1.0, width=256):
18 | """Save images to the disk.
19 | Parameters:
20 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
21 | images (numpy array list) -- a list of numpy array that stores images
22 | names (str list) -- a str list stores the names of the images above
23 | image_path (str) -- the string is used to create image paths
24 | aspect_ratio (float) -- the aspect ratio of saved images
25 | width (int) -- the images will be resized to width x width
26 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
27 | """
28 | image_dir = webpage.get_image_dir()
29 | name = ntpath.basename(image_path)
30 |
31 | webpage.add_header(name)
32 | ims, txts, links = [], [], []
33 |
34 | for label, im_data in zip(names, images):
35 | im = util.tensor2im(im_data)
36 | image_name = '%s_%s.png' % (name, label)
37 | save_path = os.path.join(image_dir, image_name)
38 | util.save_image(im, save_path, aspect_ratio=aspect_ratio)
39 | ims.append(image_name)
40 | txts.append(label)
41 | links.append(image_name)
42 | webpage.add_images(ims, txts, links, width=width)
43 |
44 |
45 | class Visualizer():
46 | """This class includes several functions that can display/save images and print/save logging information.
47 | It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
48 | """
49 |
50 | def __init__(self, opt):
51 | """Initialize the Visualizer class
52 | Parameters:
53 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
54 | Step 1: Cache the training/test options
55 | Step 2: connect to a visdom server
56 | Step 3: create an HTML object for saveing HTML filters
57 | Step 4: create a logging file to store training losses
58 | """
59 | self.opt = opt # cache the option
60 | self.display_id = opt.display_id
61 | self.use_html = opt.isTrain and not opt.no_html
62 | self.win_size = opt.display_winsize
63 | self.name = opt.name
64 | self.port = opt.display_port
65 | self.saved = False
66 | if self.display_id > 0: # connect to a visdom server given and
67 | import visdom
68 | self.ncols = opt.display_ncols
69 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
70 | if not self.vis.check_connection():
71 | self.create_visdom_connections()
72 | if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/
73 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
74 | self.img_dir = os.path.join(self.web_dir, 'images')
75 | print('create web directory %s...' % self.web_dir)
76 | util.mkdirs([self.web_dir, self.img_dir])
77 | # create a logging file to store training losses
78 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
79 | with open(self.log_name, "a") as log_file:
80 | now = time.strftime("%c")
81 | log_file.write('================ Training Loss (%s) ================\n' % now)
82 |
83 | def reset(self):
84 | """Reset the self.saved status"""
85 | self.saved = False
86 |
87 | def create_visdom_connections(self):
88 | """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
89 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
90 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
91 | print('Command: %s' % cmd)
92 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
93 |
94 | def display_current_results(self, visuals, epoch, save_result):
95 | """Display current results on visdom; save current results to an HTML file.
96 | Parameters:
97 | visuals (OrderedDict) - - dictionary of images to display or save
98 | epoch (int) - - the current epoch
99 | save_result (bool) - - if save the current results to an HTML file
100 | """
101 | if self.display_id > 0: # show images in the browser using visdom
102 | ncols = self.ncols
103 | if ncols > 0: # show all the images in one visdom panel
104 | ncols = min(ncols, len(visuals))
105 | h, w = next(iter(visuals.values())).shape[:2]
106 | table_css = """""" % (w, h) # create a table css
110 | # create a table of images.
111 | title = self.name
112 | label_html = ''
113 | label_html_row = ''
114 | images = []
115 | idx = 0
116 | for label, image in visuals.items():
117 | image_numpy = util.tensor2im(image)
118 | label_html_row += '%s | ' % label
119 | images.append(image_numpy.transpose([2, 0, 1]))
120 | idx += 1
121 | if idx % ncols == 0:
122 | label_html += '%s
' % label_html_row
123 | label_html_row = ''
124 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
125 | while idx % ncols != 0:
126 | images.append(white_image)
127 | label_html_row += ' | '
128 | idx += 1
129 | if label_html_row != '':
130 | label_html += '%s
' % label_html_row
131 | try:
132 | self.vis.images(images, nrow=ncols, win=self.display_id + 1,
133 | padding=2, opts=dict(title=title + ' images'))
134 | label_html = '' % label_html
135 | self.vis.text(table_css + label_html, win=self.display_id + 2,
136 | opts=dict(title=title + ' labels'))
137 | except VisdomExceptionBase:
138 | self.create_visdom_connections()
139 |
140 | else: # show each image in a separate visdom panel;
141 | idx = 1
142 | try:
143 | for label, image in visuals.items():
144 | image_numpy = util.tensor2im(image)
145 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
146 | win=self.display_id + idx)
147 | idx += 1
148 | except VisdomExceptionBase:
149 | self.create_visdom_connections()
150 |
151 | if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
152 | self.saved = True
153 | # save images to the disk
154 | for label, image in visuals.items():
155 | image_numpy = util.tensor2im(image)
156 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
157 | util.save_image(image_numpy, img_path)
158 |
159 | # update website
160 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
161 | for n in range(epoch, 0, -1):
162 | webpage.add_header('epoch [%d]' % n)
163 | ims, txts, links = [], [], []
164 |
165 | for label, image_numpy in visuals.items():
166 | image_numpy = util.tensor2im(image)
167 | img_path = 'epoch%.3d_%s.png' % (n, label)
168 | ims.append(img_path)
169 | txts.append(label)
170 | links.append(img_path)
171 | webpage.add_images(ims, txts, links, width=self.win_size)
172 | webpage.save()
173 |
174 | def plot_current_losses(self, epoch, counter_ratio, losses):
175 | """display the current losses on visdom display: dictionary of error labels and values
176 | Parameters:
177 | epoch (int) -- current epoch
178 | counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
179 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
180 | """
181 | if not hasattr(self, 'plot_data'):
182 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
183 | self.plot_data['X'].append(epoch + counter_ratio)
184 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
185 | try:
186 | self.vis.line(
187 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
188 | Y=np.array(self.plot_data['Y']),
189 | opts={
190 | 'title': self.name + ' loss over time',
191 | 'legend': self.plot_data['legend'],
192 | 'xlabel': 'epoch',
193 | 'ylabel': 'loss'},
194 | win=self.display_id)
195 | except VisdomExceptionBase:
196 | self.create_visdom_connections()
197 |
198 | # losses: same format as |losses| of plot_current_losses
199 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
200 | """print current losses on console; also save the losses to the disk
201 | Parameters:
202 | epoch (int) -- current epoch
203 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
204 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
205 | t_comp (float) -- computational time per data point (normalized by batch_size)
206 | t_data (float) -- data loading time per data point (normalized by batch_size)
207 | """
208 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
209 | for k, v in losses.items():
210 | message += '%s: %.3f ' % (k, v)
211 |
212 | print(message) # print the message
213 | with open(self.log_name, "a") as log_file:
214 | log_file.write('%s\n' % message) # save the message
215 |
--------------------------------------------------------------------------------
/video.py:
--------------------------------------------------------------------------------
1 | from options.video_options import VideoOptions
2 | from data import create_dataset
3 | from models import create_model
4 | from itertools import islice
5 | from util import util
6 | import numpy as np
7 | import moviepy.editor
8 | import os
9 | import torch
10 |
11 |
12 | def get_random_z(opt):
13 | z_samples = np.random.normal(0, 1, (opt.n_samples + 1, opt.nz))
14 | return z_samples
15 |
16 |
17 | def produce_frame(t):
18 | k = int(t * opt.fps)
19 | return np.concatenate(frame_rows[k], axis=1 - use_vertical)
20 |
21 |
22 | # hard-code opt
23 | opt = VideoOptions().parse()
24 | opt.num_threads = 1 # test code only supports num_threads=1
25 | opt.batch_size = 1 # test code only supports batch_size=1
26 | opt.no_encode = True # do not use encoder
27 |
28 | dataset = create_dataset(opt)
29 | model = create_model(opt)
30 | model.setup(opt)
31 | model.eval()
32 | interp_mode = 'slerp'
33 | use_vertical = 1 if opt.align_mode == 'vertical' else 0
34 |
35 | print('Loading model %s' % opt.model)
36 | # create website
37 | results_dir = opt.results_dir
38 | util.mkdir(results_dir)
39 | total_frames = opt.num_frames * opt.n_samples
40 |
41 |
42 | z_samples = get_random_z(opt)
43 | frame_rows = [[] for n in range(total_frames)]
44 |
45 | for i, data in enumerate(islice(dataset, opt.num_test)):
46 | print('process input image %3.3d/%3.3d' % (i, opt.num_test))
47 | model.set_input(data)
48 | real_A = util.tensor2im(model.real_A)
49 | wb = opt.border
50 | hb = opt.border
51 | h = real_A.shape[0]
52 | w = real_A.shape[1] # border
53 | real_A_b = np.full((h + hb, w + wb, opt.output_nc), 255, real_A.dtype)
54 | real_A_b[hb:, wb:, :] = real_A
55 | frames = [[real_A_b] for n in range(total_frames)]
56 |
57 | for n in range(opt.n_samples):
58 | z0 = z_samples[n]
59 | z1 = z_samples[n + 1]
60 | zs = util.interp_z(z0, z1, num_frames=opt.num_frames, interp_mode=interp_mode)
61 | for k in range(opt.num_frames):
62 | zs_k = (torch.Tensor(zs[[k]])).to(model.device)
63 | _, fake_B_device, _ = model.test(zs_k, encode=False)
64 | fake_B = util.tensor2im(fake_B_device)
65 | fake_B_b = np.full((h + hb, w + wb, opt.output_nc), 255, fake_B.dtype)
66 | fake_B_b[hb:, wb:, :] = fake_B
67 | frames[k + opt.num_frames * n].append(fake_B_b)
68 |
69 | for k in range(total_frames):
70 | frame_row = np.concatenate(frames[k], axis=use_vertical)
71 | frame_rows[k].append(frame_row)
72 |
73 | # compile it to a vdieo
74 | images_dir = os.path.join(results_dir, 'frames_seed%4.4d' % opt.seed)
75 | util.mkdir(images_dir)
76 |
77 |
78 | for k in range(total_frames):
79 | final_frame = np.concatenate(frame_rows[k], axis=1 - use_vertical)
80 | util.save_image(final_frame, os.path.join(
81 | images_dir, 'frame_%4.4d.jpg' % k))
82 |
83 |
84 | video_file = os.path.join(
85 | results_dir, 'morphing_video_seed%4.4d_fps%d.mp4' % (opt.seed, opt.fps))
86 | video = moviepy.editor.VideoClip(
87 | produce_frame, duration=float(total_frames) / opt.fps)
88 | video.write_videofile(video_file, fps=30, codec='libx264', bitrate='16M')
89 |
--------------------------------------------------------------------------------