├── .gitignore ├── LICENSE ├── README.md ├── data ├── __init__.py ├── aligned_dataset.py ├── base_dataset.py ├── image_folder.py ├── single_dataset.py └── template_dataset.py ├── datasets ├── bibtex │ ├── cityscapes.tex │ ├── facades.tex │ ├── handbags.tex │ ├── night2day.tex │ ├── shoes.tex │ └── transattr.tex ├── download_dataset.sh ├── download_mini_dataset.sh └── download_testset.sh ├── imgs ├── day2night.gif ├── results_matrix.jpg └── teaser.jpg ├── models ├── __init__.py ├── base_model.py ├── bicycle_gan_model.py ├── networks.py ├── pix2pix_model.py └── template_model.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py ├── train_options.py └── video_options.py ├── pretrained_models └── download_model.sh ├── requirements.txt ├── scripts ├── check_all.sh ├── install_conda.sh ├── install_pip.sh ├── test_before_push.py ├── test_edges2handbags.sh ├── test_edges2shoes.sh ├── test_facades.sh ├── test_maps.sh ├── test_night2day.sh ├── train.sh ├── train_edges2shoes.sh ├── train_facades.sh └── video_edges2shoes.sh ├── test.py ├── train.py ├── util ├── __init__.py ├── html.py ├── util.py └── visualizer.py └── video.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.fuse* 3 | *.pth 4 | *.pyc 5 | debug* 6 | datasets/ 7 | videos/ 8 | checkpoints/ 9 | pretrained_models/ 10 | results/ 11 | build/ 12 | dist/ 13 | *.png 14 | torch.egg-info/ 15 | */**/__pycache__ 16 | torch/version.py 17 | torch/csrc/generic/TensorMethods.cpp 18 | torch/lib/*.so* 19 | torch/lib/*.dylib* 20 | torch/lib/*.h 21 | torch/lib/build 22 | torch/lib/tmp_install 23 | torch/lib/include 24 | torch/lib/torch_shm_manager 25 | torch/csrc/cudnn/cuDNN.cpp 26 | torch/csrc/nn/THNN.cwrap 27 | torch/csrc/nn/THNN.cpp 28 | torch/csrc/nn/THCUNN.cwrap 29 | torch/csrc/nn/THCUNN.cpp 30 | torch/csrc/nn/THNN_generic.cwrap 31 | torch/csrc/nn/THNN_generic.cpp 32 | torch/csrc/nn/THNN_generic.h 33 | docs/src/**/* 34 | test/data/legacy_modules.t7 35 | test/data/gpu_tensors.pt 36 | test/htmlcov 37 | test/.coverage 38 | */*.pyc 39 | */**/*.pyc 40 | */**/**/*.pyc 41 | */**/**/**/*.pyc 42 | */**/**/**/**/*.pyc 43 | */*.so* 44 | */**/*.so* 45 | */**/*.dylib* 46 | test/data/legacy_serialized.pt 47 | *~ 48 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Jun-Yan Zhu, Richard Zhang, and Deepak Pathak 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | 26 | --------------------------- LICENSE FOR pytorch-CycleGAN-and-pix2pix ---------------- 27 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 28 | All rights reserved. 29 | 30 | Redistribution and use in source and binary forms, with or without 31 | modification, are permitted provided that the following conditions are met: 32 | 33 | * Redistributions of source code must retain the above copyright notice, this 34 | list of conditions and the following disclaimer. 35 | 36 | * Redistributions in binary form must reproduce the above copyright notice, 37 | this list of conditions and the following disclaimer in the documentation 38 | and/or other materials provided with the distribution. 39 | 40 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 41 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 42 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 43 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 44 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 45 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 46 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 47 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 48 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 49 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |



4 | 5 | # BicycleGAN 6 | [Project Page](https://junyanz.github.io/BicycleGAN/) | [Paper](https://arxiv.org/abs/1711.11586) | [Video](https://youtu.be/JvGysD2EFhw) 7 | 8 | 9 | Pytorch implementation for multimodal image-to-image translation. For example, given the same night image, our model is able to synthesize possible day images with different types of lighting, sky and clouds. The training requires paired data. 10 | 11 | **Note**: The current software works well with PyTorch 0.41+. Check out the older [branch](https://github.com/junyanz/BicycleGAN/tree/pytorch0.3.1) that supports PyTorch 0.1-0.3. 12 | 13 | 14 | 15 | **Toward Multimodal Image-to-Image Translation.** 16 | [Jun-Yan Zhu](https://www.cs.cmu.edu/~junyanz/), 17 | [Richard Zhang](https://richzhang.github.io/), [Deepak Pathak](http://people.eecs.berkeley.edu/~pathak/), [Trevor Darrell](https://people.eecs.berkeley.edu/~trevor/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/), [Oliver Wang](http://www.oliverwang.info/), [Eli Shechtman](https://research.adobe.com/person/eli-shechtman/). 18 | UC Berkeley and Adobe Research 19 | In Neural Information Processing Systems, 2017. 20 | 21 | ## Example results 22 | 23 | 24 | 25 | ## Other Implementations 26 | - [[Tensorflow]](https://github.com/gitlimlab/BicycleGAN-Tensorflow) by Youngwoon Lee (USC CLVR Lab). 27 | - [[Tensorflow]](https://github.com/kvmanohar22/img2imgGAN) by Kv Manohar. 28 | 29 | ## Prerequisites 30 | - Linux or macOS 31 | - Python 3 32 | - CPU or NVIDIA GPU + CUDA CuDNN 33 | 34 | 35 | ## Getting Started ### 36 | ### Installation 37 | - Clone this repo: 38 | ```bash 39 | git clone -b master --single-branch https://github.com/junyanz/BicycleGAN.git 40 | cd BicycleGAN 41 | ``` 42 | - Install PyTorch and dependencies from http://pytorch.org 43 | - Install python libraries [visdom](https://github.com/facebookresearch/visdom), [dominate](https://github.com/Knio/dominate), and [moviepy](https://github.com/Zulko/moviepy). 44 | 45 | For pip users: 46 | ```bash 47 | bash ./scripts/install_pip.sh 48 | ``` 49 | 50 | For conda users: 51 | ```bash 52 | bash ./scripts/install_conda.sh 53 | ``` 54 | 55 | 56 | ### Use a Pre-trained Model 57 | - Download some test photos (e.g., edges2shoes): 58 | ```bash 59 | bash ./datasets/download_testset.sh edges2shoes 60 | ``` 61 | - Download a pre-trained model (e.g., edges2shoes): 62 | ```bash 63 | bash ./pretrained_models/download_model.sh edges2shoes 64 | ``` 65 | 66 | - Generate results with the model 67 | ```bash 68 | bash ./scripts/test_edges2shoes.sh 69 | ``` 70 | The test results will be saved to a html file here: `./results/edges2shoes/val/index.html`. 71 | 72 | - Generate results with synchronized latent vectors 73 | ```bash 74 | bash ./scripts/test_edges2shoes.sh --sync 75 | ``` 76 | Results can be found at `./results/edges2shoes/val_sync/index.html`. 77 | 78 | ### Generate Morphing Videos 79 | - We can also produce a morphing video similar to this [GIF](imgs/day2night.gif) and Youtube [video](http://www.youtube.com/watch?v=JvGysD2EFhw&t=2m21s). 80 | ```bash 81 | bash ./scripts/video_edges2shoes.sh 82 | ``` 83 | Results can be found at `./videos/edges2shoes/`. 84 | 85 | ### Model Training 86 | - To train a model, download the training images (e.g., edges2shoes). 87 | ```bash 88 | bash ./datasets/download_dataset.sh edges2shoes 89 | ``` 90 | 91 | - Train a model: 92 | ```bash 93 | bash ./scripts/train_edges2shoes.sh 94 | ``` 95 | - To view training results and loss plots, run `python -m visdom.server` and click the URL http://localhost:8097. To see more intermediate results, check out `./checkpoints/edges2shoes_bicycle_gan/web/index.html` 96 | - See more training details for other datasets in `./scripts/train.sh`. 97 | 98 | ### Datasets (from pix2pix) 99 | Download the datasets using the following script. Many of the datasets are collected by other researchers. Please cite their papers if you use the data. 100 | - Download the testset. 101 | ```bash 102 | bash ./datasets/download_testset.sh dataset_name 103 | ``` 104 | - Download the training and testset. 105 | ```bash 106 | bash ./datasets/download_dataset.sh dataset_name 107 | ``` 108 | - `facades`: 400 images from [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade). [[Citation](datasets/bibtex/facades.tex)] 109 | - `maps`: 1096 training images scraped from Google Maps 110 | - `edges2shoes`: 50k training images from [UT Zappos50K dataset](http://vision.cs.utexas.edu/projects/finegrained/utzap50k). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. [[Citation](datasets/bibtex/shoes.tex)] 111 | - `edges2handbags`: 137K Amazon Handbag images from [iGAN project](https://github.com/junyanz/iGAN). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. [[Citation](datasets/bibtex/handbags.tex)] 112 | - `night2day`: around 20K natural scene images from [Transient Attributes dataset](http://transattr.cs.brown.edu/) [[Citation](datasets/bibtex/transattr.tex)] 113 | 114 | ## Models 115 | Download the pre-trained models with the following script. 116 | ```bash 117 | bash ./pretrained_models/download_model.sh model_name 118 | ``` 119 | - `edges2shoes` (edge -> photo) trained on UT Zappos50K dataset. 120 | - `edges2handbags` (edge -> photo) trained on Amazon handbags images.. 121 | ```bash 122 | bash ./pretrained_models/download_model.sh edges2handbags 123 | bash ./datasets/download_testset.sh edges2handbags 124 | bash ./scripts/test_edges2handbags.sh 125 | ``` 126 | - `night2day` (nighttime scene -> daytime scene) trained on around 100 [webcams](http://transattr.cs.brown.edu/). 127 | ```bash 128 | bash ./pretrained_models/download_model.sh night2day 129 | bash ./datasets/download_testset.sh night2day 130 | bash ./scripts/test_night2day.sh 131 | ``` 132 | - `facades` (facade label -> facade photo) trained on the CMP Facades dataset. 133 | ```bash 134 | bash ./pretrained_models/download_model.sh facades 135 | bash ./datasets/download_testset.sh facades 136 | bash ./scripts/test_facades.sh 137 | ``` 138 | - `maps` (map photo -> aerial photo) trained on 1096 training images scraped from Google Maps. 139 | ```bash 140 | bash ./pretrained_models/download_model.sh maps 141 | bash ./datasets/download_testset.sh maps 142 | bash ./scripts/test_maps.sh 143 | ``` 144 | 145 | ### Metrics 146 | 147 | Figure 6 shows realism vs diversity of our method. 148 | 149 | - **Realism** We use the Amazon Mechanical Turk (AMT) Real vs Fake test from [this repository](https://github.com/phillipi/AMT_Real_vs_Fake), first introduced in [this work](http://richzhang.github.io/colorization/). 150 | 151 | - **Diversity** For each input image, we produce 20 translations by randomly sampling 20 `z` vectors. We compute LPIPS distance between consecutive pairs to get 19 paired distances. You can compute this by putting the 20 images into a directory and using [this script](https://github.com/richzhang/PerceptualSimilarity/blob/master/compute_dists_pair.py) (note that we used version 0.0 rather than default 0.1, so use flag `-v 0.0`). This is done for 100 input images. This results in 1900 total distances (100 images X 19 paired distances each), which are averaged together. A larger number means higher diversity. 152 | 153 | ### Citation 154 | 155 | If you find this useful for your research, please use the following. 156 | 157 | ``` 158 | @inproceedings{zhu2017toward, 159 | title={Toward multimodal image-to-image translation}, 160 | author={Zhu, Jun-Yan and Zhang, Richard and Pathak, Deepak and Darrell, Trevor and Efros, Alexei A and Wang, Oliver and Shechtman, Eli}, 161 | booktitle={Advances in Neural Information Processing Systems}, 162 | year={2017} 163 | } 164 | 165 | ``` 166 | If you use modules from CycleGAN or pix2pix paper, please use the following: 167 | ``` 168 | @inproceedings{CycleGAN2017, 169 | title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networkss}, 170 | author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A}, 171 | booktitle={Computer Vision (ICCV), 2017 IEEE International Conference on}, 172 | year={2017} 173 | } 174 | 175 | 176 | @inproceedings{isola2017image, 177 | title={Image-to-Image Translation with Conditional Adversarial Networks}, 178 | author={Isola, Phillip and Zhu, Jun-Yan and Zhou, Tinghui and Efros, Alexei A}, 179 | booktitle={Computer Vision and Pattern Recognition (CVPR), 2017 IEEE Conference on}, 180 | year={2017} 181 | } 182 | ``` 183 | ### Acknowledgements 184 | 185 | This code borrows heavily from the [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) repository. 186 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | import torch.utils.data 15 | from data.base_dataset import BaseDataset 16 | 17 | 18 | def find_dataset_using_name(dataset_name): 19 | """Import the module "data/[dataset_name]_dataset.py". 20 | 21 | In the file, the class called DatasetNameDataset() will 22 | be instantiated. It has to be a subclass of BaseDataset, 23 | and it is case-insensitive. 24 | """ 25 | dataset_filename = "data." + dataset_name + "_dataset" 26 | datasetlib = importlib.import_module(dataset_filename) 27 | 28 | dataset = None 29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 30 | for name, cls in datasetlib.__dict__.items(): 31 | if name.lower() == target_dataset_name.lower() \ 32 | and issubclass(cls, BaseDataset): 33 | dataset = cls 34 | 35 | if dataset is None: 36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 37 | 38 | return dataset 39 | 40 | 41 | def get_option_setter(dataset_name): 42 | """Return the static method of the dataset class.""" 43 | dataset_class = find_dataset_using_name(dataset_name) 44 | return dataset_class.modify_commandline_options 45 | 46 | 47 | def create_dataset(opt): 48 | """Create a dataset given the option. 49 | 50 | This function wraps the class CustomDatasetDataLoader. 51 | This is the main interface between this package and 'train.py'/'test.py' 52 | 53 | Example: 54 | >>> from data import create_dataset 55 | >>> dataset = create_dataset(opt) 56 | """ 57 | data_loader = CustomDatasetDataLoader(opt) 58 | dataset = data_loader.load_data() 59 | return dataset 60 | 61 | 62 | class CustomDatasetDataLoader(): 63 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 64 | 65 | def __init__(self, opt): 66 | """Initialize this class 67 | 68 | Step 1: create a dataset instance given the name [dataset_mode] 69 | Step 2: create a multi-threaded data loader. 70 | """ 71 | self.opt = opt 72 | dataset_class = find_dataset_using_name(opt.dataset_mode) 73 | self.dataset = dataset_class(opt) 74 | print("dataset [%s] was created" % type(self.dataset).__name__) 75 | self.dataloader = torch.utils.data.DataLoader( 76 | self.dataset, 77 | batch_size=opt.batch_size, 78 | shuffle=not opt.serial_batches, 79 | num_workers=int(opt.num_threads)) 80 | 81 | def load_data(self): 82 | return self 83 | 84 | def __len__(self): 85 | """Return the number of data in the dataset""" 86 | return min(len(self.dataset), self.opt.max_dataset_size) 87 | 88 | def __iter__(self): 89 | """Return a batch of data""" 90 | for i, data in enumerate(self.dataloader): 91 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 92 | break 93 | yield data 94 | -------------------------------------------------------------------------------- /data/aligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_params, get_transform 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | 6 | 7 | class AlignedDataset(BaseDataset): 8 | """A dataset class for paired image dataset. 9 | 10 | It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}. 11 | During test time, you need to prepare a directory '/path/to/data/test'. 12 | """ 13 | 14 | def __init__(self, opt): 15 | """Initialize this dataset class. 16 | 17 | Parameters: 18 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 19 | """ 20 | BaseDataset.__init__(self, opt) 21 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory 22 | self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths 23 | assert(self.opt.load_size >= self.opt.crop_size) # crop_size should be smaller than the size of loaded image 24 | self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc 25 | self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc 26 | 27 | def __getitem__(self, index): 28 | """Return a data point and its metadata information. 29 | 30 | Parameters: 31 | index - - a random integer for data indexing 32 | 33 | Returns a dictionary that contains A, B, A_paths and B_paths 34 | A (tensor) - - an image in the input domain 35 | B (tensor) - - its corresponding image in the target domain 36 | A_paths (str) - - image paths 37 | B_paths (str) - - image paths (same as A_paths) 38 | """ 39 | # read a image given a random integer index 40 | AB_path = self.AB_paths[index] 41 | AB = Image.open(AB_path).convert('RGB') 42 | # split AB image into A and B 43 | w, h = AB.size 44 | w2 = int(w / 2) 45 | A = AB.crop((0, 0, w2, h)) 46 | B = AB.crop((w2, 0, w, h)) 47 | 48 | # apply the same transform to both A and B 49 | transform_params = get_params(self.opt, A.size) 50 | A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1)) 51 | B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1)) 52 | 53 | A = A_transform(A) 54 | B = B_transform(B) 55 | 56 | return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path} 57 | 58 | def __len__(self): 59 | """Return the total number of images in the dataset.""" 60 | return len(self.AB_paths) 61 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from abc import ABC, abstractmethod 11 | 12 | 13 | class BaseDataset(data.Dataset, ABC): 14 | """This class is an abstract base class (ABC) for datasets. 15 | 16 | To create a subclass, you need to implement the following four functions: 17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 18 | -- <__len__>: return the size of dataset. 19 | -- <__getitem__>: get a data point. 20 | -- : (optionally) add dataset-specific options and set default options. 21 | """ 22 | 23 | def __init__(self, opt): 24 | """Initialize the class; save the options in the class 25 | 26 | Parameters: 27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 28 | """ 29 | self.opt = opt 30 | self.root = opt.dataroot 31 | 32 | @staticmethod 33 | def modify_commandline_options(parser, is_train): 34 | """Add new dataset-specific options, and rewrite default values for existing options. 35 | 36 | Parameters: 37 | parser -- original option parser 38 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 39 | 40 | Returns: 41 | the modified parser. 42 | """ 43 | return parser 44 | 45 | @abstractmethod 46 | def __len__(self): 47 | """Return the total number of images in the dataset.""" 48 | return 0 49 | 50 | @abstractmethod 51 | def __getitem__(self, index): 52 | """Return a data point and its metadata information. 53 | 54 | Parameters: 55 | index - - a random integer for data indexing 56 | 57 | Returns: 58 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 59 | """ 60 | pass 61 | 62 | 63 | def get_params(opt, size): 64 | w, h = size 65 | new_h = h 66 | new_w = w 67 | if opt.preprocess == 'resize_and_crop': 68 | new_h = new_w = opt.load_size 69 | elif opt.preprocess == 'scale_width_and_crop': 70 | new_w = opt.load_size 71 | new_h = opt.load_size * h // w 72 | 73 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 74 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 75 | 76 | flip = random.random() > 0.5 77 | 78 | return {'crop_pos': (x, y), 'flip': flip} 79 | 80 | 81 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 82 | transform_list = [] 83 | if grayscale: 84 | transform_list.append(transforms.Grayscale(1)) 85 | if 'resize' in opt.preprocess: 86 | osize = [opt.load_size, opt.load_size] 87 | transform_list.append(transforms.Resize(osize, method)) 88 | elif 'scale_width' in opt.preprocess: 89 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 90 | 91 | if 'crop' in opt.preprocess: 92 | if params is None: 93 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 94 | else: 95 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 96 | 97 | if opt.preprocess == 'none': 98 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 99 | 100 | if not opt.no_flip: 101 | if params is None: 102 | transform_list.append(transforms.RandomHorizontalFlip()) 103 | elif params['flip']: 104 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 105 | 106 | if convert: 107 | transform_list += [transforms.ToTensor()] 108 | if grayscale: 109 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 110 | else: 111 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 112 | return transforms.Compose(transform_list) 113 | 114 | 115 | def __make_power_2(img, base, method=Image.BICUBIC): 116 | ow, oh = img.size 117 | h = int(round(oh / base) * base) 118 | w = int(round(ow / base) * base) 119 | if (h == oh) and (w == ow): 120 | return img 121 | 122 | __print_size_warning(ow, oh, w, h) 123 | return img.resize((w, h), method) 124 | 125 | 126 | def __scale_width(img, target_width, method=Image.BICUBIC): 127 | ow, oh = img.size 128 | if (ow == target_width): 129 | return img 130 | w = target_width 131 | h = int(target_width * oh / ow) 132 | return img.resize((w, h), method) 133 | 134 | 135 | def __crop(img, pos, size): 136 | ow, oh = img.size 137 | x1, y1 = pos 138 | tw = th = size 139 | if (ow > tw or oh > th): 140 | return img.crop((x1, y1, x1 + tw, y1 + th)) 141 | return img 142 | 143 | 144 | def __flip(img, flip): 145 | if flip: 146 | return img.transpose(Image.FLIP_LEFT_RIGHT) 147 | return img 148 | 149 | 150 | def __print_size_warning(ow, oh, w, h): 151 | """Print warning information about image size(only print once)""" 152 | if not hasattr(__print_size_warning, 'has_printed'): 153 | print("The image size needs to be a multiple of 4. " 154 | "The loaded image size was (%d, %d), so it was adjusted to " 155 | "(%d, %d). This adjustment will be done to all images " 156 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 157 | __print_size_warning.has_printed = True 158 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | """A modified image folder class 2 | 3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 4 | so that this class can load images from both current directory and its subdirectories. 5 | """ 6 | 7 | import torch.utils.data as data 8 | 9 | from PIL import Image 10 | import os 11 | import os.path 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 16 | ] 17 | 18 | 19 | def is_image_file(filename): 20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 21 | 22 | 23 | def make_dataset(dir, max_dataset_size=float("inf")): 24 | images = [] 25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 26 | 27 | for root, _, fnames in sorted(os.walk(dir)): 28 | for fname in fnames: 29 | if is_image_file(fname): 30 | path = os.path.join(root, fname) 31 | images.append(path) 32 | return images[:min(max_dataset_size, len(images))] 33 | 34 | 35 | def default_loader(path): 36 | return Image.open(path).convert('RGB') 37 | 38 | 39 | class ImageFolder(data.Dataset): 40 | 41 | def __init__(self, root, transform=None, return_paths=False, 42 | loader=default_loader): 43 | imgs = make_dataset(root) 44 | if len(imgs) == 0: 45 | raise(RuntimeError("Found 0 images in: " + root + "\n" 46 | "Supported image extensions are: " + 47 | ",".join(IMG_EXTENSIONS))) 48 | 49 | self.root = root 50 | self.imgs = imgs 51 | self.transform = transform 52 | self.return_paths = return_paths 53 | self.loader = loader 54 | 55 | def __getitem__(self, index): 56 | path = self.imgs[index] 57 | img = self.loader(path) 58 | if self.transform is not None: 59 | img = self.transform(img) 60 | if self.return_paths: 61 | return img, path 62 | else: 63 | return img 64 | 65 | def __len__(self): 66 | return len(self.imgs) 67 | -------------------------------------------------------------------------------- /data/single_dataset.py: -------------------------------------------------------------------------------- 1 | from data.base_dataset import BaseDataset, get_transform 2 | from data.image_folder import make_dataset 3 | from PIL import Image 4 | 5 | 6 | class SingleDataset(BaseDataset): 7 | """This dataset class can load a set of images specified by the path --dataroot /path/to/data. 8 | 9 | It can be used for generating CycleGAN results only for one side with the model option '-model test'. 10 | """ 11 | 12 | def __init__(self, opt): 13 | """Initialize this dataset class. 14 | 15 | Parameters: 16 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 17 | """ 18 | BaseDataset.__init__(self, opt) 19 | self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size)) 20 | input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc 21 | self.transform = get_transform(opt, grayscale=(input_nc == 1)) 22 | 23 | def __getitem__(self, index): 24 | """Return a data point and its metadata information. 25 | 26 | Parameters: 27 | index - - a random integer for data indexing 28 | 29 | Returns a dictionary that contains A and A_paths 30 | A(tensor) - - an image in one domain 31 | A_paths(str) - - the path of the image 32 | """ 33 | A_path = self.A_paths[index] 34 | A_img = Image.open(A_path).convert('RGB') 35 | A = self.transform(A_img) 36 | return {'A': A, 'A_paths': A_path} 37 | 38 | def __len__(self): 39 | """Return the total number of images in the dataset.""" 40 | return len(self.A_paths) 41 | -------------------------------------------------------------------------------- /data/template_dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset class template 2 | 3 | This module provides a template for users to implement custom datasets. 4 | You can specify '--dataset_mode template' to use this dataset. 5 | The class name should be consistent with both the filename and its dataset_mode option. 6 | The filename should be _dataset.py 7 | The class name should be Dataset.py 8 | You need to implement the following functions: 9 | -- : Add dataset-specific options and rewrite default values for existing options. 10 | -- <__init__>: Initialize this dataset class. 11 | -- <__getitem__>: Return a data point and its metadata information. 12 | -- <__len__>: Return the number of images. 13 | """ 14 | from data.base_dataset import BaseDataset, get_transform 15 | # from data.image_folder import make_dataset 16 | # from PIL import Image 17 | 18 | 19 | class TemplateDataset(BaseDataset): 20 | """A template dataset class for you to implement custom datasets.""" 21 | @staticmethod 22 | def modify_commandline_options(parser, is_train): 23 | """Add new dataset-specific options, and rewrite default values for existing options. 24 | 25 | Parameters: 26 | parser -- original option parser 27 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 28 | 29 | Returns: 30 | the modified parser. 31 | """ 32 | parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option') 33 | parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values 34 | return parser 35 | 36 | def __init__(self, opt): 37 | """Initialize this dataset class. 38 | 39 | Parameters: 40 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 41 | 42 | A few things can be done here. 43 | - save the options (have been done in BaseDataset) 44 | - get image paths and meta information of the dataset. 45 | - define the image transformation. 46 | """ 47 | # save the option and dataset root 48 | BaseDataset.__init__(self, opt) 49 | # get the image paths of your dataset; 50 | self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root 51 | # define the default transform function. You can use ; You can also define your custom transform function 52 | self.transform = get_transform(opt) 53 | 54 | def __getitem__(self, index): 55 | """Return a data point and its metadata information. 56 | 57 | Parameters: 58 | index -- a random integer for data indexing 59 | 60 | Returns: 61 | a dictionary of data with their names. It usually contains the data itself and its metadata information. 62 | 63 | Step 1: get a random image path: e.g., path = self.image_paths[index] 64 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). 65 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) 66 | Step 4: return a data point as a dictionary. 67 | """ 68 | path = 'temp' # needs to be a string 69 | data_A = None # needs to be a tensor 70 | data_B = None # needs to be a tensor 71 | return {'data_A': data_A, 'data_B': data_B, 'path': path} 72 | 73 | def __len__(self): 74 | """Return the total number of images.""" 75 | return len(self.image_paths) 76 | -------------------------------------------------------------------------------- /datasets/bibtex/cityscapes.tex: -------------------------------------------------------------------------------- 1 | @inproceedings{Cordts2016Cityscapes, 2 | title={The Cityscapes Dataset for Semantic Urban Scene Understanding}, 3 | author={Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt}, 4 | booktitle={Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 5 | year={2016} 6 | } 7 | -------------------------------------------------------------------------------- /datasets/bibtex/facades.tex: -------------------------------------------------------------------------------- 1 | @INPROCEEDINGS{Tylecek13, 2 | author = {Radim Tyle{\v c}ek, Radim {\v S}{\' a}ra}, 3 | title = {Spatial Pattern Templates for Recognition of Objects with Regular Structure}, 4 | booktitle = {Proc. GCPR}, 5 | year = {2013}, 6 | address = {Saarbrucken, Germany}, 7 | } 8 | -------------------------------------------------------------------------------- /datasets/bibtex/handbags.tex: -------------------------------------------------------------------------------- 1 | @inproceedings{zhu2016generative, 2 | title={Generative Visual Manipulation on the Natural Image Manifold}, 3 | author={Zhu, Jun-Yan and Kr{\"a}henb{\"u}hl, Philipp and Shechtman, Eli and Efros, Alexei A.}, 4 | booktitle={Proceedings of European Conference on Computer Vision (ECCV)}, 5 | year={2016} 6 | } 7 | 8 | @InProceedings{xie15hed, 9 | author = {"Xie, Saining and Tu, Zhuowen"}, 10 | Title = {Holistically-Nested Edge Detection}, 11 | Booktitle = "Proceedings of IEEE International Conference on Computer Vision", 12 | Year = {2015}, 13 | } 14 | -------------------------------------------------------------------------------- /datasets/bibtex/night2day.tex: -------------------------------------------------------------------------------- 1 | @article{laffont2014transient, 2 | title={Transient attributes for high-level understanding and editing of outdoor scenes}, 3 | author={Laffont, Pierre-Yves and Ren, Zhile and Tao, Xiaofeng and Qian, Chao and Hays, James}, 4 | journal={ACM Transactions on Graphics (TOG)}, 5 | volume={33}, 6 | number={4}, 7 | pages={149}, 8 | year={2014}, 9 | publisher={ACM} 10 | } 11 | -------------------------------------------------------------------------------- /datasets/bibtex/shoes.tex: -------------------------------------------------------------------------------- 1 | @InProceedings{fine-grained, 2 | author = {A. Yu and K. Grauman}, 3 | title = {{F}ine-{G}rained {V}isual {C}omparisons with {L}ocal {L}earning}, 4 | booktitle = {Computer Vision and Pattern Recognition (CVPR)}, 5 | month = {June}, 6 | year = {2014} 7 | } 8 | 9 | @InProceedings{xie15hed, 10 | author = {"Xie, Saining and Tu, Zhuowen"}, 11 | Title = {Holistically-Nested Edge Detection}, 12 | Booktitle = "Proceedings of IEEE International Conference on Computer Vision", 13 | Year = {2015}, 14 | } 15 | -------------------------------------------------------------------------------- /datasets/bibtex/transattr.tex: -------------------------------------------------------------------------------- 1 | @article {Laffont14, 2 | title = {Transient Attributes for High-Level Understanding and Editing of Outdoor Scenes}, 3 | author = {Pierre-Yves Laffont and Zhile Ren and Xiaofeng Tao and Chao Qian and James Hays}, 4 | journal = {ACM Transactions on Graphics (proceedings of SIGGRAPH)}, 5 | volume = {33}, 6 | number = {4}, 7 | year = {2014} 8 | } 9 | -------------------------------------------------------------------------------- /datasets/download_dataset.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | 3 | if [[ $FILE != "cityscapes" && $FILE != "night2day" && $FILE != "edges2handbags" && $FILE != "edges2shoes" && $FILE != "facades" && $FILE != "maps" ]]; then 4 | echo "Available datasets are cityscapes, night2day, edges2handbags, edges2shoes, facades, maps" 5 | exit 1 6 | fi 7 | 8 | echo "Specified [$FILE]" 9 | 10 | URL=http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/$FILE.tar.gz 11 | TAR_FILE=./datasets/$FILE.tar.gz 12 | TARGET_DIR=./datasets/$FILE/ 13 | wget -N $URL -O $TAR_FILE 14 | mkdir -p $TARGET_DIR 15 | tar -zxvf $TAR_FILE -C ./datasets/ 16 | rm $TAR_FILE 17 | -------------------------------------------------------------------------------- /datasets/download_mini_dataset.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | echo "Specified [$FILE]" 3 | URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip 4 | ZIP_FILE=./datasets/$FILE.zip 5 | TARGET_DIR=./datasets/$FILE/ 6 | wget -N $URL -O $ZIP_FILE 7 | mkdir $TARGET_DIR 8 | unzip $ZIP_FILE -d ./datasets/ 9 | rm $ZIP_FILE 10 | -------------------------------------------------------------------------------- /datasets/download_testset.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | URL=http://efrosgans.eecs.berkeley.edu/BicycleGAN/testset/${FILE}.tar.gz 3 | TAR_FILE=./datasets/$FILE.tar.gz 4 | TARGET_DIR=./datasets/$FILE/ 5 | wget -N $URL -O $TAR_FILE 6 | mkdir $TARGET_DIR 7 | tar -zxvf $TAR_FILE -C ./datasets/ 8 | rm $TAR_FILE 9 | -------------------------------------------------------------------------------- /imgs/day2night.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyanz/BicycleGAN/40b9d52c27b9831f56c1c7c7a6ddde8bc9149067/imgs/day2night.gif -------------------------------------------------------------------------------- /imgs/results_matrix.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyanz/BicycleGAN/40b9d52c27b9831f56c1c7c7a6ddde8bc9149067/imgs/results_matrix.jpg -------------------------------------------------------------------------------- /imgs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyanz/BicycleGAN/40b9d52c27b9831f56c1c7c7a6ddde8bc9149067/imgs/teaser.jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 3 | You need to implement the following five functions: 4 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 5 | -- : unpack data from dataset and apply preprocessing. 6 | -- : produce intermediate results. 7 | -- : calculate loss, gradients, and update network weights. 8 | -- : (optionally) add model-specific options and set default options. 9 | In the function <__init__>, you need to define four lists: 10 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 11 | -- self.model_names (str list): specify the images that you want to display and save. 12 | -- self.visual_names (str list): define networks used in our training. 13 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 14 | Now you can use the model class by specifying flag '--model dummy'. 15 | See our template model class 'template_model.py' for an example. 16 | """ 17 | 18 | import importlib 19 | from models.base_model import BaseModel 20 | 21 | 22 | def find_model_using_name(model_name): 23 | """Import the module "models/[model_name]_model.py". 24 | In the file, the class called DatasetNameModel() will 25 | be instantiated. It has to be a subclass of BaseModel, 26 | and it is case-insensitive. 27 | """ 28 | model_filename = "models." + model_name + "_model" 29 | modellib = importlib.import_module(model_filename) 30 | model = None 31 | target_model_name = model_name.replace('_', '') + 'model' 32 | for name, cls in modellib.__dict__.items(): 33 | if name.lower() == target_model_name.lower() \ 34 | and issubclass(cls, BaseModel): 35 | model = cls 36 | 37 | if model is None: 38 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 39 | exit(0) 40 | 41 | return model 42 | 43 | 44 | def get_option_setter(model_name): 45 | """Return the static method of the model class.""" 46 | model_class = find_model_using_name(model_name) 47 | return model_class.modify_commandline_options 48 | 49 | 50 | def create_model(opt): 51 | """Create a model given the option. 52 | This function warps the class CustomDatasetDataLoader. 53 | This is the main interface between this package and 'train.py'/'test.py' 54 | Example: 55 | >>> from models import create_model 56 | >>> model = create_model(opt) 57 | """ 58 | model = find_model_using_name(opt.model) 59 | instance = model(opt) 60 | print("model [%s] was created" % type(instance).__name__) 61 | return instance 62 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from abc import ABC, abstractmethod 5 | from . import networks 6 | 7 | 8 | class BaseModel(ABC): 9 | """This class is an abstract base class (ABC) for models. 10 | To create a subclass, you need to implement the following five functions: 11 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 12 | -- : unpack data from dataset and apply preprocessing. 13 | -- : produce intermediate results. 14 | -- : calculate losses, gradients, and update network weights. 15 | -- : (optionally) add model-specific options and set default options. 16 | """ 17 | 18 | def __init__(self, opt): 19 | """Initialize the BaseModel class. 20 | 21 | Parameters: 22 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 23 | 24 | When creating your custom class, you need to implement your own initialization. 25 | In this fucntion, you should first call `BaseModel.__init__(self, opt)` 26 | Then, you need to define four lists: 27 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 28 | -- self.model_names (str list): specify the images that you want to display and save. 29 | -- self.visual_names (str list): define networks used in our training. 30 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 31 | """ 32 | self.opt = opt 33 | self.gpu_ids = opt.gpu_ids 34 | self.isTrain = opt.isTrain 35 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU 36 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir 37 | if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. 38 | torch.backends.cudnn.benchmark = True 39 | self.loss_names = [] 40 | self.model_names = [] 41 | self.visual_names = [] 42 | self.optimizers = [] 43 | self.image_paths = [] 44 | 45 | @staticmethod 46 | def modify_commandline_options(parser, is_train): 47 | """Add new model-specific options, and rewrite default values for existing options. 48 | 49 | Parameters: 50 | parser -- original option parser 51 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 52 | 53 | Returns: 54 | the modified parser. 55 | """ 56 | return parser 57 | 58 | @abstractmethod 59 | def set_input(self, input): 60 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 61 | 62 | Parameters: 63 | input (dict): includes the data itself and its metadata information. 64 | """ 65 | pass 66 | 67 | @abstractmethod 68 | def forward(self): 69 | """Run forward pass; called by both functions and .""" 70 | pass 71 | 72 | def is_train(self): 73 | """check if the current batch is good for training.""" 74 | return True 75 | 76 | @abstractmethod 77 | def optimize_parameters(self): 78 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 79 | pass 80 | 81 | def setup(self, opt): 82 | """Load and print networks; create schedulers 83 | 84 | Parameters: 85 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 86 | """ 87 | if self.isTrain: 88 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 89 | if not self.isTrain or opt.continue_train: 90 | self.load_networks(opt.epoch) 91 | self.print_networks(opt.verbose) 92 | 93 | def eval(self): 94 | """Make models eval mode during test time""" 95 | for name in self.model_names: 96 | if isinstance(name, str): 97 | net = getattr(self, 'net' + name) 98 | net.eval() 99 | 100 | def test(self): 101 | """Forward function used in test time. 102 | 103 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 104 | It also calls to produce additional visualization results 105 | """ 106 | with torch.no_grad(): 107 | self.forward() 108 | self.compute_visuals() 109 | 110 | def compute_visuals(self): 111 | """Calculate additional output images for visdom and HTML visualization""" 112 | pass 113 | 114 | def get_image_paths(self): 115 | """ Return image paths that are used to load current data""" 116 | return self.image_paths 117 | 118 | def update_learning_rate(self): 119 | """Update learning rates for all the networks; called at the end of every epoch""" 120 | for scheduler in self.schedulers: 121 | scheduler.step() 122 | lr = self.optimizers[0].param_groups[0]['lr'] 123 | print('learning rate = %.7f' % lr) 124 | 125 | def get_current_visuals(self): 126 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 127 | visual_ret = OrderedDict() 128 | for name in self.visual_names: 129 | if isinstance(name, str): 130 | visual_ret[name] = getattr(self, name) 131 | return visual_ret 132 | 133 | def get_current_losses(self): 134 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 135 | errors_ret = OrderedDict() 136 | for name in self.loss_names: 137 | if isinstance(name, str): 138 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number 139 | return errors_ret 140 | 141 | def save_networks(self, epoch): 142 | """Save all the networks to the disk. 143 | 144 | Parameters: 145 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 146 | """ 147 | for name in self.model_names: 148 | if isinstance(name, str): 149 | save_filename = '%s_net_%s.pth' % (epoch, name) 150 | save_path = os.path.join(self.save_dir, save_filename) 151 | net = getattr(self, 'net' + name) 152 | 153 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 154 | torch.save(net.module.cpu().state_dict(), save_path) 155 | net.cuda(self.gpu_ids[0]) 156 | else: 157 | torch.save(net.cpu().state_dict(), save_path) 158 | 159 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 160 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" 161 | key = keys[i] 162 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 163 | if module.__class__.__name__.startswith('InstanceNorm') and \ 164 | (key == 'running_mean' or key == 'running_var'): 165 | if getattr(module, key) is None: 166 | state_dict.pop('.'.join(keys)) 167 | if module.__class__.__name__.startswith('InstanceNorm') and \ 168 | (key == 'num_batches_tracked'): 169 | state_dict.pop('.'.join(keys)) 170 | else: 171 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 172 | 173 | def load_networks(self, epoch): 174 | """Load all the networks from the disk. 175 | 176 | Parameters: 177 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 178 | """ 179 | for name in self.model_names: 180 | if isinstance(name, str): 181 | load_filename = '%s_net_%s.pth' % (epoch, name) 182 | load_path = os.path.join(self.save_dir, load_filename) 183 | net = getattr(self, 'net' + name) 184 | if isinstance(net, torch.nn.DataParallel): 185 | net = net.module 186 | print('loading the model from %s' % load_path) 187 | # if you are using PyTorch newer than 0.4 (e.g., built from 188 | # GitHub source), you can remove str() on self.device 189 | state_dict = torch.load(load_path, map_location=str(self.device)) 190 | if hasattr(state_dict, '_metadata'): 191 | del state_dict._metadata 192 | 193 | # patch InstanceNorm checkpoints prior to 0.4 194 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 195 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 196 | net.load_state_dict(state_dict) 197 | 198 | def print_networks(self, verbose): 199 | """Print the total number of parameters in the network and (if verbose) network architecture 200 | 201 | Parameters: 202 | verbose (bool) -- if verbose: print the network architecture 203 | """ 204 | print('---------- Networks initialized -------------') 205 | for name in self.model_names: 206 | if isinstance(name, str): 207 | net = getattr(self, 'net' + name) 208 | num_params = 0 209 | for param in net.parameters(): 210 | num_params += param.numel() 211 | if verbose: 212 | print(net) 213 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 214 | print('-----------------------------------------------') 215 | 216 | def set_requires_grad(self, nets, requires_grad=False): 217 | """Set requires_grad=False for all the networks to avoid unnecessary computations 218 | Parameters: 219 | nets (network list) -- a list of networks 220 | requires_grad (bool) -- whether the networks require gradients or not 221 | """ 222 | if not isinstance(nets, list): 223 | nets = [nets] 224 | for net in nets: 225 | if net is not None: 226 | for param in net.parameters(): 227 | param.requires_grad = requires_grad 228 | -------------------------------------------------------------------------------- /models/bicycle_gan_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks 4 | 5 | 6 | class BiCycleGANModel(BaseModel): 7 | @staticmethod 8 | def modify_commandline_options(parser, is_train=True): 9 | return parser 10 | 11 | def __init__(self, opt): 12 | if opt.isTrain: 13 | assert opt.batch_size % 2 == 0 # load two images at one time. 14 | 15 | BaseModel.__init__(self, opt) 16 | # specify the training losses you want to print out. The program will call base_model.get_current_losses 17 | self.loss_names = ['G_GAN', 'D', 'G_GAN2', 'D2', 'G_L1', 'z_L1', 'kl'] 18 | # specify the images you want to save/display. The program will call base_model.get_current_visuals 19 | self.visual_names = ['real_A_encoded', 'real_B_encoded', 'fake_B_random', 'fake_B_encoded'] 20 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks 21 | use_D = opt.isTrain and opt.lambda_GAN > 0.0 22 | use_D2 = opt.isTrain and opt.lambda_GAN2 > 0.0 and not opt.use_same_D 23 | use_E = opt.isTrain or not opt.no_encode 24 | use_vae = True 25 | self.model_names = ['G'] 26 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.nz, opt.ngf, netG=opt.netG, 27 | norm=opt.norm, nl=opt.nl, use_dropout=opt.use_dropout, init_type=opt.init_type, init_gain=opt.init_gain, 28 | gpu_ids=self.gpu_ids, where_add=opt.where_add, upsample=opt.upsample) 29 | D_output_nc = opt.input_nc + opt.output_nc if opt.conditional_D else opt.output_nc 30 | if use_D: 31 | self.model_names += ['D'] 32 | self.netD = networks.define_D(D_output_nc, opt.ndf, netD=opt.netD, norm=opt.norm, nl=opt.nl, 33 | init_type=opt.init_type, init_gain=opt.init_gain, num_Ds=opt.num_Ds, gpu_ids=self.gpu_ids) 34 | if use_D2: 35 | self.model_names += ['D2'] 36 | self.netD2 = networks.define_D(D_output_nc, opt.ndf, netD=opt.netD2, norm=opt.norm, nl=opt.nl, 37 | init_type=opt.init_type, init_gain=opt.init_gain, num_Ds=opt.num_Ds, gpu_ids=self.gpu_ids) 38 | else: 39 | self.netD2 = None 40 | if use_E: 41 | self.model_names += ['E'] 42 | self.netE = networks.define_E(opt.output_nc, opt.nz, opt.nef, netE=opt.netE, norm=opt.norm, nl=opt.nl, 43 | init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids, vaeLike=use_vae) 44 | 45 | if opt.isTrain: 46 | self.criterionGAN = networks.GANLoss(gan_mode=opt.gan_mode).to(self.device) 47 | self.criterionL1 = torch.nn.L1Loss() 48 | self.criterionZ = torch.nn.L1Loss() 49 | # initialize optimizers 50 | self.optimizers = [] 51 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 52 | self.optimizers.append(self.optimizer_G) 53 | if use_E: 54 | self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 55 | self.optimizers.append(self.optimizer_E) 56 | 57 | if use_D: 58 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 59 | self.optimizers.append(self.optimizer_D) 60 | if use_D2: 61 | self.optimizer_D2 = torch.optim.Adam(self.netD2.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 62 | self.optimizers.append(self.optimizer_D2) 63 | 64 | def is_train(self): 65 | """check if the current batch is good for training.""" 66 | return self.opt.isTrain and self.real_A.size(0) == self.opt.batch_size 67 | 68 | def set_input(self, input): 69 | AtoB = self.opt.direction == 'AtoB' 70 | self.real_A = input['A' if AtoB else 'B'].to(self.device) 71 | self.real_B = input['B' if AtoB else 'A'].to(self.device) 72 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] 73 | 74 | def get_z_random(self, batch_size, nz, random_type='gauss'): 75 | if random_type == 'uni': 76 | z = torch.rand(batch_size, nz) * 2.0 - 1.0 77 | elif random_type == 'gauss': 78 | z = torch.randn(batch_size, nz) 79 | return z.detach().to(self.device) 80 | 81 | def encode(self, input_image): 82 | mu, logvar = self.netE.forward(input_image) 83 | std = logvar.mul(0.5).exp_() 84 | eps = self.get_z_random(std.size(0), std.size(1)) 85 | z = eps.mul(std).add_(mu) 86 | return z, mu, logvar 87 | 88 | def test(self, z0=None, encode=False): 89 | with torch.no_grad(): 90 | if encode: # use encoded z 91 | z0, _ = self.netE(self.real_B) 92 | if z0 is None: 93 | z0 = self.get_z_random(self.real_A.size(0), self.opt.nz) 94 | self.fake_B = self.netG(self.real_A, z0) 95 | return self.real_A, self.fake_B, self.real_B 96 | 97 | def forward(self): 98 | # get real images 99 | half_size = self.opt.batch_size // 2 100 | # A1, B1 for encoded; A2, B2 for random 101 | self.real_A_encoded = self.real_A[0:half_size] 102 | self.real_B_encoded = self.real_B[0:half_size] 103 | self.real_A_random = self.real_A[half_size:] 104 | self.real_B_random = self.real_B[half_size:] 105 | # get encoded z 106 | self.z_encoded, self.mu, self.logvar = self.encode(self.real_B_encoded) 107 | # get random z 108 | self.z_random = self.get_z_random(self.real_A_encoded.size(0), self.opt.nz) 109 | # generate fake_B_encoded 110 | self.fake_B_encoded = self.netG(self.real_A_encoded, self.z_encoded) 111 | # generate fake_B_random 112 | self.fake_B_random = self.netG(self.real_A_encoded, self.z_random) 113 | if self.opt.conditional_D: # tedious conditoinal data 114 | self.fake_data_encoded = torch.cat([self.real_A_encoded, self.fake_B_encoded], 1) 115 | self.real_data_encoded = torch.cat([self.real_A_encoded, self.real_B_encoded], 1) 116 | self.fake_data_random = torch.cat([self.real_A_encoded, self.fake_B_random], 1) 117 | self.real_data_random = torch.cat([self.real_A_random, self.real_B_random], 1) 118 | else: 119 | self.fake_data_encoded = self.fake_B_encoded 120 | self.fake_data_random = self.fake_B_random 121 | self.real_data_encoded = self.real_B_encoded 122 | self.real_data_random = self.real_B_random 123 | 124 | # compute z_predict 125 | if self.opt.lambda_z > 0.0: 126 | self.mu2, logvar2 = self.netE(self.fake_B_random) # mu2 is a point estimate 127 | 128 | def backward_D(self, netD, real, fake): 129 | # Fake, stop backprop to the generator by detaching fake_B 130 | pred_fake = netD(fake.detach()) 131 | # real 132 | pred_real = netD(real) 133 | loss_D_fake, _ = self.criterionGAN(pred_fake, False) 134 | loss_D_real, _ = self.criterionGAN(pred_real, True) 135 | # Combined loss 136 | loss_D = loss_D_fake + loss_D_real 137 | loss_D.backward() 138 | return loss_D, [loss_D_fake, loss_D_real] 139 | 140 | def backward_G_GAN(self, fake, netD=None, ll=0.0): 141 | if ll > 0.0: 142 | pred_fake = netD(fake) 143 | loss_G_GAN, _ = self.criterionGAN(pred_fake, True) 144 | else: 145 | loss_G_GAN = 0 146 | return loss_G_GAN * ll 147 | 148 | def backward_EG(self): 149 | # 1, G(A) should fool D 150 | self.loss_G_GAN = self.backward_G_GAN(self.fake_data_encoded, self.netD, self.opt.lambda_GAN) 151 | if self.opt.use_same_D: 152 | self.loss_G_GAN2 = self.backward_G_GAN(self.fake_data_random, self.netD, self.opt.lambda_GAN2) 153 | else: 154 | self.loss_G_GAN2 = self.backward_G_GAN(self.fake_data_random, self.netD2, self.opt.lambda_GAN2) 155 | # 2. KL loss 156 | if self.opt.lambda_kl > 0.0: 157 | self.loss_kl = torch.sum(1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) * (-0.5 * self.opt.lambda_kl) 158 | else: 159 | self.loss_kl = 0 160 | # 3, reconstruction |fake_B-real_B| 161 | if self.opt.lambda_L1 > 0.0: 162 | self.loss_G_L1 = self.criterionL1(self.fake_B_encoded, self.real_B_encoded) * self.opt.lambda_L1 163 | else: 164 | self.loss_G_L1 = 0.0 165 | 166 | self.loss_G = self.loss_G_GAN + self.loss_G_GAN2 + self.loss_G_L1 + self.loss_kl 167 | self.loss_G.backward(retain_graph=True) 168 | 169 | def update_D(self): 170 | self.set_requires_grad([self.netD, self.netD2], True) 171 | # update D1 172 | if self.opt.lambda_GAN > 0.0: 173 | self.optimizer_D.zero_grad() 174 | self.loss_D, self.losses_D = self.backward_D(self.netD, self.real_data_encoded, self.fake_data_encoded) 175 | if self.opt.use_same_D: 176 | self.loss_D2, self.losses_D2 = self.backward_D(self.netD, self.real_data_random, self.fake_data_random) 177 | self.optimizer_D.step() 178 | 179 | if self.opt.lambda_GAN2 > 0.0 and not self.opt.use_same_D: 180 | self.optimizer_D2.zero_grad() 181 | self.loss_D2, self.losses_D2 = self.backward_D(self.netD2, self.real_data_random, self.fake_data_random) 182 | self.optimizer_D2.step() 183 | 184 | def backward_G_alone(self): 185 | # 3, reconstruction |(E(G(A, z_random)))-z_random| 186 | if self.opt.lambda_z > 0.0: 187 | self.loss_z_L1 = self.criterionZ(self.mu2, self.z_random) * self.opt.lambda_z 188 | self.loss_z_L1.backward() 189 | else: 190 | self.loss_z_L1 = 0.0 191 | 192 | def update_G_and_E(self): 193 | # update G and E 194 | self.set_requires_grad([self.netD, self.netD2], False) 195 | self.optimizer_E.zero_grad() 196 | self.optimizer_G.zero_grad() 197 | self.backward_EG() 198 | 199 | # update G alone 200 | if self.opt.lambda_z > 0.0: 201 | self.set_requires_grad([self.netE], False) 202 | self.backward_G_alone() 203 | self.set_requires_grad([self.netE], True) 204 | 205 | self.optimizer_E.step() 206 | self.optimizer_G.step() 207 | 208 | def optimize_parameters(self): 209 | self.forward() 210 | self.update_G_and_E() 211 | self.update_D() 212 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | 7 | ############################################################################### 8 | # Helper functions 9 | ############################################################################### 10 | 11 | 12 | def init_weights(net, init_type='normal', init_gain=0.02): 13 | """Initialize network weights. 14 | Parameters: 15 | net (network) -- network to be initialized 16 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 17 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 18 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 19 | work better for some applications. Feel free to try yourself. 20 | """ 21 | def init_func(m): # define the initialization function 22 | classname = m.__class__.__name__ 23 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 24 | if init_type == 'normal': 25 | init.normal_(m.weight.data, 0.0, init_gain) 26 | elif init_type == 'xavier': 27 | init.xavier_normal_(m.weight.data, gain=init_gain) 28 | elif init_type == 'kaiming': 29 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 30 | elif init_type == 'orthogonal': 31 | init.orthogonal_(m.weight.data, gain=init_gain) 32 | else: 33 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 34 | if hasattr(m, 'bias') and m.bias is not None: 35 | init.constant_(m.bias.data, 0.0) 36 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 37 | init.normal_(m.weight.data, 1.0, init_gain) 38 | init.constant_(m.bias.data, 0.0) 39 | 40 | print('initialize network with %s' % init_type) 41 | net.apply(init_func) # apply the initialization function 42 | 43 | 44 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 45 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 46 | Parameters: 47 | net (network) -- the network to be initialized 48 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 49 | gain (float) -- scaling factor for normal, xavier and orthogonal. 50 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 51 | Return an initialized network. 52 | """ 53 | if len(gpu_ids) > 0: 54 | assert(torch.cuda.is_available()) 55 | net.to(gpu_ids[0]) 56 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 57 | init_weights(net, init_type, init_gain=init_gain) 58 | return net 59 | 60 | 61 | def get_scheduler(optimizer, opt): 62 | """Return a learning rate scheduler 63 | Parameters: 64 | optimizer -- the optimizer of the network 65 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  66 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 67 | For 'linear', we keep the same learning rate for the first epochs 68 | and linearly decay the rate to zero over the next epochs. 69 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 70 | See https://pytorch.org/docs/stable/optim.html for more details. 71 | """ 72 | if opt.lr_policy == 'linear': 73 | def lambda_rule(epoch): 74 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 75 | return lr_l 76 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 77 | elif opt.lr_policy == 'step': 78 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 79 | elif opt.lr_policy == 'plateau': 80 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 81 | elif opt.lr_policy == 'cosine': 82 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 83 | else: 84 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 85 | return scheduler 86 | 87 | 88 | def get_norm_layer(norm_type='instance'): 89 | """Return a normalization layer 90 | Parameters: 91 | norm_type (str) -- the name of the normalization layer: batch | instance | none 92 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 93 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 94 | """ 95 | if norm_type == 'batch': 96 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 97 | elif norm_type == 'instance': 98 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 99 | elif norm_type == 'none': 100 | norm_layer = None 101 | else: 102 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 103 | return norm_layer 104 | 105 | 106 | def get_non_linearity(layer_type='relu'): 107 | if layer_type == 'relu': 108 | nl_layer = functools.partial(nn.ReLU, inplace=True) 109 | elif layer_type == 'lrelu': 110 | nl_layer = functools.partial( 111 | nn.LeakyReLU, negative_slope=0.2, inplace=True) 112 | elif layer_type == 'elu': 113 | nl_layer = functools.partial(nn.ELU, inplace=True) 114 | else: 115 | raise NotImplementedError( 116 | 'nonlinearity activitation [%s] is not found' % layer_type) 117 | return nl_layer 118 | 119 | 120 | def define_G(input_nc, output_nc, nz, ngf, netG='unet_128', norm='batch', nl='relu', 121 | use_dropout=False, init_type='xavier', init_gain=0.02, gpu_ids=[], where_add='input', upsample='bilinear'): 122 | net = None 123 | norm_layer = get_norm_layer(norm_type=norm) 124 | nl_layer = get_non_linearity(layer_type=nl) 125 | 126 | if nz == 0: 127 | where_add = 'input' 128 | 129 | if netG == 'unet_128' and where_add == 'input': 130 | net = G_Unet_add_input(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, 131 | use_dropout=use_dropout, upsample=upsample) 132 | elif netG == 'unet_256' and where_add == 'input': 133 | net = G_Unet_add_input(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, 134 | use_dropout=use_dropout, upsample=upsample) 135 | elif netG == 'unet_128' and where_add == 'all': 136 | net = G_Unet_add_all(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, 137 | use_dropout=use_dropout, upsample=upsample) 138 | elif netG == 'unet_256' and where_add == 'all': 139 | net = G_Unet_add_all(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, 140 | use_dropout=use_dropout, upsample=upsample) 141 | else: 142 | raise NotImplementedError('Generator model name [%s] is not recognized' % net) 143 | 144 | return init_net(net, init_type, init_gain, gpu_ids) 145 | 146 | 147 | def define_D(input_nc, ndf, netD, norm='batch', nl='lrelu', init_type='xavier', init_gain=0.02, num_Ds=1, gpu_ids=[]): 148 | net = None 149 | norm_layer = get_norm_layer(norm_type=norm) 150 | nl = 'lrelu' # use leaky relu for D 151 | nl_layer = get_non_linearity(layer_type=nl) 152 | 153 | if netD == 'basic_128': 154 | net = D_NLayers(input_nc, ndf, n_layers=2, norm_layer=norm_layer, nl_layer=nl_layer) 155 | elif netD == 'basic_256': 156 | net = D_NLayers(input_nc, ndf, n_layers=3, norm_layer=norm_layer, nl_layer=nl_layer) 157 | elif netD == 'basic_128_multi': 158 | net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=2, norm_layer=norm_layer, num_D=num_Ds) 159 | elif netD == 'basic_256_multi': 160 | net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=3, norm_layer=norm_layer, num_D=num_Ds) 161 | else: 162 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % net) 163 | return init_net(net, init_type, init_gain, gpu_ids) 164 | 165 | 166 | def define_E(input_nc, output_nc, ndf, netE, 167 | norm='batch', nl='lrelu', 168 | init_type='xavier', init_gain=0.02, gpu_ids=[], vaeLike=False): 169 | net = None 170 | norm_layer = get_norm_layer(norm_type=norm) 171 | nl = 'lrelu' # use leaky relu for E 172 | nl_layer = get_non_linearity(layer_type=nl) 173 | if netE == 'resnet_128': 174 | net = E_ResNet(input_nc, output_nc, ndf, n_blocks=4, norm_layer=norm_layer, 175 | nl_layer=nl_layer, vaeLike=vaeLike) 176 | elif netE == 'resnet_256': 177 | net = E_ResNet(input_nc, output_nc, ndf, n_blocks=5, norm_layer=norm_layer, 178 | nl_layer=nl_layer, vaeLike=vaeLike) 179 | elif netE == 'conv_128': 180 | net = E_NLayers(input_nc, output_nc, ndf, n_layers=4, norm_layer=norm_layer, 181 | nl_layer=nl_layer, vaeLike=vaeLike) 182 | elif netE == 'conv_256': 183 | net = E_NLayers(input_nc, output_nc, ndf, n_layers=5, norm_layer=norm_layer, 184 | nl_layer=nl_layer, vaeLike=vaeLike) 185 | else: 186 | raise NotImplementedError('Encoder model name [%s] is not recognized' % net) 187 | 188 | return init_net(net, init_type, init_gain, gpu_ids) 189 | 190 | 191 | class D_NLayersMulti(nn.Module): 192 | def __init__(self, input_nc, ndf=64, n_layers=3, 193 | norm_layer=nn.BatchNorm2d, num_D=1): 194 | super(D_NLayersMulti, self).__init__() 195 | # st() 196 | self.num_D = num_D 197 | if num_D == 1: 198 | layers = self.get_layers(input_nc, ndf, n_layers, norm_layer) 199 | self.model = nn.Sequential(*layers) 200 | else: 201 | layers = self.get_layers(input_nc, ndf, n_layers, norm_layer) 202 | self.add_module("model_0", nn.Sequential(*layers)) 203 | self.down = nn.AvgPool2d(3, stride=2, padding=[ 204 | 1, 1], count_include_pad=False) 205 | for i in range(1, num_D): 206 | ndf_i = int(round(ndf / (2**i))) 207 | layers = self.get_layers(input_nc, ndf_i, n_layers, norm_layer) 208 | self.add_module("model_%d" % i, nn.Sequential(*layers)) 209 | 210 | def get_layers(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 211 | kw = 4 212 | padw = 1 213 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, 214 | stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 215 | 216 | nf_mult = 1 217 | nf_mult_prev = 1 218 | for n in range(1, n_layers): 219 | nf_mult_prev = nf_mult 220 | nf_mult = min(2**n, 8) 221 | sequence += [ 222 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 223 | kernel_size=kw, stride=2, padding=padw), 224 | norm_layer(ndf * nf_mult), 225 | nn.LeakyReLU(0.2, True) 226 | ] 227 | 228 | nf_mult_prev = nf_mult 229 | nf_mult = min(2**n_layers, 8) 230 | sequence += [ 231 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 232 | kernel_size=kw, stride=1, padding=padw), 233 | norm_layer(ndf * nf_mult), 234 | nn.LeakyReLU(0.2, True) 235 | ] 236 | 237 | sequence += [nn.Conv2d(ndf * nf_mult, 1, 238 | kernel_size=kw, stride=1, padding=padw)] 239 | 240 | return sequence 241 | 242 | def forward(self, input): 243 | if self.num_D == 1: 244 | return self.model(input) 245 | result = [] 246 | down = input 247 | for i in range(self.num_D): 248 | model = getattr(self, "model_%d" % i) 249 | result.append(model(down)) 250 | if i != self.num_D - 1: 251 | down = self.down(down) 252 | return result 253 | 254 | 255 | class D_NLayers(nn.Module): 256 | """Defines a PatchGAN discriminator""" 257 | 258 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 259 | """Construct a PatchGAN discriminator 260 | Parameters: 261 | input_nc (int) -- the number of channels in input images 262 | ndf (int) -- the number of filters in the last conv layer 263 | n_layers (int) -- the number of conv layers in the discriminator 264 | norm_layer -- normalization layer 265 | """ 266 | super(D_NLayers, self).__init__() 267 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 268 | use_bias = norm_layer.func != nn.BatchNorm2d 269 | else: 270 | use_bias = norm_layer != nn.BatchNorm2d 271 | 272 | kw = 4 273 | padw = 1 274 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 275 | nf_mult = 1 276 | nf_mult_prev = 1 277 | for n in range(1, n_layers): # gradually increase the number of filters 278 | nf_mult_prev = nf_mult 279 | nf_mult = min(2 ** n, 8) 280 | sequence += [ 281 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 282 | norm_layer(ndf * nf_mult), 283 | nn.LeakyReLU(0.2, True) 284 | ] 285 | 286 | nf_mult_prev = nf_mult 287 | nf_mult = min(2 ** n_layers, 8) 288 | sequence += [ 289 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 290 | norm_layer(ndf * nf_mult), 291 | nn.LeakyReLU(0.2, True) 292 | ] 293 | 294 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 295 | self.model = nn.Sequential(*sequence) 296 | 297 | def forward(self, input): 298 | """Standard forward.""" 299 | return self.model(input) 300 | 301 | 302 | ############################################################################## 303 | # Classes 304 | ############################################################################## 305 | class RecLoss(nn.Module): 306 | def __init__(self, use_L2=True): 307 | super(RecLoss, self).__init__() 308 | self.use_L2 = use_L2 309 | 310 | def __call__(self, input, target, batch_mean=True): 311 | if self.use_L2: 312 | diff = (input - target) ** 2 313 | else: 314 | diff = torch.abs(input - target) 315 | if batch_mean: 316 | return torch.mean(diff) 317 | else: 318 | return torch.mean(torch.mean(torch.mean(diff, dim=1), dim=2), dim=3) 319 | 320 | 321 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 322 | # When LSGAN is used, it is basically same as MSELoss, 323 | # but it abstracts away the need to create the target label tensor 324 | # that has the same size as the input 325 | class GANLoss(nn.Module): 326 | """Define different GAN objectives. 327 | 328 | The GANLoss class abstracts away the need to create the target label tensor 329 | that has the same size as the input. 330 | """ 331 | 332 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): 333 | """ Initialize the GANLoss class. 334 | 335 | Parameters: 336 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 337 | target_real_label (bool) - - label for a real image 338 | target_fake_label (bool) - - label of a fake image 339 | 340 | Note: Do not use sigmoid as the last layer of Discriminator. 341 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 342 | """ 343 | super(GANLoss, self).__init__() 344 | self.register_buffer('real_label', torch.tensor(target_real_label)) 345 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 346 | self.gan_mode = gan_mode 347 | if gan_mode == 'lsgan': 348 | self.loss = nn.MSELoss() 349 | elif gan_mode == 'vanilla': 350 | self.loss = nn.BCEWithLogitsLoss() 351 | elif gan_mode in ['wgangp']: 352 | self.loss = None 353 | else: 354 | raise NotImplementedError('gan mode %s not implemented' % gan_mode) 355 | 356 | def get_target_tensor(self, prediction, target_is_real): 357 | """Create label tensors with the same size as the input. 358 | 359 | Parameters: 360 | prediction (tensor) - - tpyically the prediction from a discriminator 361 | target_is_real (bool) - - if the ground truth label is for real images or fake images 362 | 363 | Returns: 364 | A label tensor filled with ground truth label, and with the size of the input 365 | """ 366 | 367 | if target_is_real: 368 | target_tensor = self.real_label 369 | else: 370 | target_tensor = self.fake_label 371 | return target_tensor.expand_as(prediction) 372 | 373 | def __call__(self, predictions, target_is_real): 374 | """Calculate loss given Discriminator's output and grount truth labels. 375 | 376 | Parameters: 377 | prediction (tensor list) - - tpyically the prediction output from a discriminator; supports multi Ds. 378 | target_is_real (bool) - - if the ground truth label is for real images or fake images 379 | 380 | Returns: 381 | the calculated loss. 382 | """ 383 | all_losses = [] 384 | for prediction in predictions: 385 | if self.gan_mode in ['lsgan', 'vanilla']: 386 | target_tensor = self.get_target_tensor(prediction, target_is_real) 387 | loss = self.loss(prediction, target_tensor) 388 | elif self.gan_mode == 'wgangp': 389 | if target_is_real: 390 | loss = -prediction.mean() 391 | else: 392 | loss = prediction.mean() 393 | all_losses.append(loss) 394 | total_loss = sum(all_losses) 395 | return total_loss, all_losses 396 | 397 | 398 | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): 399 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 400 | Arguments: 401 | netD (network) -- discriminator network 402 | real_data (tensor array) -- real images 403 | fake_data (tensor array) -- generated images from the generator 404 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 405 | type (str) -- if we mix real and fake data or not [real | fake | mixed]. 406 | constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2 407 | lambda_gp (float) -- weight for this loss 408 | Returns the gradient penalty loss 409 | """ 410 | if lambda_gp > 0.0: 411 | if type == 'real': # either use real images, fake images, or a linear interpolation of two. 412 | interpolatesv = real_data 413 | elif type == 'fake': 414 | interpolatesv = fake_data 415 | elif type == 'mixed': 416 | alpha = torch.rand(real_data.shape[0], 1) 417 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) 418 | alpha = alpha.to(device) 419 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) 420 | else: 421 | raise NotImplementedError('{} not implemented'.format(type)) 422 | interpolatesv.requires_grad_(True) 423 | disc_interpolates = netD(interpolatesv) 424 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, 425 | grad_outputs=torch.ones(disc_interpolates.size()).to(device), 426 | create_graph=True, retain_graph=True, only_inputs=True) 427 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data 428 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps 429 | return gradient_penalty, gradients 430 | else: 431 | return 0.0, None 432 | 433 | # Defines the Unet generator. 434 | # |num_downs|: number of downsamplings in UNet. For example, 435 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1 436 | # at the bottleneck 437 | 438 | 439 | class G_Unet_add_input(nn.Module): 440 | def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64, 441 | norm_layer=None, nl_layer=None, use_dropout=False, 442 | upsample='basic'): 443 | super(G_Unet_add_input, self).__init__() 444 | self.nz = nz 445 | max_nchn = 8 446 | # construct unet structure 447 | unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, 448 | innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) 449 | for i in range(num_downs - 5): 450 | unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, 451 | norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample) 452 | unet_block = UnetBlock(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, 453 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) 454 | unet_block = UnetBlock(ngf * 2, ngf * 2, ngf * 4, unet_block, 455 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) 456 | unet_block = UnetBlock(ngf, ngf, ngf * 2, unet_block, 457 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) 458 | unet_block = UnetBlock(input_nc + nz, output_nc, ngf, unet_block, 459 | outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) 460 | 461 | self.model = unet_block 462 | 463 | def forward(self, x, z=None): 464 | if self.nz > 0: 465 | z_img = z.view(z.size(0), z.size(1), 1, 1).expand( 466 | z.size(0), z.size(1), x.size(2), x.size(3)) 467 | x_with_z = torch.cat([x, z_img], 1) 468 | else: 469 | x_with_z = x # no z 470 | 471 | return self.model(x_with_z) 472 | 473 | 474 | def upsampleLayer(inplanes, outplanes, upsample='basic', padding_type='zero'): 475 | # padding_type = 'zero' 476 | if upsample == 'basic': 477 | upconv = [nn.ConvTranspose2d( 478 | inplanes, outplanes, kernel_size=4, stride=2, padding=1)] 479 | elif upsample == 'bilinear': 480 | upconv = [nn.Upsample(scale_factor=2, mode='bilinear'), 481 | nn.ReflectionPad2d(1), 482 | nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=1, padding=0)] 483 | else: 484 | raise NotImplementedError( 485 | 'upsample layer [%s] not implemented' % upsample) 486 | return upconv 487 | 488 | 489 | # Defines the submodule with skip connection. 490 | # X -------------------identity---------------------- X 491 | # |-- downsampling -- |submodule| -- upsampling --| 492 | class UnetBlock(nn.Module): 493 | def __init__(self, input_nc, outer_nc, inner_nc, 494 | submodule=None, outermost=False, innermost=False, 495 | norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='zero'): 496 | super(UnetBlock, self).__init__() 497 | self.outermost = outermost 498 | p = 0 499 | downconv = [] 500 | if padding_type == 'reflect': 501 | downconv += [nn.ReflectionPad2d(1)] 502 | elif padding_type == 'replicate': 503 | downconv += [nn.ReplicationPad2d(1)] 504 | elif padding_type == 'zero': 505 | p = 1 506 | else: 507 | raise NotImplementedError( 508 | 'padding [%s] is not implemented' % padding_type) 509 | downconv += [nn.Conv2d(input_nc, inner_nc, 510 | kernel_size=4, stride=2, padding=p)] 511 | # downsample is different from upsample 512 | downrelu = nn.LeakyReLU(0.2, True) 513 | downnorm = norm_layer(inner_nc) if norm_layer is not None else None 514 | uprelu = nl_layer() 515 | upnorm = norm_layer(outer_nc) if norm_layer is not None else None 516 | 517 | if outermost: 518 | upconv = upsampleLayer( 519 | inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type) 520 | down = downconv 521 | up = [uprelu] + upconv + [nn.Tanh()] 522 | model = down + [submodule] + up 523 | elif innermost: 524 | upconv = upsampleLayer( 525 | inner_nc, outer_nc, upsample=upsample, padding_type=padding_type) 526 | down = [downrelu] + downconv 527 | up = [uprelu] + upconv 528 | if upnorm is not None: 529 | up += [upnorm] 530 | model = down + up 531 | else: 532 | upconv = upsampleLayer( 533 | inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type) 534 | down = [downrelu] + downconv 535 | if downnorm is not None: 536 | down += [downnorm] 537 | up = [uprelu] + upconv 538 | if upnorm is not None: 539 | up += [upnorm] 540 | 541 | if use_dropout: 542 | model = down + [submodule] + up + [nn.Dropout(0.5)] 543 | else: 544 | model = down + [submodule] + up 545 | 546 | self.model = nn.Sequential(*model) 547 | 548 | def forward(self, x): 549 | if self.outermost: 550 | return self.model(x) 551 | else: 552 | return torch.cat([self.model(x), x], 1) 553 | 554 | 555 | def conv3x3(in_planes, out_planes): 556 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 557 | padding=1, bias=True) 558 | 559 | 560 | # two usage cases, depend on kw and padw 561 | def upsampleConv(inplanes, outplanes, kw, padw): 562 | sequence = [] 563 | sequence += [nn.Upsample(scale_factor=2, mode='nearest')] 564 | sequence += [nn.Conv2d(inplanes, outplanes, kernel_size=kw, 565 | stride=1, padding=padw, bias=True)] 566 | return nn.Sequential(*sequence) 567 | 568 | 569 | def meanpoolConv(inplanes, outplanes): 570 | sequence = [] 571 | sequence += [nn.AvgPool2d(kernel_size=2, stride=2)] 572 | sequence += [nn.Conv2d(inplanes, outplanes, 573 | kernel_size=1, stride=1, padding=0, bias=True)] 574 | return nn.Sequential(*sequence) 575 | 576 | 577 | def convMeanpool(inplanes, outplanes): 578 | sequence = [] 579 | sequence += [conv3x3(inplanes, outplanes)] 580 | sequence += [nn.AvgPool2d(kernel_size=2, stride=2)] 581 | return nn.Sequential(*sequence) 582 | 583 | 584 | class BasicBlockUp(nn.Module): 585 | def __init__(self, inplanes, outplanes, norm_layer=None, nl_layer=None): 586 | super(BasicBlockUp, self).__init__() 587 | layers = [] 588 | if norm_layer is not None: 589 | layers += [norm_layer(inplanes)] 590 | layers += [nl_layer()] 591 | layers += [upsampleConv(inplanes, outplanes, kw=3, padw=1)] 592 | if norm_layer is not None: 593 | layers += [norm_layer(outplanes)] 594 | layers += [conv3x3(outplanes, outplanes)] 595 | self.conv = nn.Sequential(*layers) 596 | self.shortcut = upsampleConv(inplanes, outplanes, kw=1, padw=0) 597 | 598 | def forward(self, x): 599 | out = self.conv(x) + self.shortcut(x) 600 | return out 601 | 602 | 603 | class BasicBlock(nn.Module): 604 | def __init__(self, inplanes, outplanes, norm_layer=None, nl_layer=None): 605 | super(BasicBlock, self).__init__() 606 | layers = [] 607 | if norm_layer is not None: 608 | layers += [norm_layer(inplanes)] 609 | layers += [nl_layer()] 610 | layers += [conv3x3(inplanes, inplanes)] 611 | if norm_layer is not None: 612 | layers += [norm_layer(inplanes)] 613 | layers += [nl_layer()] 614 | layers += [convMeanpool(inplanes, outplanes)] 615 | self.conv = nn.Sequential(*layers) 616 | self.shortcut = meanpoolConv(inplanes, outplanes) 617 | 618 | def forward(self, x): 619 | out = self.conv(x) + self.shortcut(x) 620 | return out 621 | 622 | 623 | class E_ResNet(nn.Module): 624 | def __init__(self, input_nc=3, output_nc=1, ndf=64, n_blocks=4, 625 | norm_layer=None, nl_layer=None, vaeLike=False): 626 | super(E_ResNet, self).__init__() 627 | self.vaeLike = vaeLike 628 | max_ndf = 4 629 | conv_layers = [ 630 | nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1, bias=True)] 631 | for n in range(1, n_blocks): 632 | input_ndf = ndf * min(max_ndf, n) 633 | output_ndf = ndf * min(max_ndf, n + 1) 634 | conv_layers += [BasicBlock(input_ndf, 635 | output_ndf, norm_layer, nl_layer)] 636 | conv_layers += [nl_layer(), nn.AvgPool2d(8)] 637 | if vaeLike: 638 | self.fc = nn.Sequential(*[nn.Linear(output_ndf, output_nc)]) 639 | self.fcVar = nn.Sequential(*[nn.Linear(output_ndf, output_nc)]) 640 | else: 641 | self.fc = nn.Sequential(*[nn.Linear(output_ndf, output_nc)]) 642 | self.conv = nn.Sequential(*conv_layers) 643 | 644 | def forward(self, x): 645 | x_conv = self.conv(x) 646 | conv_flat = x_conv.view(x.size(0), -1) 647 | output = self.fc(conv_flat) 648 | if self.vaeLike: 649 | outputVar = self.fcVar(conv_flat) 650 | return output, outputVar 651 | else: 652 | return output 653 | return output 654 | 655 | 656 | # Defines the Unet generator. 657 | # |num_downs|: number of downsamplings in UNet. For example, 658 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1 659 | # at the bottleneck 660 | class G_Unet_add_all(nn.Module): 661 | def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64, 662 | norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic'): 663 | super(G_Unet_add_all, self).__init__() 664 | self.nz = nz 665 | # construct unet structure 666 | unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, None, innermost=True, 667 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) 668 | unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, unet_block, 669 | norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample) 670 | for i in range(num_downs - 6): 671 | unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, unet_block, 672 | norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample) 673 | unet_block = UnetBlock_with_z(ngf * 4, ngf * 4, ngf * 8, nz, unet_block, 674 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) 675 | unet_block = UnetBlock_with_z(ngf * 2, ngf * 2, ngf * 4, nz, unet_block, 676 | norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) 677 | unet_block = UnetBlock_with_z( 678 | ngf, ngf, ngf * 2, nz, unet_block, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) 679 | unet_block = UnetBlock_with_z(input_nc, output_nc, ngf, nz, unet_block, 680 | outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample) 681 | self.model = unet_block 682 | 683 | def forward(self, x, z): 684 | return self.model(x, z) 685 | 686 | 687 | class UnetBlock_with_z(nn.Module): 688 | def __init__(self, input_nc, outer_nc, inner_nc, nz=0, 689 | submodule=None, outermost=False, innermost=False, 690 | norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='zero'): 691 | super(UnetBlock_with_z, self).__init__() 692 | p = 0 693 | downconv = [] 694 | if padding_type == 'reflect': 695 | downconv += [nn.ReflectionPad2d(1)] 696 | elif padding_type == 'replicate': 697 | downconv += [nn.ReplicationPad2d(1)] 698 | elif padding_type == 'zero': 699 | p = 1 700 | else: 701 | raise NotImplementedError( 702 | 'padding [%s] is not implemented' % padding_type) 703 | 704 | self.outermost = outermost 705 | self.innermost = innermost 706 | self.nz = nz 707 | input_nc = input_nc + nz 708 | downconv += [nn.Conv2d(input_nc, inner_nc, 709 | kernel_size=4, stride=2, padding=p)] 710 | # downsample is different from upsample 711 | downrelu = nn.LeakyReLU(0.2, True) 712 | uprelu = nl_layer() 713 | 714 | if outermost: 715 | upconv = upsampleLayer( 716 | inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type) 717 | down = downconv 718 | up = [uprelu] + upconv + [nn.Tanh()] 719 | elif innermost: 720 | upconv = upsampleLayer( 721 | inner_nc, outer_nc, upsample=upsample, padding_type=padding_type) 722 | down = [downrelu] + downconv 723 | up = [uprelu] + upconv 724 | if norm_layer is not None: 725 | up += [norm_layer(outer_nc)] 726 | else: 727 | upconv = upsampleLayer( 728 | inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type) 729 | down = [downrelu] + downconv 730 | if norm_layer is not None: 731 | down += [norm_layer(inner_nc)] 732 | up = [uprelu] + upconv 733 | 734 | if norm_layer is not None: 735 | up += [norm_layer(outer_nc)] 736 | 737 | if use_dropout: 738 | up += [nn.Dropout(0.5)] 739 | self.down = nn.Sequential(*down) 740 | self.submodule = submodule 741 | self.up = nn.Sequential(*up) 742 | 743 | def forward(self, x, z): 744 | # print(x.size()) 745 | if self.nz > 0: 746 | z_img = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), x.size(2), x.size(3)) 747 | x_and_z = torch.cat([x, z_img], 1) 748 | else: 749 | x_and_z = x 750 | 751 | if self.outermost: 752 | x1 = self.down(x_and_z) 753 | x2 = self.submodule(x1, z) 754 | return self.up(x2) 755 | elif self.innermost: 756 | x1 = self.up(self.down(x_and_z)) 757 | return torch.cat([x1, x], 1) 758 | else: 759 | x1 = self.down(x_and_z) 760 | x2 = self.submodule(x1, z) 761 | return torch.cat([self.up(x2), x], 1) 762 | 763 | 764 | class E_NLayers(nn.Module): 765 | def __init__(self, input_nc, output_nc=1, ndf=64, n_layers=3, 766 | norm_layer=None, nl_layer=None, vaeLike=False): 767 | super(E_NLayers, self).__init__() 768 | self.vaeLike = vaeLike 769 | 770 | kw, padw = 4, 1 771 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, 772 | stride=2, padding=padw), nl_layer()] 773 | 774 | nf_mult = 1 775 | nf_mult_prev = 1 776 | for n in range(1, n_layers): 777 | nf_mult_prev = nf_mult 778 | nf_mult = min(2**n, 4) 779 | sequence += [ 780 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 781 | kernel_size=kw, stride=2, padding=padw)] 782 | if norm_layer is not None: 783 | sequence += [norm_layer(ndf * nf_mult)] 784 | sequence += [nl_layer()] 785 | sequence += [nn.AvgPool2d(8)] 786 | self.conv = nn.Sequential(*sequence) 787 | self.fc = nn.Sequential(*[nn.Linear(ndf * nf_mult, output_nc)]) 788 | if vaeLike: 789 | self.fcVar = nn.Sequential(*[nn.Linear(ndf * nf_mult, output_nc)]) 790 | 791 | def forward(self, x): 792 | x_conv = self.conv(x) 793 | conv_flat = x_conv.view(x.size(0), -1) 794 | output = self.fc(conv_flat) 795 | if self.vaeLike: 796 | outputVar = self.fcVar(conv_flat) 797 | return output, outputVar 798 | return output 799 | -------------------------------------------------------------------------------- /models/pix2pix_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks 4 | 5 | 6 | class Pix2PixModel(BaseModel): 7 | """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. 8 | 9 | The model training requires '--dataset_mode aligned' dataset. 10 | By default, it uses a '--netG unet256' U-Net generator, 11 | a '--netD basic' discriminator (PatchGAN), 12 | and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). 13 | 14 | pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf 15 | """ 16 | @staticmethod 17 | def modify_commandline_options(parser, is_train=True): 18 | """Add new dataset-specific options, and rewrite default values for existing options. 19 | 20 | Parameters: 21 | parser -- original option parser 22 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 23 | 24 | Returns: 25 | the modified parser. 26 | 27 | For pix2pix, we do not use image buffer 28 | The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 29 | By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets. 30 | """ 31 | # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) 32 | parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='aligned') 33 | parser.set_defaults(where_add='input', nz=0) 34 | if is_train: 35 | parser.set_defaults(gan_mode='vanilla', lambda_l1=100.0) 36 | 37 | return parser 38 | 39 | def __init__(self, opt): 40 | """Initialize the pix2pix class. 41 | 42 | Parameters: 43 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 44 | """ 45 | BaseModel.__init__(self, opt) 46 | # specify the training losses you want to print out. The training/test scripts will call 47 | self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] 48 | # specify the images you want to save/display. The training/test scripts will call 49 | self.visual_names = ['real_A', 'fake_B', 'real_B'] 50 | # specify the models you want to save to the disk. The training/test scripts will call and 51 | if self.isTrain: 52 | self.model_names = ['G', 'D'] 53 | else: # during test time, only load G 54 | self.model_names = ['G'] 55 | # define networks (both generator and discriminator) 56 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.nz, opt.ngf, netG=opt.netG, 57 | norm=opt.norm, nl=opt.nl, use_dropout=opt.use_dropout, init_type=opt.init_type, init_gain=opt.init_gain, 58 | gpu_ids=self.gpu_ids, where_add=opt.where_add, upsample=opt.upsample) 59 | 60 | if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc 61 | self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, netD=opt.netD2, norm=opt.norm, nl=opt.nl, 62 | init_type=opt.init_type, init_gain=opt.init_gain, num_Ds=opt.num_Ds, gpu_ids=self.gpu_ids) 63 | 64 | if self.isTrain: 65 | # define loss functions 66 | self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) 67 | self.criterionL1 = torch.nn.L1Loss() 68 | # initialize optimizers; schedulers will be automatically created by function . 69 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 70 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 71 | self.optimizers.append(self.optimizer_G) 72 | self.optimizers.append(self.optimizer_D) 73 | 74 | def set_input(self, input): 75 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 76 | 77 | Parameters: 78 | input (dict): include the data itself and its metadata information. 79 | 80 | The option 'direction' can be used to swap images in domain A and domain B. 81 | """ 82 | AtoB = self.opt.direction == 'AtoB' 83 | self.real_A = input['A' if AtoB else 'B'].to(self.device) 84 | self.real_B = input['B' if AtoB else 'A'].to(self.device) 85 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] 86 | 87 | def forward(self): 88 | """Run forward pass; called by both functions and .""" 89 | self.fake_B = self.netG(self.real_A) # G(A) 90 | 91 | def backward_D(self): 92 | """Calculate GAN loss for the discriminator""" 93 | # Fake; stop backprop to the generator by detaching fake_B 94 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator 95 | pred_fake = self.netD(fake_AB.detach()) 96 | self.loss_D_fake, _ = self.criterionGAN(pred_fake, False) 97 | # Real 98 | real_AB = torch.cat((self.real_A, self.real_B), 1) 99 | pred_real = self.netD(real_AB) 100 | self.loss_D_real, _ = self.criterionGAN(pred_real, True) 101 | # combine loss and calculate gradients 102 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 103 | self.loss_D.backward() 104 | 105 | def backward_G(self): 106 | """Calculate GAN and L1 loss for the generator""" 107 | # First, G(A) should fake the discriminator 108 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) 109 | pred_fake = self.netD(fake_AB) 110 | self.loss_G_GAN, _ = self.criterionGAN(pred_fake, True) 111 | # Second, G(A) = B 112 | self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 113 | # combine loss and calculate gradients 114 | self.loss_G = self.loss_G_GAN + self.loss_G_L1 115 | self.loss_G.backward() 116 | 117 | def optimize_parameters(self): 118 | self.forward() # compute fake images: G(A) 119 | # update D 120 | self.set_requires_grad(self.netD, True) # enable backprop for D 121 | self.optimizer_D.zero_grad() # set D's gradients to zero 122 | self.backward_D() # calculate gradients for D 123 | self.optimizer_D.step() # update D's weights 124 | # update G 125 | self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G 126 | self.optimizer_G.zero_grad() # set G's gradients to zero 127 | self.backward_G() # calculate graidents for G 128 | self.optimizer_G.step() # udpate G's weights 129 | -------------------------------------------------------------------------------- /models/template_model.py: -------------------------------------------------------------------------------- 1 | """Model class template 2 | 3 | This module provides a template for users to implement custom models. 4 | You can specify '--model template' to use this model. 5 | The class name should be consistent with both the filename and its model option. 6 | The filename should be _dataset.py 7 | The class name should be Dataset.py 8 | It implements a simple image-to-image translation baseline based on regression loss. 9 | Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss: 10 | min_ ||netG(data_A) - data_B||_1 11 | You need to implement the following functions: 12 | : Add model-specific options and rewrite default values for existing options. 13 | <__init__>: Initialize this model class. 14 | : Unpack input data and perform data pre-processing. 15 | : Run forward pass. This will be called by both and . 16 | : Update network weights; it will be called in every training iteration. 17 | """ 18 | import torch 19 | from .base_model import BaseModel 20 | from . import networks 21 | 22 | 23 | class TemplateModel(BaseModel): 24 | @staticmethod 25 | def modify_commandline_options(parser, is_train=True): 26 | """Add new model-specific options and rewrite default values for existing options. 27 | 28 | Parameters: 29 | parser -- the option parser 30 | is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options. 31 | 32 | Returns: 33 | the modified parser. 34 | """ 35 | parser.set_defaults(dataset_mode='aligned', netG='unet_256', where_add='input', nz=0) # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset. 36 | if is_train: 37 | parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model. 38 | 39 | return parser 40 | 41 | def __init__(self, opt): 42 | """Initialize this model class. 43 | 44 | Parameters: 45 | opt -- training/test options 46 | 47 | A few things can be done here. 48 | - (required) call the initialization function of BaseModel 49 | - define loss function, visualization images, model names, and optimizers 50 | """ 51 | BaseModel.__init__(self, opt) # call the initialization method of BaseModel 52 | # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk. 53 | self.loss_names = ['loss_G'] 54 | # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images. 55 | self.visual_names = ['data_A', 'data_B', 'output'] 56 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks. 57 | # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them. 58 | self.model_names = ['G'] 59 | # define networks; you can use opt.isTrain to specify different behaviors for training and test. 60 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.nz, opt.ngf, netG=opt.netG, gpu_ids=self.gpu_ids, where_add=opt.where_add) 61 | if self.isTrain: # only defined during training time 62 | # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss. 63 | # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device) 64 | self.criterionLoss = torch.nn.L1Loss() 65 | # define and initialize optimizers. You can define one optimizer for each network. 66 | # If two networks are updated at the same time, you can use itertools.chain to group them. See bicycle_gan_model.py for an example. 67 | self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 68 | self.optimizers = [self.optimizer] 69 | 70 | # Our program will automatically call to define schedulers, load networks, and print networks 71 | 72 | def set_input(self, input): 73 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 74 | 75 | Parameters: 76 | input: a dictionary that contains the data itself and its metadata information. 77 | """ 78 | AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B 79 | self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A 80 | self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B 81 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths 82 | 83 | def forward(self): 84 | """Run forward pass. This will be called by both functions and .""" 85 | self.output = self.netG(self.data_A) # generate output image given the input data_A 86 | 87 | def backward(self): 88 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 89 | # caculate the intermediate results if necessary; here self.output has been computed during function 90 | # calculate loss given the input and intermediate results 91 | self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression 92 | self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G 93 | 94 | def optimize_parameters(self): 95 | """Update network weights; it will be called in every training iteration.""" 96 | self.forward() # first call forward to calculate intermediate results 97 | self.optimizer.zero_grad() # clear network G's existing gradients 98 | self.backward() # calculate gradients for network G 99 | self.optimizer.step() # update gradients for network G 100 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyanz/BicycleGAN/40b9d52c27b9831f56c1c7c7a6ddde8bc9149067/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | import data 7 | 8 | 9 | class BaseOptions(): 10 | def __init__(self): 11 | self.initialized = False 12 | 13 | def initialize(self, parser): 14 | """This class defines options used during both training and test time. 15 | 16 | It also implements several helper functions such as parsing, printing, and saving the options. 17 | It also gathers additional options defined in functions in both dataset class and model class. 18 | """ 19 | parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 20 | parser.add_argument('--batch_size', type=int, default=2, help='input batch size') 21 | parser.add_argument('--load_size', type=int, default=286, help='scale images to this size') 22 | parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size') 23 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') 24 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 25 | parser.add_argument('--nz', type=int, default=8, help='#latent vector') 26 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2, -1 for CPU mode') 27 | parser.add_argument('--name', type=str, default='', help='name of the experiment. It decides where to store samples and models') 28 | parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='not implemented') 29 | parser.add_argument('--dataset_mode', type=str, default='aligned', help='aligned,single') 30 | parser.add_argument('--model', type=str, default='bicycle_gan', help='chooses which model to use. bicycle,, ...') 31 | parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA') 32 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 33 | parser.add_argument('--num_threads', default=4, type=int, help='# sthreads for loading data') 34 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 35 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 36 | parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator') 37 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 38 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') 39 | 40 | # model parameters 41 | parser.add_argument('--num_Ds', type=int, default=2, help='number of Discrminators') 42 | parser.add_argument('--netD', type=str, default='basic_256_multi', help='selects model to use for netD') 43 | parser.add_argument('--netD2', type=str, default='basic_256_multi', help='selects model to use for netD2') 44 | parser.add_argument('--netG', type=str, default='unet_256', help='selects model to use for netG') 45 | parser.add_argument('--netE', type=str, default='resnet_256', help='selects model to use for netE') 46 | parser.add_argument('--nef', type=int, default=64, help='# of encoder filters in the first conv layer') 47 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') 48 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') 49 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') 50 | parser.add_argument('--upsample', type=str, default='basic', help='basic | bilinear') 51 | parser.add_argument('--nl', type=str, default='relu', help='non-linearity activation: relu | lrelu | elu') 52 | 53 | # extra parameters 54 | parser.add_argument('--where_add', type=str, default='all', help='input|all|middle; where to add z in the network G') 55 | parser.add_argument('--conditional_D', action='store_true', help='if use conditional GAN for D') 56 | parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal | xavier | kaiming | orthogonal]') 57 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 58 | parser.add_argument('--center_crop', action='store_true', help='if apply for center cropping for the test') 59 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 60 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') 61 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size') 62 | 63 | # special tasks 64 | self.initialized = True 65 | return parser 66 | 67 | def gather_options(self): 68 | """Initialize our parser with basic options(only once). 69 | Add additional model-specific and dataset-specific options. 70 | These options are difined in the function 71 | in model and dataset classes. 72 | """ 73 | if not self.initialized: # check if it has been initialized 74 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 75 | parser = self.initialize(parser) 76 | 77 | # get the basic options 78 | opt, _ = parser.parse_known_args() 79 | 80 | # modify model-related parser options 81 | model_name = opt.model 82 | model_option_setter = models.get_option_setter(model_name) 83 | parser = model_option_setter(parser, self.isTrain) 84 | opt, _ = parser.parse_known_args() # parse again with new defaults 85 | 86 | # modify dataset-related parser options 87 | dataset_name = opt.dataset_mode 88 | dataset_option_setter = data.get_option_setter(dataset_name) 89 | parser = dataset_option_setter(parser, self.isTrain) 90 | 91 | # save and return the parser 92 | self.parser = parser 93 | return parser.parse_args() 94 | 95 | def print_options(self, opt): 96 | """Print and save options 97 | 98 | It will print both current options and default values(if different). 99 | It will save options into a text file / [checkpoints_dir] / opt.txt 100 | """ 101 | message = '' 102 | message += '----------------- Options ---------------\n' 103 | for k, v in sorted(vars(opt).items()): 104 | comment = '' 105 | default = self.parser.get_default(k) 106 | if v != default: 107 | comment = '\t[default: %s]' % str(default) 108 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 109 | message += '----------------- End -------------------' 110 | print(message) 111 | 112 | # save to the disk 113 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 114 | util.mkdirs(expr_dir) 115 | file_name = os.path.join(expr_dir, 'opt.txt') 116 | with open(file_name, 'wt') as opt_file: 117 | opt_file.write(message) 118 | opt_file.write('\n') 119 | 120 | def parse(self): 121 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 122 | opt = self.gather_options() 123 | opt.isTrain = self.isTrain # train or test 124 | 125 | # process opt.suffix 126 | if opt.suffix: 127 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 128 | opt.name = opt.name + suffix 129 | 130 | self.print_options(opt) 131 | 132 | # set gpu ids 133 | str_ids = opt.gpu_ids.split(',') 134 | opt.gpu_ids = [] 135 | for str_id in str_ids: 136 | id = int(str_id) 137 | if id >= 0: 138 | opt.gpu_ids.append(id) 139 | if len(opt.gpu_ids) > 0: 140 | torch.cuda.set_device(opt.gpu_ids[0]) 141 | 142 | self.opt = opt 143 | return self.opt 144 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self, parser): 6 | BaseOptions.initialize(self, parser) 7 | parser.add_argument('--results_dir', type=str, default='../results/', help='saves results here.') 8 | parser.add_argument('--phase', type=str, default='val', help='train, val, test, etc') 9 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') 10 | parser.add_argument('--n_samples', type=int, default=5, help='#samples') 11 | parser.add_argument('--no_encode', action='store_true', help='do not produce encoded image') 12 | parser.add_argument('--sync', action='store_true', help='use the same latent code for different input images') 13 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio for the results') 14 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 15 | 16 | self.isTrain = False 17 | return parser 18 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self, parser): 6 | BaseOptions.initialize(self, parser) 7 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') 8 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 9 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') 10 | parser.add_argument('--display_port', type=int, default=8097, help='visdom display port') 11 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 12 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 13 | parser.add_argument('--update_html_freq', type=int, default=4000, help='frequency of saving training results to html') 14 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 15 | parser.add_argument('--save_latest_freq', type=int, default=10000, help='frequency of saving the latest results') 16 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') 17 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 18 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 19 | parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla | lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') 20 | # training parameters 21 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 22 | parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 23 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 24 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 25 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 26 | parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy: linear | step | plateau | cosine') 27 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 28 | parser.add_argument('--lr_decay_iters', type=int, default=100, help='multiply by a gamma every lr_decay_iters iterations') 29 | # lambda parameters 30 | parser.add_argument('--lambda_L1', type=float, default=10.0, help='weight for |B-G(A, E(B))|') 31 | parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight on D loss. D(G(A, E(B)))') 32 | parser.add_argument('--lambda_GAN2', type=float, default=1.0, help='weight on D2 loss, D(G(A, random_z))') 33 | parser.add_argument('--lambda_z', type=float, default=0.5, help='weight for ||E(G(random_z)) - random_z||') 34 | parser.add_argument('--lambda_kl', type=float, default=0.01, help='weight for KL loss') 35 | parser.add_argument('--use_same_D', action='store_true', help='if two Ds share the weights or not') 36 | self.isTrain = True 37 | return parser 38 | -------------------------------------------------------------------------------- /options/video_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class VideoOptions(BaseOptions): 5 | def initialize(self, parser): 6 | BaseOptions.initialize(self, parser) 7 | parser.add_argument('--results_dir', type=str, default='../video/', help='saves results here.') 8 | parser.add_argument('--phase', type=str, default='val', help='train, val, test, etc') 9 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') 10 | parser.add_argument('--n_samples', type=int, default=5, help='#samples') 11 | parser.add_argument('--num_frames', type=int, default=4, help='number of the frames used in the morphing sequence') 12 | parser.add_argument('--align_mode', type=str, default='horizontal', help='ways of aligning the input images') 13 | parser.add_argument('--border', type=int, default='0', help='border between results') 14 | parser.add_argument('--seed', type=int, default=50, help='random seed for latent vectors') 15 | parser.add_argument('--fps', type=int, default=8, help='speed of the generated video') 16 | self.isTrain = False 17 | return parser 18 | -------------------------------------------------------------------------------- /pretrained_models/download_model.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | 3 | echo "Note: available models are edges2shoes, edges2handbags, night2day, maps, and facades" 4 | echo "downloading [$FILE]" 5 | MODEL_DIR=./pretrained_models/${FILE} 6 | mkdir -p ${MODEL_DIR} 7 | 8 | 9 | MODEL_FILE_G=${MODEL_DIR}/latest_net_G.pth 10 | URL_G=http://efrosgans.eecs.berkeley.edu/BicycleGAN//models/${FILE}_net_G.pth 11 | wget -N $URL_G -O $MODEL_FILE_G 12 | 13 | 14 | MODEL_FILE_E=${MODEL_DIR}/latest_net_E.pth 15 | URL_E=http://efrosgans.eecs.berkeley.edu/BicycleGAN//models/${FILE}_net_E.pth 16 | wget -N $URL_E -O $MODEL_FILE_E 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=0.4.1 2 | torchvision>=0.2.1 3 | dominate>=2.3.1 4 | visdom>=0.1.8.3 5 | -------------------------------------------------------------------------------- /scripts/check_all.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | DOWNLOAD_MODEL=${1} 3 | # test code 4 | echo 'test edges2handbags' 5 | if [ ${DOWNLOAD_MODEL} -eq 1 ] 6 | then 7 | bash ./pretrained_models/download_model.sh edges2handbags 8 | fi 9 | bash ./datasets/download_testset.sh edges2handbags 10 | bash ./scripts/test_edges2handbags.sh 11 | 12 | echo 'test edges2shoes' 13 | if [ ${DOWNLOAD_MODEL} -eq 1 ] 14 | then 15 | bash ./pretrained_models/download_model.sh edges2shoes 16 | fi 17 | bash ./datasets/download_testset.sh edges2shoes 18 | bash ./scripts/test_edges2shoes.sh 19 | 20 | echo 'test facades_label2image' 21 | if [ ${DOWNLOAD_MODEL} -eq 1 ] 22 | then 23 | bash ./pretrained_models/download_model.sh night2day 24 | fi 25 | bash ./datasets/download_testset.sh night2day 26 | bash ./scripts/test_night2day.sh 27 | 28 | echo 'test maps' 29 | if [ ${DOWNLOAD_MODEL} -eq 1 ] 30 | then 31 | bash ./pretrained_models/download_model.sh maps 32 | fi 33 | bash ./datasets/download_testset.sh maps 34 | bash ./scripts/test_maps.sh 35 | 36 | echo 'test facades' 37 | if [ ${DOWNLOAD_MODEL} -eq 1 ] 38 | then 39 | bash ./pretrained_models/download_model.sh facades 40 | fi 41 | bash ./datasets/download_testset.sh facades 42 | bash ./scripts/test_facades.sh 43 | 44 | echo 'test night2day' 45 | if [ ${DOWNLOAD_MODEL} -eq 1 ] 46 | then 47 | bash ./pretrained_models/download_model.sh night2day 48 | fi 49 | bash ./datasets/download_testset.sh night2day 50 | bash ./scripts/test_night2day.sh 51 | 52 | echo 'video edges2shoes' 53 | bash ./scripts/video_edges2shoes.sh 54 | 55 | echo "train a pix2pix model" 56 | bash ./datasets/download_dataset.sh facades 57 | python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix \ 58 | --netG unet_256 --direction BtoA --lambda_L1 10 --dataset_mode aligned \ 59 | --gan_mode lsgan --norm batch --niter 1 --niter_decay 0 --save_epoch_freq 1 60 | echo "train a bicyclegan model" 61 | python train.py --dataroot ./datasets/facades --name facades_bicycle --model bicycle_gan \ 62 | --direction BtoA --dataset_mode aligned \ 63 | --gan_mode lsgan --norm batch --niter 1 --niter_decay 0 --save_epoch_freq 1 64 | -------------------------------------------------------------------------------- /scripts/install_conda.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | conda install -c conda-forge visdom 3 | conda install -c conda-forge dominate 4 | conda install -c conda-forge moviepy 5 | -------------------------------------------------------------------------------- /scripts/install_pip.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | sudo -H pip install visdom 3 | sudo -H pip install dominate 4 | sudo -H pip install moviepy 5 | -------------------------------------------------------------------------------- /scripts/test_before_push.py: -------------------------------------------------------------------------------- 1 | """Simple script to make sure basic usage such as training, testing, saving and loading runs without errors.""" 2 | import os 3 | 4 | 5 | def run(command): 6 | print(command) 7 | exit_status = os.system(command) 8 | if exit_status > 0: 9 | exit(1) 10 | 11 | 12 | if __name__ == '__main__': 13 | if not os.path.exists('./datasets/mini_pix2pix'): 14 | run('bash ./datasets/download_mini_dataset.sh mini_pix2pix') 15 | 16 | # pix2pix train/test 17 | run('python train.py --model pix2pix --name temp_pix2pix --dataroot ./datasets/mini_pix2pix --niter 1 --niter_decay 0 --save_latest_freq 10 --display_id -1') 18 | 19 | # template train/test 20 | run('python train.py --model template --name temp2 --dataroot ./datasets/mini_pix2pix --niter 1 --niter_decay 0 --save_latest_freq 10 --display_id -1') 21 | 22 | run('bash ./scripts/test_edges2shoes.sh') 23 | run('bash ./scripts/test_edges2shoes.sh --sync') 24 | run('bash ./scripts/video_edges2shoes.sh') 25 | # run('bash ./scripts/train_facades.sh') 26 | -------------------------------------------------------------------------------- /scripts/test_edges2handbags.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | # models 3 | RESULTS_DIR='./results/edges2handbags' 4 | G_PATH='./pretrained_models/edges2handbags_net_G.pth' 5 | E_PATH='./pretrained_models/edges2handbags_net_E.pth' 6 | 7 | # dataset 8 | CLASS='edges2handbags' 9 | DIRECTION='AtoB' # from domain A to domain B 10 | LOAD_SIZE=256 # scale images to this size 11 | CROP_SIZE=256 # then crop to this size 12 | INPUT_NC=1 # number of channels in the input image 13 | 14 | # misc 15 | GPU_ID=0 # gpu id 16 | NUM_TEST=10 # number of input images duirng test 17 | NUM_SAMPLES=10 # number of samples per input images 18 | 19 | # command 20 | CUDA_VISIBLE_DEVICES=${GPU_ID} python ./test.py \ 21 | --dataroot ./datasets/${CLASS} \ 22 | --results_dir ${RESULTS_DIR} \ 23 | --checkpoints_dir ./pretrained_models/ \ 24 | --name ${CLASS} \ 25 | --direction ${DIRECTION} \ 26 | --load_size ${LOAD_SIZE} \ 27 | --crop_size ${CROP_SIZE} \ 28 | --input_nc ${INPUT_NC} \ 29 | --num_test ${NUM_TEST} \ 30 | --n_samples ${NUM_SAMPLES} \ 31 | --center_crop \ 32 | --no_flip 33 | -------------------------------------------------------------------------------- /scripts/test_edges2shoes.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | # models 3 | RESULTS_DIR='./results/edges2shoes' 4 | G_PATH='./pretrained_models/edges2shoes_net_G.pth' 5 | E_PATH='./pretrained_models/edges2shoes_net_E.pth' 6 | 7 | # dataset 8 | CLASS='edges2shoes' 9 | DIRECTION='AtoB' # from domain A to domain B 10 | LOAD_SIZE=256 # scale images to this size 11 | CROP_SIZE=256 # then crop to this size 12 | INPUT_NC=1 # number of channels in the input image 13 | 14 | # misc 15 | GPU_ID=0 # gpu id 16 | NUM_TEST=10 # number of input images duirng test 17 | NUM_SAMPLES=10 # number of samples per input images 18 | 19 | # command 20 | CUDA_VISIBLE_DEVICES=${GPU_ID} python ./test.py \ 21 | --dataroot ./datasets/${CLASS} \ 22 | --results_dir ${RESULTS_DIR} \ 23 | --checkpoints_dir ./pretrained_models/ \ 24 | --name ${CLASS} \ 25 | --direction ${DIRECTION} \ 26 | --load_size ${LOAD_SIZE} \ 27 | --crop_size ${CROP_SIZE} \ 28 | --input_nc ${INPUT_NC} \ 29 | --num_test ${NUM_TEST} \ 30 | --n_samples ${NUM_SAMPLES} \ 31 | --center_crop \ 32 | --no_flip 33 | -------------------------------------------------------------------------------- /scripts/test_facades.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | # models 3 | RESULTS_DIR='./results/facades' 4 | 5 | # dataset 6 | CLASS='facades' 7 | DIRECTION='BtoA' # from domain A to domain B 8 | LOAD_SIZE=286 # scale images to this size 9 | CROP_SIZE=256 # then crop to this size 10 | INPUT_NC=3 # number of channels in the input image 11 | 12 | # misc 13 | GPU_ID=0 # gpu id 14 | NUM_TEST=10 # number of input images duirng test 15 | NUM_SAMPLES=10 # number of samples per input images 16 | 17 | 18 | # command 19 | CUDA_VISIBLE_DEVICES=${GPU_ID} python ./test.py \ 20 | --dataroot ./datasets/${CLASS} \ 21 | --results_dir ${RESULTS_DIR} \ 22 | --checkpoints_dir ./pretrained_models/ \ 23 | --name ${CLASS} \ 24 | --direction ${DIRECTION} \ 25 | --load_size ${LOAD_SIZE} \ 26 | --crop_size ${CROP_SIZE} \ 27 | --input_nc ${INPUT_NC} \ 28 | --num_test ${NUM_TEST} \ 29 | --n_samples ${NUM_SAMPLES} \ 30 | --center_crop \ 31 | --no_flip 32 | -------------------------------------------------------------------------------- /scripts/test_maps.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | # models 3 | RESULTS_DIR='./results/maps' 4 | G_PATH='./pretrained_models/map2aerial_net_G.pth' 5 | E_PATH='./pretrained_models/map2aerial_net_E.pth' 6 | 7 | # dataset 8 | CLASS='maps' 9 | DIRECTION='BtoA' # from domain A to domain B 10 | LOAD_SIZE=512 # scale images to this size 11 | CROP_SIZE=512 # then crop to this size 12 | INPUT_NC=3 # number of channels in the input image 13 | ASPECT_RATIO=1.0 14 | # change aspect ratio for the test images 15 | 16 | # misc 17 | GPU_ID=0 # gpu id 18 | NUM_TEST=10 # number of input images duirng test 19 | NUM_SAMPLES=10 # number of samples per input images 20 | 21 | 22 | # command 23 | CUDA_VISIBLE_DEVICES=${GPU_ID} python ./test.py \ 24 | --dataroot ./datasets/${CLASS} \ 25 | --results_dir ${RESULTS_DIR} \ 26 | --checkpoints_dir ./pretrained_models/ \ 27 | --name ${CLASS} \ 28 | --direction ${DIRECTION} \ 29 | --load_size ${LOAD_SIZE} \ 30 | --crop_size ${CROP_SIZE} \ 31 | --input_nc ${INPUT_NC} \ 32 | --num_test ${NUM_TEST} \ 33 | --n_samples ${NUM_SAMPLES} \ 34 | --aspect_ratio ${ASPECT_RATIO} \ 35 | --center_crop \ 36 | --no_flip \ 37 | --no_encode 38 | -------------------------------------------------------------------------------- /scripts/test_night2day.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | # models 3 | RESULTS_DIR='./results/night2day' 4 | G_PATH='./pretrained_models/night2day_net_G.pth' 5 | E_PATH='./pretrained_models/night2day_net_E.pth' 6 | 7 | # dataset 8 | CLASS='night2day' 9 | DIRECTION='AtoB' # from domain A to domain B 10 | LOAD_SIZE=286 # scale images to this size 11 | CROP_SIZE=256 # then crop to this size 12 | INPUT_NC=3 # number of channels in the input image 13 | ASPECT_RATIO=1.4 14 | # change aspect ratio for the test images 15 | 16 | # misc 17 | GPU_ID=0 # gpu id 18 | NUM_TEST=10 # number of input images duirng test 19 | NUM_SAMPLES=10 # number of samples per input images 20 | 21 | 22 | # command 23 | CUDA_VISIBLE_DEVICES=${GPU_ID} python ./test.py \ 24 | --dataroot ./datasets/${CLASS} \ 25 | --results_dir ${RESULTS_DIR} \ 26 | --checkpoints_dir ./pretrained_models/ \ 27 | --name ${CLASS} \ 28 | --direction ${DIRECTION} \ 29 | --load_size ${LOAD_SIZE} \ 30 | --crop_size ${CROP_SIZE} \ 31 | --input_nc ${INPUT_NC} \ 32 | --num_test ${NUM_TEST} \ 33 | --n_samples ${NUM_SAMPLES} \ 34 | --aspect_ratio ${ASPECT_RATIO} \ 35 | --center_crop \ 36 | --no_flip 37 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | CLASS='edges2shoes' # facades, day2night, edges2shoes, edges2handbags, maps 3 | MODEL='bicycle_gan' 4 | CLASS=${1} 5 | GPU_ID=${2} 6 | DISPLAY_ID=$((GPU_ID*10+1)) 7 | PORT=2005 8 | NZ=8 9 | 10 | 11 | CHECKPOINTS_DIR=../checkpoints/${CLASS}/ 12 | DATE=`date '+%d_%m_%Y_%H'` 13 | NAME=${CLASS}_${MODEL}_${DATE} 14 | 15 | 16 | # dataset 17 | NO_FLIP='' 18 | DIRECTION='AtoB' 19 | LOAD_SIZE=286 20 | CROP_SIZE=256 21 | INPUT_NC=3 22 | 23 | # dataset parameters 24 | case ${CLASS} in 25 | 'facades') 26 | NITER=200 27 | NITER_DECAY=200 28 | SAVE_EPOCH=25 29 | DIRECTION='BtoA' 30 | ;; 31 | 'edges2shoes') 32 | NITER=30 33 | NITER_DECAY=30 34 | LOAD_SIZE=256 35 | SAVE_EPOCH=5 36 | INPUT_NC=1 37 | NO_FLIP='--no_flip' 38 | ;; 39 | 'edges2handbags') 40 | NITER=15 41 | NITER_DECAY=15 42 | LOAD_SIZE=256 43 | SAVE_EPOCH=5 44 | INPUT_NC=1 45 | ;; 46 | 'maps') 47 | NITER=200 48 | NITER_DECAY=200 49 | LOAD_SIZE=600 50 | SAVE_EPOCH=25 51 | DIRECTION='BtoA' 52 | ;; 53 | 'day2night') 54 | NITER=50 55 | NITER_DECAY=50 56 | SAVE_EPOCH=10 57 | ;; 58 | *) 59 | echo 'WRONG category'${CLASS} 60 | ;; 61 | esac 62 | 63 | 64 | 65 | # command 66 | CUDA_VISIBLE_DEVICES=${GPU_ID} python ./train.py \ 67 | --display_id ${DISPLAY_ID} \ 68 | --dataroot ./datasets/${CLASS} \ 69 | --name ${NAME} \ 70 | --model ${MODEL} \ 71 | --display_port ${PORT} \ 72 | --direction ${DIRECTION} \ 73 | --checkpoints_dir ${CHECKPOINTS_DIR} \ 74 | --load_size ${LOAD_SIZE} \ 75 | --crop_size ${CROP_SIZE} \ 76 | --nz ${NZ} \ 77 | --save_epoch_freq ${SAVE_EPOCH} \ 78 | --input_nc ${INPUT_NC} \ 79 | --niter ${NITER} \ 80 | --niter_decay ${NITER_DECAY} \ 81 | --use_dropout 82 | -------------------------------------------------------------------------------- /scripts/train_edges2shoes.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | MODEL='bicycle_gan' 3 | # dataset details 4 | CLASS='edges2shoes' # facades, day2night, edges2shoes, edges2handbags, maps 5 | NZ=8 6 | NO_FLIP='--no_flip' 7 | DIRECTION='AtoB' 8 | LOAD_SIZE=256 9 | CROP_SIZE=256 10 | INPUT_NC=1 11 | NITER=30 12 | NITER_DECAY=30 13 | 14 | # training 15 | GPU_ID=0 16 | DISPLAY_ID=$((GPU_ID*10+1)) 17 | CHECKPOINTS_DIR=../checkpoints/${CLASS}/ 18 | NAME=${CLASS}_${MODEL} 19 | 20 | # command 21 | CUDA_VISIBLE_DEVICES=${GPU_ID} python ./train.py \ 22 | --display_id ${DISPLAY_ID} \ 23 | --dataroot ./datasets/${CLASS} \ 24 | --name ${NAME} \ 25 | --model ${MODEL} \ 26 | --direction ${DIRECTION} \ 27 | --checkpoints_dir ${CHECKPOINTS_DIR} \ 28 | --load_size ${LOAD_SIZE} \ 29 | --crop_size ${CROP_SIZE} \ 30 | --nz ${NZ} \ 31 | --input_nc ${INPUT_NC} \ 32 | --niter ${NITER} \ 33 | --niter_decay ${NITER_DECAY} \ 34 | --use_dropout 35 | -------------------------------------------------------------------------------- /scripts/train_facades.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | MODEL='bicycle_gan' 3 | # dataset details 4 | CLASS='facades' # facades, day2night, edges2shoes, edges2handbags, maps 5 | NZ=8 6 | NO_FLIP='' 7 | DIRECTION='BtoA' 8 | LOAD_SIZE=256 9 | CROP_SIZE=256 10 | INPUT_NC=3 11 | NITER=200 12 | NITER_DECAY=200 13 | SAVE_EPOCH=25 14 | 15 | # training 16 | GPU_ID=0 17 | DISPLAY_ID=$((GPU_ID*10+1)) 18 | CHECKPOINTS_DIR=../checkpoints/${CLASS}/ 19 | NAME=${CLASS}_${MODEL} 20 | 21 | # command 22 | python ./train.py \ 23 | --display_id ${DISPLAY_ID} \ 24 | --dataroot ./datasets/${CLASS} \ 25 | --name ${NAME} \ 26 | --model ${MODEL} \ 27 | --direction ${DIRECTION} \ 28 | --checkpoints_dir ${CHECKPOINTS_DIR} \ 29 | --load_size ${LOAD_SIZE} \ 30 | --crop_size ${CROP_SIZE} \ 31 | --nz ${NZ} \ 32 | --input_nc ${INPUT_NC} \ 33 | --niter ${NITER} \ 34 | --niter_decay ${NITER_DECAY} \ 35 | --save_epoch_freq ${SAVE_EPOCH} \ 36 | --use_dropout 37 | -------------------------------------------------------------------------------- /scripts/video_edges2shoes.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | # models 3 | RESULTS_DIR='./videos/edges2shoes' 4 | G_PATH='./pretrained_models/edges2shoes_net_G.pth' 5 | 6 | # dataset 7 | CLASS='edges2shoes' 8 | DIRECTION='AtoB' 9 | LOAD_SIZE=256 10 | CROP_SIZE=256 11 | INPUT_NC=1 12 | 13 | # misc 14 | GPU_ID=0 15 | NUM_TEST=5 # number of input images duirng test 16 | NUM_SAMPLES=20 # number of samples per input images 17 | 18 | # command 19 | CUDA_VISIBLE_DEVICES=${GPU_ID} python ./video.py \ 20 | --dataroot ./datasets/${CLASS} \ 21 | --results_dir ${RESULTS_DIR} \ 22 | --checkpoints_dir ./pretrained_models/ \ 23 | --name ${CLASS} \ 24 | --direction ${DIRECTION} \ 25 | --load_size ${LOAD_SIZE} --crop_size ${CROP_SIZE} \ 26 | --input_nc ${INPUT_NC} \ 27 | --num_test ${NUM_TEST} \ 28 | --n_samples ${NUM_SAMPLES} \ 29 | --center_crop \ 30 | --no_flip 31 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from options.test_options import TestOptions 3 | from data import create_dataset 4 | from models import create_model 5 | from util.visualizer import save_images 6 | from itertools import islice 7 | from util import html 8 | 9 | 10 | # options 11 | opt = TestOptions().parse() 12 | opt.num_threads = 1 # test code only supports num_threads=1 13 | opt.batch_size = 1 # test code only supports batch_size=1 14 | opt.serial_batches = True # no shuffle 15 | 16 | # create dataset 17 | dataset = create_dataset(opt) 18 | model = create_model(opt) 19 | model.setup(opt) 20 | model.eval() 21 | print('Loading model %s' % opt.model) 22 | 23 | # create website 24 | web_dir = os.path.join(opt.results_dir, opt.phase + '_sync' if opt.sync else opt.phase) 25 | webpage = html.HTML(web_dir, 'Training = %s, Phase = %s, Class =%s' % (opt.name, opt.phase, opt.name)) 26 | 27 | # sample random z 28 | if opt.sync: 29 | z_samples = model.get_z_random(opt.n_samples + 1, opt.nz) 30 | 31 | # test stage 32 | for i, data in enumerate(islice(dataset, opt.num_test)): 33 | model.set_input(data) 34 | print('process input image %3.3d/%3.3d' % (i, opt.num_test)) 35 | if not opt.sync: 36 | z_samples = model.get_z_random(opt.n_samples + 1, opt.nz) 37 | for nn in range(opt.n_samples + 1): 38 | encode = nn == 0 and not opt.no_encode 39 | real_A, fake_B, real_B = model.test(z_samples[[nn]], encode=encode) 40 | if nn == 0: 41 | images = [real_A, real_B, fake_B] 42 | names = ['input', 'ground truth', 'encoded'] 43 | else: 44 | images.append(fake_B) 45 | names.append('random_sample%2.2d' % nn) 46 | 47 | img_path = 'input_%3.3d' % i 48 | save_images(webpage, images, names, img_path, aspect_ratio=opt.aspect_ratio, width=opt.crop_size) 49 | 50 | webpage.save() 51 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """General-purpose training script for image-to-image translation. 2 | This script works for various models (with option '--model': e.g., bicycle_gan, pix2pix, test) and 3 | different datasets (with option '--dataset_mode': e.g., aligned, single). 4 | You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model'). 5 | It first creates model, dataset, and visualizer given the option. 6 | It then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models. 7 | The script supports continue/resume training. Use '--continue_train' to resume your previous training. 8 | Example: 9 | Train a BiCycleGAN model: 10 | python train.py --dataroot ./datasets/facades --name facades_bicyclegan --model bicycle_gan --direction BtoA 11 | Train a pix2pix model: 12 | python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA 13 | See options/base_options.py and options/train_options.py for more training options. 14 | """ 15 | import time 16 | from options.train_options import TrainOptions 17 | from data import create_dataset 18 | from models import create_model 19 | from util.visualizer import Visualizer 20 | 21 | if __name__ == '__main__': 22 | opt = TrainOptions().parse() # get training options 23 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options 24 | dataset_size = len(dataset) # get the number of images in the dataset. 25 | print('The number of training images = %d' % dataset_size) 26 | 27 | model = create_model(opt) # create a model given opt.model and other options 28 | model.setup(opt) # regular setup: load and print networks; create schedulers 29 | visualizer = Visualizer(opt) # create a visualizer that display/save images and plots 30 | total_iters = 0 # the total number of training iterations 31 | 32 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): # outer loop for different epochs; we save the model by , + 33 | epoch_start_time = time.time() # timer for entire epoch 34 | iter_data_time = time.time() # timer for data loading per iteration 35 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch 36 | 37 | for i, data in enumerate(dataset): # inner loop within one epoch 38 | iter_start_time = time.time() # timer for computation per iteration 39 | if total_iters % opt.print_freq == 0: 40 | t_data = iter_start_time - iter_data_time 41 | visualizer.reset() 42 | total_iters += opt.batch_size 43 | epoch_iter += opt.batch_size 44 | model.set_input(data) # unpack data from dataset and apply preprocessing 45 | if not model.is_train(): # if this batch of input data is enough for training. 46 | print('skip this batch') 47 | continue 48 | model.optimize_parameters() # calculate loss functions, get gradients, update network weights 49 | 50 | if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file 51 | save_result = total_iters % opt.update_html_freq == 0 52 | model.compute_visuals() 53 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 54 | 55 | if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk 56 | losses = model.get_current_losses() 57 | t_comp = (time.time() - iter_start_time) / opt.batch_size 58 | visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data) 59 | if opt.display_id > 0: 60 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses) 61 | 62 | if total_iters % opt.save_latest_freq == 0: # cache our latest model every iterations 63 | print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) 64 | model.save_networks('latest') 65 | 66 | iter_data_time = time.time() 67 | if epoch % opt.save_epoch_freq == 0: # cache our model every epochs 68 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) 69 | model.save_networks('latest') 70 | model.save_networks(epoch) 71 | 72 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 73 | model.update_learning_rate() # update learning rates at the end of every epoch. 74 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junyanz/BicycleGAN/40b9d52c27b9831f56c1c7c7a6ddde8bc9149067/util/__init__.py -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def save(self): 69 | """save the current content to the HMTL file""" 70 | html_file = '%s/index.html' % self.web_dir 71 | f = open(html_file, 'wt') 72 | f.write(self.doc.render()) 73 | f.close() 74 | 75 | 76 | if __name__ == '__main__': # we show an example usage here. 77 | html = HTML('web/', 'test_html') 78 | html.add_header('hello world') 79 | 80 | ims, txts, links = [], [], [] 81 | for n in range(4): 82 | ims.append('image_%d.png' % n) 83 | txts.append('text_%d' % n) 84 | links.append('image_%d.png' % n) 85 | html.add_images(ims, txts, links) 86 | html.save() 87 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import os 6 | import pickle 7 | 8 | 9 | def tensor2im(input_image, imtype=np.uint8): 10 | """"Convert a Tensor array into a numpy image array. 11 | Parameters: 12 | input_image (tensor) -- the input image tensor array 13 | imtype (type) -- the desired type of the converted numpy array 14 | """ 15 | if not isinstance(input_image, np.ndarray): 16 | if isinstance(input_image, torch.Tensor): # get the data from a variable 17 | image_tensor = input_image.data 18 | else: 19 | return input_image 20 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 21 | if image_numpy.shape[0] == 1: # grayscale to RGB 22 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 23 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 24 | else: # if it is a numpy array, do nothing 25 | image_numpy = input_image 26 | return image_numpy.astype(imtype) 27 | 28 | 29 | def tensor2vec(vector_tensor): 30 | numpy_vec = vector_tensor.data.cpu().numpy() 31 | if numpy_vec.ndim == 4: 32 | return numpy_vec[:, :, 0, 0] 33 | else: 34 | return numpy_vec 35 | 36 | 37 | def pickle_load(file_name): 38 | data = None 39 | with open(file_name, 'rb') as f: 40 | data = pickle.load(f) 41 | return data 42 | 43 | 44 | def pickle_save(file_name, data): 45 | with open(file_name, 'wb') as f: 46 | pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) 47 | 48 | 49 | def diagnose_network(net, name='network'): 50 | """Calculate and print the mean of average absolute(gradients) 51 | Parameters: 52 | net (torch network) -- Torch network 53 | name (str) -- the name of the network 54 | """ 55 | mean = 0.0 56 | count = 0 57 | for param in net.parameters(): 58 | if param.grad is not None: 59 | mean += torch.mean(torch.abs(param.grad.data)) 60 | count += 1 61 | if count > 0: 62 | mean = mean / count 63 | print(name) 64 | print(mean) 65 | 66 | 67 | def interp_z(z0, z1, num_frames, interp_mode='linear'): 68 | zs = [] 69 | if interp_mode == 'linear': 70 | for n in range(num_frames): 71 | ratio = n / float(num_frames - 1) 72 | z_t = (1 - ratio) * z0 + ratio * z1 73 | zs.append(z_t[np.newaxis, :]) 74 | zs = np.concatenate(zs, axis=0).astype(np.float32) 75 | 76 | if interp_mode == 'slerp': 77 | z0_n = z0 / (np.linalg.norm(z0) + 1e-10) 78 | z1_n = z1 / (np.linalg.norm(z1) + 1e-10) 79 | omega = np.arccos(np.dot(z0_n, z1_n)) 80 | sin_omega = np.sin(omega) 81 | if sin_omega < 1e-10 and sin_omega > -1e-10: 82 | zs = interp_z(z0, z1, num_frames, interp_mode='linear') 83 | else: 84 | for n in range(num_frames): 85 | ratio = n / float(num_frames - 1) 86 | z_t = np.sin((1 - ratio) * omega) / sin_omega * z0 + np.sin(ratio * omega) / sin_omega * z1 87 | zs.append(z_t[np.newaxis, :]) 88 | zs = np.concatenate(zs, axis=0).astype(np.float32) 89 | 90 | return zs 91 | 92 | 93 | def save_image(image_numpy, image_path, aspect_ratio=1.0): 94 | """Save a numpy image to the disk 95 | Parameters: 96 | image_numpy (numpy array) -- input numpy array 97 | image_path (str) -- the path of the image 98 | """ 99 | 100 | image_pil = Image.fromarray(image_numpy) 101 | h, w, _ = image_numpy.shape 102 | 103 | if aspect_ratio > 1.0: 104 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) 105 | if aspect_ratio < 1.0: 106 | image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) 107 | image_pil.save(image_path) 108 | 109 | 110 | def print_numpy(x, val=True, shp=False): 111 | """Print the mean, min, max, median, std, and size of a numpy array 112 | Parameters: 113 | val (bool) -- if print the values of the numpy array 114 | shp (bool) -- if print the shape of the numpy array 115 | """ 116 | x = x.astype(np.float64) 117 | if shp: 118 | print('shape,', x.shape) 119 | if val: 120 | x = x.flatten() 121 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 122 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 123 | 124 | 125 | def mkdirs(paths): 126 | """create empty directories if they don't exist 127 | Parameters: 128 | paths (str list) -- a list of directory paths 129 | """ 130 | if isinstance(paths, list) and not isinstance(paths, str): 131 | for path in paths: 132 | mkdir(path) 133 | else: 134 | mkdir(paths) 135 | 136 | 137 | def mkdir(path): 138 | """create a single empty directory if it didn't exist 139 | Parameters: 140 | path (str) -- a single directory path 141 | """ 142 | if not os.path.exists(path): 143 | os.makedirs(path) 144 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import ntpath 5 | import time 6 | from . import util 7 | from . import html 8 | from subprocess import Popen, PIPE 9 | 10 | 11 | if sys.version_info[0] == 2: 12 | VisdomExceptionBase = Exception 13 | else: 14 | VisdomExceptionBase = ConnectionError 15 | 16 | 17 | def save_images(webpage, images, names, image_path, aspect_ratio=1.0, width=256): 18 | """Save images to the disk. 19 | Parameters: 20 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) 21 | images (numpy array list) -- a list of numpy array that stores images 22 | names (str list) -- a str list stores the names of the images above 23 | image_path (str) -- the string is used to create image paths 24 | aspect_ratio (float) -- the aspect ratio of saved images 25 | width (int) -- the images will be resized to width x width 26 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. 27 | """ 28 | image_dir = webpage.get_image_dir() 29 | name = ntpath.basename(image_path) 30 | 31 | webpage.add_header(name) 32 | ims, txts, links = [], [], [] 33 | 34 | for label, im_data in zip(names, images): 35 | im = util.tensor2im(im_data) 36 | image_name = '%s_%s.png' % (name, label) 37 | save_path = os.path.join(image_dir, image_name) 38 | util.save_image(im, save_path, aspect_ratio=aspect_ratio) 39 | ims.append(image_name) 40 | txts.append(label) 41 | links.append(image_name) 42 | webpage.add_images(ims, txts, links, width=width) 43 | 44 | 45 | class Visualizer(): 46 | """This class includes several functions that can display/save images and print/save logging information. 47 | It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. 48 | """ 49 | 50 | def __init__(self, opt): 51 | """Initialize the Visualizer class 52 | Parameters: 53 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions 54 | Step 1: Cache the training/test options 55 | Step 2: connect to a visdom server 56 | Step 3: create an HTML object for saveing HTML filters 57 | Step 4: create a logging file to store training losses 58 | """ 59 | self.opt = opt # cache the option 60 | self.display_id = opt.display_id 61 | self.use_html = opt.isTrain and not opt.no_html 62 | self.win_size = opt.display_winsize 63 | self.name = opt.name 64 | self.port = opt.display_port 65 | self.saved = False 66 | if self.display_id > 0: # connect to a visdom server given and 67 | import visdom 68 | self.ncols = opt.display_ncols 69 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env) 70 | if not self.vis.check_connection(): 71 | self.create_visdom_connections() 72 | if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ 73 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 74 | self.img_dir = os.path.join(self.web_dir, 'images') 75 | print('create web directory %s...' % self.web_dir) 76 | util.mkdirs([self.web_dir, self.img_dir]) 77 | # create a logging file to store training losses 78 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 79 | with open(self.log_name, "a") as log_file: 80 | now = time.strftime("%c") 81 | log_file.write('================ Training Loss (%s) ================\n' % now) 82 | 83 | def reset(self): 84 | """Reset the self.saved status""" 85 | self.saved = False 86 | 87 | def create_visdom_connections(self): 88 | """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ 89 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port 90 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....') 91 | print('Command: %s' % cmd) 92 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) 93 | 94 | def display_current_results(self, visuals, epoch, save_result): 95 | """Display current results on visdom; save current results to an HTML file. 96 | Parameters: 97 | visuals (OrderedDict) - - dictionary of images to display or save 98 | epoch (int) - - the current epoch 99 | save_result (bool) - - if save the current results to an HTML file 100 | """ 101 | if self.display_id > 0: # show images in the browser using visdom 102 | ncols = self.ncols 103 | if ncols > 0: # show all the images in one visdom panel 104 | ncols = min(ncols, len(visuals)) 105 | h, w = next(iter(visuals.values())).shape[:2] 106 | table_css = """""" % (w, h) # create a table css 110 | # create a table of images. 111 | title = self.name 112 | label_html = '' 113 | label_html_row = '' 114 | images = [] 115 | idx = 0 116 | for label, image in visuals.items(): 117 | image_numpy = util.tensor2im(image) 118 | label_html_row += '%s' % label 119 | images.append(image_numpy.transpose([2, 0, 1])) 120 | idx += 1 121 | if idx % ncols == 0: 122 | label_html += '%s' % label_html_row 123 | label_html_row = '' 124 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 125 | while idx % ncols != 0: 126 | images.append(white_image) 127 | label_html_row += '' 128 | idx += 1 129 | if label_html_row != '': 130 | label_html += '%s' % label_html_row 131 | try: 132 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 133 | padding=2, opts=dict(title=title + ' images')) 134 | label_html = '%s
' % label_html 135 | self.vis.text(table_css + label_html, win=self.display_id + 2, 136 | opts=dict(title=title + ' labels')) 137 | except VisdomExceptionBase: 138 | self.create_visdom_connections() 139 | 140 | else: # show each image in a separate visdom panel; 141 | idx = 1 142 | try: 143 | for label, image in visuals.items(): 144 | image_numpy = util.tensor2im(image) 145 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 146 | win=self.display_id + idx) 147 | idx += 1 148 | except VisdomExceptionBase: 149 | self.create_visdom_connections() 150 | 151 | if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. 152 | self.saved = True 153 | # save images to the disk 154 | for label, image in visuals.items(): 155 | image_numpy = util.tensor2im(image) 156 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 157 | util.save_image(image_numpy, img_path) 158 | 159 | # update website 160 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1) 161 | for n in range(epoch, 0, -1): 162 | webpage.add_header('epoch [%d]' % n) 163 | ims, txts, links = [], [], [] 164 | 165 | for label, image_numpy in visuals.items(): 166 | image_numpy = util.tensor2im(image) 167 | img_path = 'epoch%.3d_%s.png' % (n, label) 168 | ims.append(img_path) 169 | txts.append(label) 170 | links.append(img_path) 171 | webpage.add_images(ims, txts, links, width=self.win_size) 172 | webpage.save() 173 | 174 | def plot_current_losses(self, epoch, counter_ratio, losses): 175 | """display the current losses on visdom display: dictionary of error labels and values 176 | Parameters: 177 | epoch (int) -- current epoch 178 | counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 179 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 180 | """ 181 | if not hasattr(self, 'plot_data'): 182 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 183 | self.plot_data['X'].append(epoch + counter_ratio) 184 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 185 | try: 186 | self.vis.line( 187 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 188 | Y=np.array(self.plot_data['Y']), 189 | opts={ 190 | 'title': self.name + ' loss over time', 191 | 'legend': self.plot_data['legend'], 192 | 'xlabel': 'epoch', 193 | 'ylabel': 'loss'}, 194 | win=self.display_id) 195 | except VisdomExceptionBase: 196 | self.create_visdom_connections() 197 | 198 | # losses: same format as |losses| of plot_current_losses 199 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data): 200 | """print current losses on console; also save the losses to the disk 201 | Parameters: 202 | epoch (int) -- current epoch 203 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) 204 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 205 | t_comp (float) -- computational time per data point (normalized by batch_size) 206 | t_data (float) -- data loading time per data point (normalized by batch_size) 207 | """ 208 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) 209 | for k, v in losses.items(): 210 | message += '%s: %.3f ' % (k, v) 211 | 212 | print(message) # print the message 213 | with open(self.log_name, "a") as log_file: 214 | log_file.write('%s\n' % message) # save the message 215 | -------------------------------------------------------------------------------- /video.py: -------------------------------------------------------------------------------- 1 | from options.video_options import VideoOptions 2 | from data import create_dataset 3 | from models import create_model 4 | from itertools import islice 5 | from util import util 6 | import numpy as np 7 | import moviepy.editor 8 | import os 9 | import torch 10 | 11 | 12 | def get_random_z(opt): 13 | z_samples = np.random.normal(0, 1, (opt.n_samples + 1, opt.nz)) 14 | return z_samples 15 | 16 | 17 | def produce_frame(t): 18 | k = int(t * opt.fps) 19 | return np.concatenate(frame_rows[k], axis=1 - use_vertical) 20 | 21 | 22 | # hard-code opt 23 | opt = VideoOptions().parse() 24 | opt.num_threads = 1 # test code only supports num_threads=1 25 | opt.batch_size = 1 # test code only supports batch_size=1 26 | opt.no_encode = True # do not use encoder 27 | 28 | dataset = create_dataset(opt) 29 | model = create_model(opt) 30 | model.setup(opt) 31 | model.eval() 32 | interp_mode = 'slerp' 33 | use_vertical = 1 if opt.align_mode == 'vertical' else 0 34 | 35 | print('Loading model %s' % opt.model) 36 | # create website 37 | results_dir = opt.results_dir 38 | util.mkdir(results_dir) 39 | total_frames = opt.num_frames * opt.n_samples 40 | 41 | 42 | z_samples = get_random_z(opt) 43 | frame_rows = [[] for n in range(total_frames)] 44 | 45 | for i, data in enumerate(islice(dataset, opt.num_test)): 46 | print('process input image %3.3d/%3.3d' % (i, opt.num_test)) 47 | model.set_input(data) 48 | real_A = util.tensor2im(model.real_A) 49 | wb = opt.border 50 | hb = opt.border 51 | h = real_A.shape[0] 52 | w = real_A.shape[1] # border 53 | real_A_b = np.full((h + hb, w + wb, opt.output_nc), 255, real_A.dtype) 54 | real_A_b[hb:, wb:, :] = real_A 55 | frames = [[real_A_b] for n in range(total_frames)] 56 | 57 | for n in range(opt.n_samples): 58 | z0 = z_samples[n] 59 | z1 = z_samples[n + 1] 60 | zs = util.interp_z(z0, z1, num_frames=opt.num_frames, interp_mode=interp_mode) 61 | for k in range(opt.num_frames): 62 | zs_k = (torch.Tensor(zs[[k]])).to(model.device) 63 | _, fake_B_device, _ = model.test(zs_k, encode=False) 64 | fake_B = util.tensor2im(fake_B_device) 65 | fake_B_b = np.full((h + hb, w + wb, opt.output_nc), 255, fake_B.dtype) 66 | fake_B_b[hb:, wb:, :] = fake_B 67 | frames[k + opt.num_frames * n].append(fake_B_b) 68 | 69 | for k in range(total_frames): 70 | frame_row = np.concatenate(frames[k], axis=use_vertical) 71 | frame_rows[k].append(frame_row) 72 | 73 | # compile it to a vdieo 74 | images_dir = os.path.join(results_dir, 'frames_seed%4.4d' % opt.seed) 75 | util.mkdir(images_dir) 76 | 77 | 78 | for k in range(total_frames): 79 | final_frame = np.concatenate(frame_rows[k], axis=1 - use_vertical) 80 | util.save_image(final_frame, os.path.join( 81 | images_dir, 'frame_%4.4d.jpg' % k)) 82 | 83 | 84 | video_file = os.path.join( 85 | results_dir, 'morphing_video_seed%4.4d_fps%d.mp4' % (opt.seed, opt.fps)) 86 | video = moviepy.editor.VideoClip( 87 | produce_frame, duration=float(total_frames) / opt.fps) 88 | video.write_videofile(video_file, fps=30, codec='libx264', bitrate='16M') 89 | --------------------------------------------------------------------------------