├── .gitignore ├── LICENSE.txt ├── README.md ├── _config.yml ├── data ├── __init__.py ├── aligned_dataset.py ├── base_data_loader.py ├── base_dataset.py ├── custom_dataset_data_loader.py ├── data_loader.py └── image_folder.py ├── datasets └── cityscapes │ ├── test_inst │ ├── frankfurt_000000_000576_gtFine_instanceIds.png │ ├── frankfurt_000000_001236_gtFine_instanceIds.png │ ├── frankfurt_000000_003357_gtFine_instanceIds.png │ ├── frankfurt_000000_011810_gtFine_instanceIds.png │ ├── frankfurt_000000_012868_gtFine_instanceIds.png │ ├── frankfurt_000001_013710_gtFine_instanceIds.png │ ├── frankfurt_000001_015328_gtFine_instanceIds.png │ ├── frankfurt_000001_023769_gtFine_instanceIds.png │ ├── frankfurt_000001_028335_gtFine_instanceIds.png │ ├── frankfurt_000001_032711_gtFine_instanceIds.png │ ├── frankfurt_000001_033655_gtFine_instanceIds.png │ ├── frankfurt_000001_042733_gtFine_instanceIds.png │ ├── frankfurt_000001_047552_gtFine_instanceIds.png │ ├── frankfurt_000001_054640_gtFine_instanceIds.png │ └── frankfurt_000001_055387_gtFine_instanceIds.png │ ├── test_label │ ├── frankfurt_000000_000576_gtFine_labelIds.png │ ├── frankfurt_000000_001236_gtFine_labelIds.png │ ├── frankfurt_000000_003357_gtFine_labelIds.png │ ├── frankfurt_000000_011810_gtFine_labelIds.png │ ├── frankfurt_000000_012868_gtFine_labelIds.png │ ├── frankfurt_000001_013710_gtFine_labelIds.png │ ├── frankfurt_000001_015328_gtFine_labelIds.png │ ├── frankfurt_000001_023769_gtFine_labelIds.png │ ├── frankfurt_000001_028335_gtFine_labelIds.png │ ├── frankfurt_000001_032711_gtFine_labelIds.png │ ├── frankfurt_000001_033655_gtFine_labelIds.png │ ├── frankfurt_000001_042733_gtFine_labelIds.png │ ├── frankfurt_000001_047552_gtFine_labelIds.png │ ├── frankfurt_000001_054640_gtFine_labelIds.png │ └── frankfurt_000001_055387_gtFine_labelIds.png │ ├── train_img │ ├── aachen_000000_000019_leftImg8bit.png │ ├── aachen_000001_000019_leftImg8bit.png │ ├── aachen_000002_000019_leftImg8bit.png │ ├── aachen_000003_000019_leftImg8bit.png │ └── aachen_000004_000019_leftImg8bit.png │ ├── train_inst │ ├── aachen_000000_000019_gtFine_instanceIds.png │ ├── aachen_000001_000019_gtFine_instanceIds.png │ ├── aachen_000002_000019_gtFine_instanceIds.png │ ├── aachen_000003_000019_gtFine_instanceIds.png │ └── aachen_000004_000019_gtFine_instanceIds.png │ └── train_label │ ├── aachen_000000_000019_gtFine_labelIds.png │ ├── aachen_000001_000019_gtFine_labelIds.png │ ├── aachen_000002_000019_gtFine_labelIds.png │ ├── aachen_000003_000019_gtFine_labelIds.png │ └── aachen_000004_000019_gtFine_labelIds.png ├── encode_features.py ├── imgs ├── city_short.gif ├── cityscapes_1.jpg ├── cityscapes_2.jpg ├── cityscapes_3.jpg ├── cityscapes_4.jpg ├── face1_1.jpg ├── face1_2.jpg ├── face1_3.jpg ├── face2_1.jpg ├── face2_2.jpg ├── face2_3.jpg ├── face_short.gif ├── teaser_720.gif ├── teaser_label.gif ├── teaser_label.png ├── teaser_ours.jpg └── teaser_style.gif ├── models ├── __init__.py ├── base_model.py ├── models.py ├── networks.py └── pix2pixHD_model.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── precompute_feature_maps.py ├── scripts ├── test_1024p.sh ├── test_1024p_feat.sh ├── test_512p.sh ├── test_512p_feat.sh ├── train_1024p_12G.sh ├── train_1024p_24G.sh ├── train_1024p_feat_12G.sh ├── train_1024p_feat_24G.sh ├── train_512p.sh ├── train_512p_feat.sh └── train_512p_multigpu.sh ├── test.py ├── train.py └── util ├── __init__.py ├── html.py ├── image_pool.py ├── util.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | debug* 2 | checkpoints/ 3 | results/ 4 | build/ 5 | dist/ 6 | torch.egg-info/ 7 | */**/__pycache__ 8 | torch/version.py 9 | torch/csrc/generic/TensorMethods.cpp 10 | torch/lib/*.so* 11 | torch/lib/*.dylib* 12 | torch/lib/*.h 13 | torch/lib/build 14 | torch/lib/tmp_install 15 | torch/lib/include 16 | torch/lib/torch_shm_manager 17 | torch/csrc/cudnn/cuDNN.cpp 18 | torch/csrc/nn/THNN.cwrap 19 | torch/csrc/nn/THNN.cpp 20 | torch/csrc/nn/THCUNN.cwrap 21 | torch/csrc/nn/THCUNN.cpp 22 | torch/csrc/nn/THNN_generic.cwrap 23 | torch/csrc/nn/THNN_generic.cpp 24 | torch/csrc/nn/THNN_generic.h 25 | docs/src/**/* 26 | test/data/legacy_modules.t7 27 | test/data/gpu_tensors.pt 28 | test/htmlcov 29 | test/.coverage 30 | */*.pyc 31 | */**/*.pyc 32 | */**/**/*.pyc 33 | */**/**/**/*.pyc 34 | */**/**/**/**/*.pyc 35 | */*.so* 36 | */**/*.so* 37 | */**/*.dylib* 38 | test/data/legacy_serialized.pt 39 | *.DS_Store 40 | *~ 41 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (C) 2017 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu. 2 | All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | 5 | Permission to use, copy, modify, and distribute this software and its documentation 6 | for any non-commercial purpose is hereby granted without fee, provided that the above 7 | copyright notice appear in all copies and that both that copyright notice and this 8 | permission notice appear in supporting documentation, and that the name of the author 9 | not be used in advertising or publicity pertaining to distribution of the software 10 | without specific, written prior permission. 11 | 12 | THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL 13 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE. 14 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL 15 | DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 16 | WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING 17 | OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 18 | 19 | 20 | --------------------------- LICENSE FOR pytorch-CycleGAN-and-pix2pix ---------------- 21 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 22 | All rights reserved. 23 | 24 | Redistribution and use in source and binary forms, with or without 25 | modification, are permitted provided that the following conditions are met: 26 | 27 | * Redistributions of source code must retain the above copyright notice, this 28 | list of conditions and the following disclaimer. 29 | 30 | * Redistributions in binary form must reproduce the above copyright notice, 31 | this list of conditions and the following disclaimer in the documentation 32 | and/or other materials provided with the distribution. 33 | 34 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 35 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 36 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 37 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 38 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 39 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 40 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 41 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 42 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 43 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |



4 | 5 | # pix2pixHD 6 | ### [[Project]](https://tcwang0509.github.io/pix2pixHD/) [[Youtube]](https://youtu.be/3AIpPlzM_qs) [[Paper]](https://arxiv.org/pdf/1711.11585.pdf)
7 | Pytorch implementation of our method for high-resolution (e.g. 2048x1024) photorealistic image-to-image translation. It can be used for turning semantic label maps into photo-realistic images or synthesizing portraits from face label maps.

8 | [High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs](https://tcwang0509.github.io/pix2pixHD/) 9 | [Ting-Chun Wang](https://tcwang0509.github.io/)1, [Ming-Yu Liu](http://mingyuliu.net/)1, [Jun-Yan Zhu](http://people.eecs.berkeley.edu/~junyanz/)2, Andrew Tao1, [Jan Kautz](http://jankautz.com/)1, [Bryan Catanzaro](http://catanzaro.name/)1 10 | 1NVIDIA Corporation, 2UC Berkeley 11 | In arxiv, 2017. 12 | 13 | ## Image-to-image translation at 2k/1k resolution 14 | - Our label-to-streetview results 15 |

16 | 17 | 18 |

19 | - Interactive editing results 20 |

21 | 22 | 23 |

24 | - Additional streetview results 25 |

26 | 27 | 28 |

29 |

30 | 31 | 32 |

33 | 34 | - Label-to-face and interactive editing results 35 |

36 | 37 | 38 | 39 |

40 |

41 | 42 | 43 | 44 |

45 | 46 | - Our editing interface 47 |

48 | 49 | 50 |

51 | 52 | ## Prerequisites 53 | - Linux or macOS 54 | - Python 2 or 3 55 | - NVIDIA GPU (12G or 24G memory) + CUDA cuDNN 56 | 57 | ## Getting Started 58 | ### Installation 59 | - Install PyTorch and dependencies from http://pytorch.org 60 | - Install python libraries [dominate](https://github.com/Knio/dominate). 61 | ```bash 62 | pip install dominate 63 | ``` 64 | - Clone this repo: 65 | ```bash 66 | git clone https://github.com/NVIDIA/pix2pixHD 67 | cd pix2pixHD 68 | ``` 69 | 70 | 71 | ### Testing 72 | - A few example Cityscapes test images are included in the `datasets` folder. 73 | - Please download the pre-trained Cityscapes model from [here](https://drive.google.com/file/d/1h9SykUnuZul7J3Nbms2QGH1wa85nbN2-/view?usp=sharing) (google drive link), and put it under `./checkpoints/label2city_1024p/` 74 | - Test the model (`bash ./scripts/test_1024p.sh`): 75 | ```bash 76 | #!./scripts/test_1024p.sh 77 | python test.py --name label2city_1024p --netG local --ngf 32 --resize_or_crop none 78 | ``` 79 | The test results will be saved to a html file here: `./results/label2city_1024p/test_latest/index.html`. 80 | 81 | More example scripts can be found in the `scripts` directory. 82 | 83 | 84 | ### Dataset 85 | - We use the Cityscapes dataset. To train a model on the full dataset, please download it from the [official website](https://www.cityscapes-dataset.com/) (registration required). 86 | After downloading, please put it under the `datasets` folder in the same way the example images are provided. 87 | 88 | 89 | ### Training 90 | - Train a model at 1024 x 512 resolution (`bash ./scripts/train_512p.sh`): 91 | ```bash 92 | #!./scripts/train_512p.sh 93 | python train.py --name label2city_512p 94 | ``` 95 | - To view training results, please checkout intermediate results in `./checkpoints/label2city_512p/web/index.html`. 96 | If you have tensorflow installed, you can see tensorboard logs in `./checkpoints/label2city_512p/logs` by adding `--tf_log` to the training scripts. 97 | 98 | ### Multi-GPU training 99 | - Train a model using multiple GPUs (`bash ./scripts/train_512p_multigpu.sh`): 100 | ```bash 101 | #!./scripts/train_512p_multigpu.sh 102 | python train.py --name label2city_512p --batchSize 8 --gpu_ids 0,1,2,3,4,5,6,7 103 | ``` 104 | Note: this is not tested and we trained our model using single GPU only. Please use at your own discretion. 105 | 106 | ### Training at full resolution 107 | - To train the images at full resolution (2048 x 1024) requires a GPU with 24G memory (`bash ./scripts/train_1024p_24G.sh`). 108 | If only GPUs with 12G memory are available, please use the 12G script (`bash ./scripts/train_1024p_12G.sh`), which will crop the images during training. Performance is not guaranteed using this script. 109 | 110 | ### Training with your own dataset 111 | - If you want to train with your own dataset, please generate label maps which are one-channel whose pixel values correspond to the object labels (i.e. 0,1,...,N-1, where N is the number of labels). This is because we need to generate one-hot vectors from the label maps. Please also specity `--label_nc N` during both training and testing. 112 | - If your input is not a label map, please just specify `--label_nc 0` which will directly use the RGB colors as input. 113 | - If you don't have instance maps or don't want to use them, please specify `--no_instance`. 114 | - The default setting for preprocessing is `scale_width`, which will scale the width of all training images to `opt.loadSize` (1024) while keeping the aspect ratio. If you want a different setting, please change it by using the `--resize_or_crop` option. For example, `scale_width_and_crop` first resizes the image to have width `opt.loadSize` and then does random cropping of size `(opt.fineSize, opt.fineSize)`. `crop` skips the resizing step and only performs random cropping. If you don't want any preprocessing, please specify `none`, which will do nothing other than making sure the image is divisible by 32. 115 | 116 | ## More Training/Test Details 117 | - Flags: see `options/train_options.py` and `options/base_options.py` for all the training flags; see `options/test_options.py` and `options/base_options.py` for all the test flags. 118 | - Instance map: we take in both label maps and instance maps as input. If you don't want to use instance maps, please specify the flag `--no_instance`. 119 | 120 | 121 | ## Citation 122 | 123 | If you find this useful for your research, please use the following. 124 | 125 | ``` 126 | @article{wang2017highres, 127 | title={High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs}, 128 | author={Ting-Chun Wang and Ming-Yu Liu and Jun-Yan Zhu and Andrew Tao and Jan Kautz and Bryan Catanzaro}, 129 | journal={arXiv preprint arXiv:1711.11585}, 130 | year={2017} 131 | } 132 | ``` 133 | 134 | ## Acknowledgments 135 | This code borrows heavily from [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). 136 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-minimal -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/data/__init__.py -------------------------------------------------------------------------------- /data/aligned_dataset.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import os.path 4 | from data.base_dataset import BaseDataset, get_params, get_transform, normalize 5 | from data.image_folder import make_dataset 6 | from PIL import Image 7 | 8 | class AlignedDataset(BaseDataset): 9 | def initialize(self, opt): 10 | self.opt = opt 11 | self.root = opt.dataroot 12 | 13 | ### label maps 14 | self.dir_label = os.path.join(opt.dataroot, opt.phase + '_label') 15 | self.label_paths = sorted(make_dataset(self.dir_label)) 16 | 17 | ### real images 18 | if opt.isTrain: 19 | self.dir_image = os.path.join(opt.dataroot, opt.phase + '_img') 20 | self.image_paths = sorted(make_dataset(self.dir_image)) 21 | 22 | ### instance maps 23 | if not opt.no_instance: 24 | self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst') 25 | self.inst_paths = sorted(make_dataset(self.dir_inst)) 26 | 27 | ### load precomputed instance-wise encoded features 28 | if opt.load_features: 29 | self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat') 30 | print('----------- loading features from %s ----------' % self.dir_feat) 31 | self.feat_paths = sorted(make_dataset(self.dir_feat)) 32 | 33 | self.dataset_size = len(self.label_paths) 34 | 35 | def __getitem__(self, index): 36 | ### label maps 37 | label_path = self.label_paths[index] 38 | label = Image.open(label_path) 39 | params = get_params(self.opt, label.size) 40 | if self.opt.label_nc == 0: 41 | transform_label = get_transform(self.opt, params) 42 | label_tensor = transform_label(label.convert('RGB')) 43 | else: 44 | transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) 45 | label_tensor = transform_label(label) * 255.0 46 | 47 | image_tensor = inst_tensor = feat_tensor = 0 48 | ### real images 49 | if self.opt.isTrain: 50 | image_path = self.image_paths[index] 51 | image = Image.open(image_path).convert('RGB') 52 | transform_image = get_transform(self.opt, params) 53 | image_tensor = transform_image(image) 54 | 55 | ### if using instance maps 56 | if not self.opt.no_instance: 57 | inst_path = self.inst_paths[index] 58 | inst = Image.open(inst_path) 59 | inst_tensor = transform_label(inst) 60 | 61 | if self.opt.load_features: 62 | feat_path = self.feat_paths[index] 63 | feat = Image.open(feat_path).convert('RGB') 64 | norm = normalize() 65 | feat_tensor = norm(transform_label(feat)) 66 | 67 | input_dict = {'label': label_tensor, 'inst': inst_tensor, 'image': image_tensor, 68 | 'feat': feat_tensor, 'path': label_path} 69 | 70 | return input_dict 71 | 72 | def __len__(self): 73 | return len(self.label_paths) 74 | 75 | def name(self): 76 | return 'AlignedDataset' -------------------------------------------------------------------------------- /data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseDataLoader(): 3 | def __init__(self): 4 | pass 5 | 6 | def initialize(self, opt): 7 | self.opt = opt 8 | pass 9 | 10 | def load_data(): 11 | return None 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import torch.utils.data as data 4 | from PIL import Image 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | import random 8 | 9 | class BaseDataset(data.Dataset): 10 | def __init__(self): 11 | super(BaseDataset, self).__init__() 12 | 13 | def name(self): 14 | return 'BaseDataset' 15 | 16 | def initialize(self, opt): 17 | pass 18 | 19 | def get_params(opt, size): 20 | w, h = size 21 | new_h = h 22 | new_w = w 23 | if opt.resize_or_crop == 'resize_and_crop': 24 | new_h = new_w = opt.loadSize 25 | elif opt.resize_or_crop == 'scale_width_and_crop': 26 | new_w = opt.loadSize 27 | new_h = opt.loadSize * h // w 28 | 29 | x = random.randint(0, np.maximum(0, new_w - opt.fineSize)) 30 | y = random.randint(0, np.maximum(0, new_h - opt.fineSize)) 31 | 32 | flip = random.random() > 0.5 33 | return {'crop_pos': (x, y), 'flip': flip} 34 | 35 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True): 36 | transform_list = [] 37 | if 'resize' in opt.resize_or_crop: 38 | osize = [opt.loadSize, opt.loadSize] 39 | transform_list.append(transforms.Scale(osize, method)) 40 | elif 'scale_width' in opt.resize_or_crop: 41 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) 42 | 43 | if 'crop' in opt.resize_or_crop: 44 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) 45 | 46 | if opt.resize_or_crop == 'none': 47 | base = float(2 ** opt.n_downsample_global) 48 | if opt.netG == 'local': 49 | base *= (2 ** opt.n_local_enhancers) 50 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) 51 | 52 | if opt.isTrain and not opt.no_flip: 53 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 54 | 55 | transform_list += [transforms.ToTensor()] 56 | 57 | if normalize: 58 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 59 | (0.5, 0.5, 0.5))] 60 | return transforms.Compose(transform_list) 61 | 62 | def normalize(): 63 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 64 | 65 | def __make_power_2(img, base, method=Image.BICUBIC): 66 | ow, oh = img.size 67 | h = int(round(oh / base) * base) 68 | w = int(round(ow / base) * base) 69 | if (h == oh) and (w == ow): 70 | return img 71 | return img.resize((w, h), method) 72 | 73 | def __scale_width(img, target_width, method=Image.BICUBIC): 74 | ow, oh = img.size 75 | if (ow == target_width): 76 | return img 77 | w = target_width 78 | h = int(target_width * oh / ow) 79 | return img.resize((w, h), method) 80 | 81 | def __crop(img, pos, size): 82 | ow, oh = img.size 83 | x1, y1 = pos 84 | tw = th = size 85 | if (ow > tw or oh > th): 86 | return img.crop((x1, y1, x1 + tw, y1 + th)) 87 | return img 88 | 89 | def __flip(img, flip): 90 | if flip: 91 | return img.transpose(Image.FLIP_LEFT_RIGHT) 92 | return img 93 | -------------------------------------------------------------------------------- /data/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_data_loader import BaseDataLoader 3 | 4 | 5 | def CreateDataset(opt): 6 | dataset = None 7 | from data.aligned_dataset import AlignedDataset 8 | dataset = AlignedDataset() 9 | 10 | print("dataset [%s] was created" % (dataset.name())) 11 | dataset.initialize(opt) 12 | return dataset 13 | 14 | class CustomDatasetDataLoader(BaseDataLoader): 15 | def name(self): 16 | return 'CustomDatasetDataLoader' 17 | 18 | def initialize(self, opt): 19 | BaseDataLoader.initialize(self, opt) 20 | self.dataset = CreateDataset(opt) 21 | self.dataloader = torch.utils.data.DataLoader( 22 | self.dataset, 23 | batch_size=opt.batchSize, 24 | shuffle=not opt.serial_batches, 25 | num_workers=int(opt.nThreads)) 26 | 27 | def load_data(self): 28 | return self.dataloader 29 | 30 | def __len__(self): 31 | return min(len(self.dataset), self.opt.max_dataset_size) 32 | -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | def CreateDataLoader(opt): 3 | from data.custom_dataset_data_loader import CustomDatasetDataLoader 4 | data_loader = CustomDatasetDataLoader() 5 | print(data_loader.name()) 6 | data_loader.initialize(opt) 7 | return data_loader 8 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import os 10 | 11 | IMG_EXTENSIONS = [ 12 | '.jpg', '.JPG', '.jpeg', '.JPEG', 13 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' 14 | ] 15 | 16 | 17 | def is_image_file(filename): 18 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 19 | 20 | 21 | def make_dataset(dir): 22 | images = [] 23 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 24 | 25 | for root, _, fnames in sorted(os.walk(dir)): 26 | for fname in fnames: 27 | if is_image_file(fname): 28 | path = os.path.join(root, fname) 29 | images.append(path) 30 | 31 | return images 32 | 33 | 34 | def default_loader(path): 35 | return Image.open(path).convert('RGB') 36 | 37 | 38 | class ImageFolder(data.Dataset): 39 | 40 | def __init__(self, root, transform=None, return_paths=False, 41 | loader=default_loader): 42 | imgs = make_dataset(root) 43 | if len(imgs) == 0: 44 | raise(RuntimeError("Found 0 images in: " + root + "\n" 45 | "Supported image extensions are: " + 46 | ",".join(IMG_EXTENSIONS))) 47 | 48 | self.root = root 49 | self.imgs = imgs 50 | self.transform = transform 51 | self.return_paths = return_paths 52 | self.loader = loader 53 | 54 | def __getitem__(self, index): 55 | path = self.imgs[index] 56 | img = self.loader(path) 57 | if self.transform is not None: 58 | img = self.transform(img) 59 | if self.return_paths: 60 | return img, path 61 | else: 62 | return img 63 | 64 | def __len__(self): 65 | return len(self.imgs) 66 | -------------------------------------------------------------------------------- /datasets/cityscapes/test_inst/frankfurt_000000_000576_gtFine_instanceIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000000_000576_gtFine_instanceIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_inst/frankfurt_000000_001236_gtFine_instanceIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000000_001236_gtFine_instanceIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_inst/frankfurt_000000_003357_gtFine_instanceIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000000_003357_gtFine_instanceIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_inst/frankfurt_000000_011810_gtFine_instanceIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000000_011810_gtFine_instanceIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_inst/frankfurt_000000_012868_gtFine_instanceIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000000_012868_gtFine_instanceIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_inst/frankfurt_000001_013710_gtFine_instanceIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000001_013710_gtFine_instanceIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_inst/frankfurt_000001_015328_gtFine_instanceIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000001_015328_gtFine_instanceIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_inst/frankfurt_000001_023769_gtFine_instanceIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000001_023769_gtFine_instanceIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_inst/frankfurt_000001_028335_gtFine_instanceIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000001_028335_gtFine_instanceIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_inst/frankfurt_000001_032711_gtFine_instanceIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000001_032711_gtFine_instanceIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_inst/frankfurt_000001_033655_gtFine_instanceIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000001_033655_gtFine_instanceIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_inst/frankfurt_000001_042733_gtFine_instanceIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000001_042733_gtFine_instanceIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_inst/frankfurt_000001_047552_gtFine_instanceIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000001_047552_gtFine_instanceIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_inst/frankfurt_000001_054640_gtFine_instanceIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000001_054640_gtFine_instanceIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_inst/frankfurt_000001_055387_gtFine_instanceIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000001_055387_gtFine_instanceIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_label/frankfurt_000000_000576_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000000_000576_gtFine_labelIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_label/frankfurt_000000_001236_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000000_001236_gtFine_labelIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_label/frankfurt_000000_003357_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000000_003357_gtFine_labelIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_label/frankfurt_000000_011810_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000000_011810_gtFine_labelIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_label/frankfurt_000000_012868_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000000_012868_gtFine_labelIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_label/frankfurt_000001_013710_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000001_013710_gtFine_labelIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_label/frankfurt_000001_015328_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000001_015328_gtFine_labelIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_label/frankfurt_000001_023769_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000001_023769_gtFine_labelIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_label/frankfurt_000001_028335_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000001_028335_gtFine_labelIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_label/frankfurt_000001_032711_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000001_032711_gtFine_labelIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_label/frankfurt_000001_033655_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000001_033655_gtFine_labelIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_label/frankfurt_000001_042733_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000001_042733_gtFine_labelIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_label/frankfurt_000001_047552_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000001_047552_gtFine_labelIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_label/frankfurt_000001_054640_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000001_054640_gtFine_labelIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/test_label/frankfurt_000001_055387_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000001_055387_gtFine_labelIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/train_img/aachen_000000_000019_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_img/aachen_000000_000019_leftImg8bit.png -------------------------------------------------------------------------------- /datasets/cityscapes/train_img/aachen_000001_000019_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_img/aachen_000001_000019_leftImg8bit.png -------------------------------------------------------------------------------- /datasets/cityscapes/train_img/aachen_000002_000019_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_img/aachen_000002_000019_leftImg8bit.png -------------------------------------------------------------------------------- /datasets/cityscapes/train_img/aachen_000003_000019_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_img/aachen_000003_000019_leftImg8bit.png -------------------------------------------------------------------------------- /datasets/cityscapes/train_img/aachen_000004_000019_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_img/aachen_000004_000019_leftImg8bit.png -------------------------------------------------------------------------------- /datasets/cityscapes/train_inst/aachen_000000_000019_gtFine_instanceIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_inst/aachen_000000_000019_gtFine_instanceIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/train_inst/aachen_000001_000019_gtFine_instanceIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_inst/aachen_000001_000019_gtFine_instanceIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/train_inst/aachen_000002_000019_gtFine_instanceIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_inst/aachen_000002_000019_gtFine_instanceIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/train_inst/aachen_000003_000019_gtFine_instanceIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_inst/aachen_000003_000019_gtFine_instanceIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/train_inst/aachen_000004_000019_gtFine_instanceIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_inst/aachen_000004_000019_gtFine_instanceIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/train_label/aachen_000000_000019_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_label/aachen_000000_000019_gtFine_labelIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/train_label/aachen_000001_000019_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_label/aachen_000001_000019_gtFine_labelIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/train_label/aachen_000002_000019_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_label/aachen_000002_000019_gtFine_labelIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/train_label/aachen_000003_000019_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_label/aachen_000003_000019_gtFine_labelIds.png -------------------------------------------------------------------------------- /datasets/cityscapes/train_label/aachen_000004_000019_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_label/aachen_000004_000019_gtFine_labelIds.png -------------------------------------------------------------------------------- /encode_features.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | from options.train_options import TrainOptions 4 | from data.data_loader import CreateDataLoader 5 | from models.models import create_model 6 | import numpy as np 7 | import os 8 | 9 | opt = TrainOptions().parse() 10 | opt.nThreads = 1 11 | opt.batchSize = 1 12 | opt.serial_batches = True 13 | opt.no_flip = True 14 | opt.instance_feat = True 15 | 16 | name = 'features' 17 | save_path = os.path.join(opt.checkpoints_dir, opt.name) 18 | 19 | ############ Initialize ######### 20 | data_loader = CreateDataLoader(opt) 21 | dataset = data_loader.load_data() 22 | dataset_size = len(data_loader) 23 | model = create_model(opt) 24 | 25 | ########### Encode features ########### 26 | reencode = True 27 | if reencode: 28 | features = {} 29 | for label in range(opt.label_nc): 30 | features[label] = np.zeros((0, opt.feat_num+1)) 31 | for i, data in enumerate(dataset): 32 | feat = model.module.encode_features(data['image'], data['inst']) 33 | for label in range(opt.label_nc): 34 | features[label] = np.append(features[label], feat[label], axis=0) 35 | 36 | print('%d / %d images' % (i+1, dataset_size)) 37 | save_name = os.path.join(save_path, name + '.npy') 38 | np.save(save_name, features) 39 | 40 | ############## Clustering ########### 41 | n_clusters = opt.n_clusters 42 | load_name = os.path.join(save_path, name + '.npy') 43 | features = np.load(load_name).item() 44 | from sklearn.cluster import KMeans 45 | centers = {} 46 | for label in range(opt.label_nc): 47 | feat = features[label] 48 | feat = feat[feat[:,-1] > 0.5, :-1] 49 | if feat.shape[0]: 50 | n_clusters = min(feat.shape[0], opt.n_clusters) 51 | kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(feat) 52 | centers[label] = kmeans.cluster_centers_ 53 | save_name = os.path.join(save_path, name + '_clustered_%03d.npy' % opt.n_clusters) 54 | np.save(save_name, centers) 55 | print('saving to %s' % save_name) -------------------------------------------------------------------------------- /imgs/city_short.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/city_short.gif -------------------------------------------------------------------------------- /imgs/cityscapes_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/cityscapes_1.jpg -------------------------------------------------------------------------------- /imgs/cityscapes_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/cityscapes_2.jpg -------------------------------------------------------------------------------- /imgs/cityscapes_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/cityscapes_3.jpg -------------------------------------------------------------------------------- /imgs/cityscapes_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/cityscapes_4.jpg -------------------------------------------------------------------------------- /imgs/face1_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/face1_1.jpg -------------------------------------------------------------------------------- /imgs/face1_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/face1_2.jpg -------------------------------------------------------------------------------- /imgs/face1_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/face1_3.jpg -------------------------------------------------------------------------------- /imgs/face2_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/face2_1.jpg -------------------------------------------------------------------------------- /imgs/face2_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/face2_2.jpg -------------------------------------------------------------------------------- /imgs/face2_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/face2_3.jpg -------------------------------------------------------------------------------- /imgs/face_short.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/face_short.gif -------------------------------------------------------------------------------- /imgs/teaser_720.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/teaser_720.gif -------------------------------------------------------------------------------- /imgs/teaser_label.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/teaser_label.gif -------------------------------------------------------------------------------- /imgs/teaser_label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/teaser_label.png -------------------------------------------------------------------------------- /imgs/teaser_ours.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/teaser_ours.jpg -------------------------------------------------------------------------------- /imgs/teaser_style.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/teaser_style.gif -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/models/__init__.py -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import os 4 | import torch 5 | 6 | class BaseModel(torch.nn.Module): 7 | def name(self): 8 | return 'BaseModel' 9 | 10 | def initialize(self, opt): 11 | self.opt = opt 12 | self.gpu_ids = opt.gpu_ids 13 | self.isTrain = opt.isTrain 14 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 15 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 16 | 17 | def set_input(self, input): 18 | self.input = input 19 | 20 | def forward(self): 21 | pass 22 | 23 | # used in test time, no backprop 24 | def test(self): 25 | pass 26 | 27 | def get_image_paths(self): 28 | pass 29 | 30 | def optimize_parameters(self): 31 | pass 32 | 33 | def get_current_visuals(self): 34 | return self.input 35 | 36 | def get_current_errors(self): 37 | return {} 38 | 39 | def save(self, label): 40 | pass 41 | 42 | # helper saving function that can be used by subclasses 43 | def save_network(self, network, network_label, epoch_label, gpu_ids): 44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 45 | save_path = os.path.join(self.save_dir, save_filename) 46 | torch.save(network.cpu().state_dict(), save_path) 47 | if len(gpu_ids) and torch.cuda.is_available(): 48 | network.cuda() 49 | 50 | # helper loading function that can be used by subclasses 51 | def load_network(self, network, network_label, epoch_label, save_dir=''): 52 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 53 | if not save_dir: 54 | save_dir = self.save_dir 55 | save_path = os.path.join(save_dir, save_filename) 56 | if not os.path.isfile(save_path): 57 | print('%s not exists yet!' % save_path) 58 | if network_label == 'G': 59 | raise('Generator must exist!') 60 | else: 61 | #network.load_state_dict(torch.load(save_path)) 62 | try: 63 | network.load_state_dict(torch.load(save_path)) 64 | except: 65 | pretrained_dict = torch.load(save_path) 66 | model_dict = network.state_dict() 67 | try: 68 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 69 | network.load_state_dict(pretrained_dict) 70 | print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) 71 | except: 72 | print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label) 73 | from sets import Set 74 | not_initialized = Set() 75 | for k, v in pretrained_dict.items(): 76 | if v.size() == model_dict[k].size(): 77 | model_dict[k] = v 78 | 79 | for k, v in model_dict.items(): 80 | if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): 81 | not_initialized.add(k.split('.')[0]) 82 | print(sorted(not_initialized)) 83 | network.load_state_dict(model_dict) 84 | 85 | def update_learning_rate(): 86 | pass 87 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import torch 4 | 5 | def create_model(opt): 6 | from .pix2pixHD_model import Pix2PixHDModel 7 | model = Pix2PixHDModel() 8 | model.initialize(opt) 9 | print("model [%s] was created" % (model.name())) 10 | 11 | if opt.isTrain and len(opt.gpu_ids): 12 | model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) 13 | 14 | return model 15 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import torch 4 | import torch.nn as nn 5 | import functools 6 | from torch.autograd import Variable 7 | import numpy as np 8 | 9 | ############################################################################### 10 | # Functions 11 | ############################################################################### 12 | def weights_init(m): 13 | classname = m.__class__.__name__ 14 | if classname.find('Conv') != -1: 15 | m.weight.data.normal_(0.0, 0.02) 16 | elif classname.find('BatchNorm2d') != -1: 17 | m.weight.data.normal_(1.0, 0.02) 18 | m.bias.data.fill_(0) 19 | 20 | def get_norm_layer(norm_type='instance'): 21 | if norm_type == 'batch': 22 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 23 | elif norm_type == 'instance': 24 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) 25 | else: 26 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 27 | return norm_layer 28 | 29 | def define_G(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1, 30 | n_blocks_local=3, norm='instance', gpu_ids=[]): 31 | norm_layer = get_norm_layer(norm_type=norm) 32 | if netG == 'global': 33 | netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer) 34 | elif netG == 'local': 35 | netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, 36 | n_local_enhancers, n_blocks_local, norm_layer) 37 | elif netG == 'encoder': 38 | netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, norm_layer) 39 | else: 40 | raise('generator not implemented!') 41 | print(netG) 42 | if len(gpu_ids) > 0: 43 | assert(torch.cuda.is_available()) 44 | netG.cuda(gpu_ids[0]) 45 | netG.apply(weights_init) 46 | return netG 47 | 48 | def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]): 49 | norm_layer = get_norm_layer(norm_type=norm) 50 | netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat) 51 | print(netD) 52 | if len(gpu_ids) > 0: 53 | assert(torch.cuda.is_available()) 54 | netD.cuda(gpu_ids[0]) 55 | netD.apply(weights_init) 56 | return netD 57 | 58 | def print_network(net): 59 | if isinstance(net, list): 60 | net = net[0] 61 | num_params = 0 62 | for param in net.parameters(): 63 | num_params += param.numel() 64 | print(net) 65 | print('Total number of parameters: %d' % num_params) 66 | 67 | ############################################################################## 68 | # Losses 69 | ############################################################################## 70 | class GANLoss(nn.Module): 71 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, 72 | tensor=torch.FloatTensor): 73 | super(GANLoss, self).__init__() 74 | self.real_label = target_real_label 75 | self.fake_label = target_fake_label 76 | self.real_label_var = None 77 | self.fake_label_var = None 78 | self.Tensor = tensor 79 | if use_lsgan: 80 | self.loss = nn.MSELoss() 81 | else: 82 | self.loss = nn.BCELoss() 83 | 84 | def get_target_tensor(self, input, target_is_real): 85 | target_tensor = None 86 | if target_is_real: 87 | create_label = ((self.real_label_var is None) or 88 | (self.real_label_var.numel() != input.numel())) 89 | if create_label: 90 | real_tensor = self.Tensor(input.size()).fill_(self.real_label) 91 | self.real_label_var = Variable(real_tensor, requires_grad=False) 92 | target_tensor = self.real_label_var 93 | else: 94 | create_label = ((self.fake_label_var is None) or 95 | (self.fake_label_var.numel() != input.numel())) 96 | if create_label: 97 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) 98 | self.fake_label_var = Variable(fake_tensor, requires_grad=False) 99 | target_tensor = self.fake_label_var 100 | return target_tensor 101 | 102 | def __call__(self, input, target_is_real): 103 | if isinstance(input[0], list): 104 | loss = 0 105 | for input_i in input: 106 | pred = input_i[-1] 107 | target_tensor = self.get_target_tensor(pred, target_is_real) 108 | loss += self.loss(pred, target_tensor) 109 | return loss 110 | else: 111 | target_tensor = self.get_target_tensor(input[-1], target_is_real) 112 | return self.loss(input[-1], target_tensor) 113 | 114 | class VGGLoss(nn.Module): 115 | def __init__(self, gpu_ids): 116 | super(VGGLoss, self).__init__() 117 | self.vgg = Vgg19().cuda() 118 | self.criterion = nn.L1Loss() 119 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] 120 | 121 | def forward(self, x, y): 122 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 123 | loss = 0 124 | for i in range(len(x_vgg)): 125 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 126 | return loss 127 | 128 | ############################################################################## 129 | # Generator 130 | ############################################################################## 131 | class LocalEnhancer(nn.Module): 132 | def __init__(self, input_nc, output_nc, ngf=32, n_downsample_global=3, n_blocks_global=9, 133 | n_local_enhancers=1, n_blocks_local=3, norm_layer=nn.BatchNorm2d, padding_type='reflect'): 134 | super(LocalEnhancer, self).__init__() 135 | self.n_local_enhancers = n_local_enhancers 136 | 137 | ###### global generator model ##### 138 | ngf_global = ngf * (2**n_local_enhancers) 139 | model_global = GlobalGenerator(input_nc, output_nc, ngf_global, n_downsample_global, n_blocks_global, norm_layer).model 140 | model_global = [model_global[i] for i in range(len(model_global)-3)] # get rid of final convolution layers 141 | self.model = nn.Sequential(*model_global) 142 | 143 | ###### local enhancer layers ##### 144 | for n in range(1, n_local_enhancers+1): 145 | ### downsample 146 | ngf_global = ngf * (2**(n_local_enhancers-n)) 147 | model_downsample = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0), 148 | norm_layer(ngf_global), nn.ReLU(True), 149 | nn.Conv2d(ngf_global, ngf_global * 2, kernel_size=3, stride=2, padding=1), 150 | norm_layer(ngf_global * 2), nn.ReLU(True)] 151 | ### residual blocks 152 | model_upsample = [] 153 | for i in range(n_blocks_local): 154 | model_upsample += [ResnetBlock(ngf_global * 2, padding_type=padding_type, norm_layer=norm_layer)] 155 | 156 | ### upsample 157 | model_upsample += [nn.ConvTranspose2d(ngf_global * 2, ngf_global, kernel_size=3, stride=2, padding=1, output_padding=1), 158 | norm_layer(ngf_global), nn.ReLU(True)] 159 | 160 | ### final convolution 161 | if n == n_local_enhancers: 162 | model_upsample += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] 163 | 164 | setattr(self, 'model'+str(n)+'_1', nn.Sequential(*model_downsample)) 165 | setattr(self, 'model'+str(n)+'_2', nn.Sequential(*model_upsample)) 166 | 167 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) 168 | 169 | def forward(self, input): 170 | ### create input pyramid 171 | input_downsampled = [input] 172 | for i in range(self.n_local_enhancers): 173 | input_downsampled.append(self.downsample(input_downsampled[-1])) 174 | 175 | ### output at coarest level 176 | output_prev = self.model(input_downsampled[-1]) 177 | ### build up one layer at a time 178 | for n_local_enhancers in range(1, self.n_local_enhancers+1): 179 | model_downsample = getattr(self, 'model'+str(n_local_enhancers)+'_1') 180 | model_upsample = getattr(self, 'model'+str(n_local_enhancers)+'_2') 181 | input_i = input_downsampled[self.n_local_enhancers-n_local_enhancers] 182 | output_prev = model_upsample(model_downsample(input_i) + output_prev) 183 | return output_prev 184 | 185 | class GlobalGenerator(nn.Module): 186 | def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, 187 | padding_type='reflect'): 188 | assert(n_blocks >= 0) 189 | super(GlobalGenerator, self).__init__() 190 | activation = nn.ReLU(True) 191 | 192 | model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation] 193 | ### downsample 194 | for i in range(n_downsampling): 195 | mult = 2**i 196 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), 197 | norm_layer(ngf * mult * 2), activation] 198 | 199 | ### resnet blocks 200 | mult = 2**n_downsampling 201 | for i in range(n_blocks): 202 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)] 203 | 204 | ### upsample 205 | for i in range(n_downsampling): 206 | mult = 2**(n_downsampling - i) 207 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), 208 | norm_layer(int(ngf * mult / 2)), activation] 209 | model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] 210 | self.model = nn.Sequential(*model) 211 | 212 | def forward(self, input): 213 | return self.model(input) 214 | 215 | # Define a resnet block 216 | class ResnetBlock(nn.Module): 217 | def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False): 218 | super(ResnetBlock, self).__init__() 219 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout) 220 | 221 | def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout): 222 | conv_block = [] 223 | p = 0 224 | if padding_type == 'reflect': 225 | conv_block += [nn.ReflectionPad2d(1)] 226 | elif padding_type == 'replicate': 227 | conv_block += [nn.ReplicationPad2d(1)] 228 | elif padding_type == 'zero': 229 | p = 1 230 | else: 231 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 232 | 233 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), 234 | norm_layer(dim), 235 | activation] 236 | if use_dropout: 237 | conv_block += [nn.Dropout(0.5)] 238 | 239 | p = 0 240 | if padding_type == 'reflect': 241 | conv_block += [nn.ReflectionPad2d(1)] 242 | elif padding_type == 'replicate': 243 | conv_block += [nn.ReplicationPad2d(1)] 244 | elif padding_type == 'zero': 245 | p = 1 246 | else: 247 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 248 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), 249 | norm_layer(dim)] 250 | 251 | return nn.Sequential(*conv_block) 252 | 253 | def forward(self, x): 254 | out = x + self.conv_block(x) 255 | return out 256 | 257 | class Encoder(nn.Module): 258 | def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d): 259 | super(Encoder, self).__init__() 260 | self.output_nc = output_nc 261 | 262 | model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), 263 | norm_layer(ngf), nn.ReLU(True)] 264 | ### downsample 265 | for i in range(n_downsampling): 266 | mult = 2**i 267 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), 268 | norm_layer(ngf * mult * 2), nn.ReLU(True)] 269 | 270 | ### upsample 271 | for i in range(n_downsampling): 272 | mult = 2**(n_downsampling - i) 273 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), 274 | norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] 275 | 276 | model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] 277 | self.model = nn.Sequential(*model) 278 | 279 | def forward(self, input, inst): 280 | outputs = self.model(input) 281 | 282 | # instance-wise average pooling 283 | outputs_mean = outputs.clone() 284 | inst_list = np.unique(inst.cpu().numpy().astype(int)) 285 | for i in inst_list: 286 | indices = (inst == i).nonzero() # n x 4 287 | for j in range(self.output_nc): 288 | output_ins = outputs[indices[:,0], indices[:,1] + j, indices[:,2], indices[:,3]] 289 | mean_feat = torch.mean(output_ins).expand_as(output_ins) 290 | outputs_mean[indices[:,0], indices[:,1] + j, indices[:,2], indices[:,3]] = mean_feat 291 | return outputs_mean 292 | 293 | class MultiscaleDiscriminator(nn.Module): 294 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 295 | use_sigmoid=False, num_D=3, getIntermFeat=False): 296 | super(MultiscaleDiscriminator, self).__init__() 297 | self.num_D = num_D 298 | self.n_layers = n_layers 299 | self.getIntermFeat = getIntermFeat 300 | 301 | for i in range(num_D): 302 | netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat) 303 | if getIntermFeat: 304 | for j in range(n_layers+2): 305 | setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j))) 306 | else: 307 | setattr(self, 'layer'+str(i), netD.model) 308 | 309 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) 310 | 311 | def singleD_forward(self, model, input): 312 | if self.getIntermFeat: 313 | result = [input] 314 | for i in range(len(model)): 315 | result.append(model[i](result[-1])) 316 | return result[1:] 317 | else: 318 | return [model(input)] 319 | 320 | def forward(self, input): 321 | num_D = self.num_D 322 | result = [] 323 | input_downsampled = input 324 | for i in range(num_D): 325 | if self.getIntermFeat: 326 | model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)] 327 | else: 328 | model = getattr(self, 'layer'+str(num_D-1-i)) 329 | result.append(self.singleD_forward(model, input_downsampled)) 330 | if i != (num_D-1): 331 | input_downsampled = self.downsample(input_downsampled) 332 | return result 333 | 334 | # Defines the PatchGAN discriminator with the specified arguments. 335 | class NLayerDiscriminator(nn.Module): 336 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False): 337 | super(NLayerDiscriminator, self).__init__() 338 | self.getIntermFeat = getIntermFeat 339 | self.n_layers = n_layers 340 | 341 | kw = 4 342 | padw = int(np.ceil((kw-1.0)/2)) 343 | sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] 344 | 345 | nf = ndf 346 | for n in range(1, n_layers): 347 | nf_prev = nf 348 | nf = min(nf * 2, 512) 349 | sequence += [[ 350 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), 351 | norm_layer(nf), nn.LeakyReLU(0.2, True) 352 | ]] 353 | 354 | nf_prev = nf 355 | nf = min(nf * 2, 512) 356 | sequence += [[ 357 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), 358 | norm_layer(nf), 359 | nn.LeakyReLU(0.2, True) 360 | ]] 361 | 362 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 363 | 364 | if use_sigmoid: 365 | sequence += [[nn.Sigmoid()]] 366 | 367 | if getIntermFeat: 368 | for n in range(len(sequence)): 369 | setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) 370 | else: 371 | sequence_stream = [] 372 | for n in range(len(sequence)): 373 | sequence_stream += sequence[n] 374 | self.model = nn.Sequential(*sequence_stream) 375 | 376 | def forward(self, input): 377 | if self.getIntermFeat: 378 | res = [input] 379 | for n in range(self.n_layers+2): 380 | model = getattr(self, 'model'+str(n)) 381 | res.append(model(res[-1])) 382 | return res[1:] 383 | else: 384 | return self.model(input) 385 | 386 | from torchvision import models 387 | class Vgg19(torch.nn.Module): 388 | def __init__(self, requires_grad=False): 389 | super(Vgg19, self).__init__() 390 | vgg_pretrained_features = models.vgg19(pretrained=True).features 391 | self.slice1 = torch.nn.Sequential() 392 | self.slice2 = torch.nn.Sequential() 393 | self.slice3 = torch.nn.Sequential() 394 | self.slice4 = torch.nn.Sequential() 395 | self.slice5 = torch.nn.Sequential() 396 | for x in range(2): 397 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 398 | for x in range(2, 7): 399 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 400 | for x in range(7, 12): 401 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 402 | for x in range(12, 21): 403 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 404 | for x in range(21, 30): 405 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 406 | if not requires_grad: 407 | for param in self.parameters(): 408 | param.requires_grad = False 409 | 410 | def forward(self, X): 411 | h_relu1 = self.slice1(X) 412 | h_relu2 = self.slice2(h_relu1) 413 | h_relu3 = self.slice3(h_relu2) 414 | h_relu4 = self.slice4(h_relu3) 415 | h_relu5 = self.slice5(h_relu4) 416 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 417 | return out 418 | -------------------------------------------------------------------------------- /models/pix2pixHD_model.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import numpy as np 4 | import torch 5 | import os 6 | from torch.autograd import Variable 7 | from util.image_pool import ImagePool 8 | from .base_model import BaseModel 9 | from . import networks 10 | 11 | class Pix2PixHDModel(BaseModel): 12 | def name(self): 13 | return 'Pix2PixHDModel' 14 | 15 | def initialize(self, opt): 16 | BaseModel.initialize(self, opt) 17 | if opt.resize_or_crop != 'none': # when training at full res this causes OOM 18 | torch.backends.cudnn.benchmark = True 19 | self.isTrain = opt.isTrain 20 | self.use_features = opt.instance_feat or opt.label_feat 21 | self.gen_features = self.use_features and not self.opt.load_features 22 | input_nc = opt.label_nc if opt.label_nc != 0 else 3 23 | 24 | ##### define networks 25 | # Generator network 26 | netG_input_nc = input_nc 27 | if not opt.no_instance: 28 | netG_input_nc += 1 29 | if self.use_features: 30 | netG_input_nc += opt.feat_num 31 | self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, 32 | opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, 33 | opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids) 34 | 35 | # Discriminator network 36 | if self.isTrain: 37 | use_sigmoid = opt.no_lsgan 38 | netD_input_nc = input_nc + opt.output_nc 39 | if not opt.no_instance: 40 | netD_input_nc += 1 41 | self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, 42 | opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) 43 | 44 | ### Encoder network 45 | if self.gen_features: 46 | self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', 47 | opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids) 48 | 49 | print('---------- Networks initialized -------------') 50 | 51 | # load networks 52 | if not self.isTrain or opt.continue_train or opt.load_pretrain: 53 | pretrained_path = '' if not self.isTrain else opt.load_pretrain 54 | self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) 55 | if self.isTrain: 56 | self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) 57 | if self.gen_features: 58 | self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path) 59 | 60 | # set loss functions and optimizers 61 | if self.isTrain: 62 | if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: 63 | raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") 64 | self.fake_pool = ImagePool(opt.pool_size) 65 | self.old_lr = opt.lr 66 | 67 | # define loss functions 68 | self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) 69 | self.criterionFeat = torch.nn.L1Loss() 70 | if not opt.no_vgg_loss: 71 | self.criterionVGG = networks.VGGLoss(self.gpu_ids) 72 | 73 | # Names so we can breakout loss 74 | self.loss_names = ['G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake'] 75 | 76 | # initialize optimizers 77 | # optimizer G 78 | if opt.niter_fix_global > 0: 79 | print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) 80 | params_dict = dict(self.netG.named_parameters()) 81 | params = [] 82 | for key, value in params_dict.items(): 83 | if key.startswith('model' + str(opt.n_local_enhancers)): 84 | params += [{'params':[value],'lr':opt.lr}] 85 | else: 86 | params += [{'params':[value],'lr':0.0}] 87 | else: 88 | params = list(self.netG.parameters()) 89 | if self.gen_features: 90 | params += list(self.netE.parameters()) 91 | self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) 92 | 93 | # optimizer D 94 | params = list(self.netD.parameters()) 95 | self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) 96 | 97 | def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): 98 | if self.opt.label_nc == 0: 99 | input_label = label_map.data.cuda() 100 | else: 101 | # create one-hot vector for label map 102 | size = label_map.size() 103 | oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) 104 | input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() 105 | input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) 106 | 107 | # get edges from instance map 108 | if not self.opt.no_instance: 109 | inst_map = inst_map.data.cuda() 110 | edge_map = self.get_edges(inst_map) 111 | input_label = torch.cat((input_label, edge_map), dim=1) 112 | input_label = Variable(input_label, volatile=infer) 113 | 114 | # real images for training 115 | if real_image is not None: 116 | real_image = Variable(real_image.data.cuda()) 117 | 118 | # instance map for feature encoding 119 | if self.use_features: 120 | # get precomputed feature maps 121 | if self.opt.load_features: 122 | feat_map = Variable(feat_map.data.cuda()) 123 | 124 | return input_label, inst_map, real_image, feat_map 125 | 126 | def discriminate(self, input_label, test_image, use_pool=False): 127 | input_concat = torch.cat((input_label, test_image.detach()), dim=1) 128 | if use_pool: 129 | fake_query = self.fake_pool.query(input_concat) 130 | return self.netD.forward(fake_query) 131 | else: 132 | return self.netD.forward(input_concat) 133 | 134 | def forward(self, label, inst, image, feat, infer=False): 135 | # Encode Inputs 136 | input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat) 137 | 138 | # Fake Generation 139 | if self.use_features: 140 | if not self.opt.load_features: 141 | feat_map = self.netE.forward(real_image, inst_map) 142 | input_concat = torch.cat((input_label, feat_map), dim=1) 143 | else: 144 | input_concat = input_label 145 | fake_image = self.netG.forward(input_concat) 146 | 147 | # Fake Detection and Loss 148 | pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True) 149 | loss_D_fake = self.criterionGAN(pred_fake_pool, False) 150 | 151 | # Real Detection and Loss 152 | pred_real = self.discriminate(input_label, real_image) 153 | loss_D_real = self.criterionGAN(pred_real, True) 154 | 155 | # GAN loss (Fake Passability Loss) 156 | pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1)) 157 | loss_G_GAN = self.criterionGAN(pred_fake, True) 158 | 159 | # GAN feature matching loss 160 | loss_G_GAN_Feat = 0 161 | if not self.opt.no_ganFeat_loss: 162 | feat_weights = 4.0 / (self.opt.n_layers_D + 1) 163 | D_weights = 1.0 / self.opt.num_D 164 | for i in range(self.opt.num_D): 165 | for j in range(len(pred_fake[i])-1): 166 | loss_G_GAN_Feat += D_weights * feat_weights * \ 167 | self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat 168 | 169 | # VGG feature matching loss 170 | loss_G_VGG = 0 171 | if not self.opt.no_vgg_loss: 172 | loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat 173 | 174 | # Only return the fake_B image if necessary to save BW 175 | return [ [ loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ], None if not infer else fake_image ] 176 | 177 | def inference(self, label, inst): 178 | # Encode Inputs 179 | input_label, inst_map, _, _ = self.encode_input(Variable(label), Variable(inst), infer=True) 180 | 181 | # Fake Generation 182 | if self.use_features: 183 | # sample clusters from precomputed features 184 | feat_map = self.sample_features(inst_map) 185 | input_concat = torch.cat((input_label, feat_map), dim=1) 186 | else: 187 | input_concat = input_label 188 | fake_image = self.netG.forward(input_concat) 189 | return fake_image 190 | 191 | def sample_features(self, inst): 192 | # read precomputed feature clusters 193 | cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path) 194 | features_clustered = np.load(cluster_path).item() 195 | 196 | # randomly sample from the feature clusters 197 | inst_np = inst.cpu().numpy().astype(int) 198 | feat_map = torch.cuda.FloatTensor(1, self.opt.feat_num, inst.size()[2], inst.size()[3]) 199 | for i in np.unique(inst_np): 200 | label = i if i < 1000 else i//1000 201 | if label in features_clustered: 202 | feat = features_clustered[label] 203 | cluster_idx = np.random.randint(0, feat.shape[0]) 204 | 205 | idx = (inst == i).nonzero() 206 | for k in range(self.opt.feat_num): 207 | feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k] 208 | return feat_map 209 | 210 | def encode_features(self, image, inst): 211 | image = Variable(image.cuda(), volatile=True) 212 | feat_num = self.opt.feat_num 213 | h, w = inst.size()[2], inst.size()[3] 214 | block_num = 32 215 | feat_map = self.netE.forward(image, inst.cuda()) 216 | inst_np = inst.cpu().numpy().astype(int) 217 | feature = {} 218 | for i in range(self.opt.label_nc): 219 | feature[i] = np.zeros((0, feat_num+1)) 220 | for i in np.unique(inst_np): 221 | label = i if i < 1000 else i//1000 222 | idx = (inst == i).nonzero() 223 | num = idx.size()[0] 224 | idx = idx[num//2,:] 225 | val = np.zeros((1, feat_num+1)) 226 | for k in range(feat_num): 227 | val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0] 228 | val[0, feat_num] = float(num) / (h * w // block_num) 229 | feature[label] = np.append(feature[label], val, axis=0) 230 | return feature 231 | 232 | def get_edges(self, t): 233 | edge = torch.cuda.ByteTensor(t.size()).zero_() 234 | edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1]) 235 | edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1]) 236 | edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) 237 | edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) 238 | return edge.float() 239 | 240 | def save(self, which_epoch): 241 | self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) 242 | self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) 243 | if self.gen_features: 244 | self.save_network(self.netE, 'E', which_epoch, self.gpu_ids) 245 | 246 | def update_fixed_params(self): 247 | # after fixing the global generator for a number of iterations, also start finetuning it 248 | params = list(self.netG.parameters()) 249 | if self.gen_features: 250 | params += list(self.netE.parameters()) 251 | self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) 252 | print('------------ Now also finetuning global generator -----------') 253 | 254 | def update_learning_rate(self): 255 | lrd = self.opt.lr / self.opt.niter_decay 256 | lr = self.old_lr - lrd 257 | for param_group in self.optimizer_D.param_groups: 258 | param_group['lr'] = lr 259 | for param_group in self.optimizer_G.param_groups: 260 | param_group['lr'] = lr 261 | print('update learning rate: %f -> %f' % (self.old_lr, lr)) 262 | self.old_lr = lr 263 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import argparse 4 | import os 5 | from util import util 6 | import torch 7 | 8 | class BaseOptions(): 9 | def __init__(self): 10 | self.parser = argparse.ArgumentParser() 11 | self.initialized = False 12 | 13 | def initialize(self): 14 | # experiment specifics 15 | self.parser.add_argument('--name', type=str, default='label2city', help='name of the experiment. It decides where to store samples and models') 16 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 17 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 18 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') 19 | self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator') 20 | 21 | # input/output sizes 22 | self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size') 23 | self.parser.add_argument('--loadSize', type=int, default=1024, help='scale images to this size') 24 | self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size') 25 | self.parser.add_argument('--label_nc', type=int, default=35, help='# of input image channels') 26 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 27 | 28 | # for setting inputs 29 | self.parser.add_argument('--dataroot', type=str, default='./datasets/cityscapes/') 30 | self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') 31 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 32 | self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') 33 | self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data') 34 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 35 | 36 | # for displays 37 | self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size') 38 | self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed') 39 | 40 | # for generator 41 | self.parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG') 42 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 43 | self.parser.add_argument('--n_downsample_global', type=int, default=4, help='number of downsampling layers in netG') 44 | self.parser.add_argument('--n_blocks_global', type=int, default=9, help='number of residual blocks in the global generator network') 45 | self.parser.add_argument('--n_blocks_local', type=int, default=3, help='number of residual blocks in the local enhancer network') 46 | self.parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use') 47 | self.parser.add_argument('--niter_fix_global', type=int, default=0, help='number of epochs that we only train the outmost local enhancer') 48 | 49 | # for instance-wise features 50 | self.parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input') 51 | self.parser.add_argument('--instance_feat', action='store_true', help='if specified, add encoded instance features as input') 52 | self.parser.add_argument('--label_feat', action='store_true', help='if specified, add encoded label features as input') 53 | self.parser.add_argument('--feat_num', type=int, default=3, help='vector length for encoded features') 54 | self.parser.add_argument('--load_features', action='store_true', help='if specified, load precomputed feature maps') 55 | self.parser.add_argument('--n_downsample_E', type=int, default=4, help='# of downsampling layers in encoder') 56 | self.parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer') 57 | self.parser.add_argument('--n_clusters', type=int, default=10, help='number of clusters for features') 58 | 59 | self.initialized = True 60 | 61 | def parse(self, save=True): 62 | if not self.initialized: 63 | self.initialize() 64 | self.opt = self.parser.parse_args() 65 | self.opt.isTrain = self.isTrain # train or test 66 | 67 | str_ids = self.opt.gpu_ids.split(',') 68 | self.opt.gpu_ids = [] 69 | for str_id in str_ids: 70 | id = int(str_id) 71 | if id >= 0: 72 | self.opt.gpu_ids.append(id) 73 | 74 | # set gpu ids 75 | if len(self.opt.gpu_ids) > 0: 76 | torch.cuda.set_device(self.opt.gpu_ids[0]) 77 | 78 | args = vars(self.opt) 79 | 80 | print('------------ Options -------------') 81 | for k, v in sorted(args.items()): 82 | print('%s: %s' % (str(k), str(v))) 83 | print('-------------- End ----------------') 84 | 85 | # save to the disk 86 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 87 | util.mkdirs(expr_dir) 88 | if save and not self.opt.continue_train: 89 | file_name = os.path.join(expr_dir, 'opt.txt') 90 | with open(file_name, 'wt') as opt_file: 91 | opt_file.write('------------ Options -------------\n') 92 | for k, v in sorted(args.items()): 93 | opt_file.write('%s: %s\n' % (str(k), str(v))) 94 | opt_file.write('-------------- End ----------------\n') 95 | return self.opt 96 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | from .base_options import BaseOptions 4 | 5 | class TestOptions(BaseOptions): 6 | def initialize(self): 7 | BaseOptions.initialize(self) 8 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 9 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 10 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 11 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 12 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 13 | self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run') 14 | self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features') 15 | self.isTrain = False 16 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | from .base_options import BaseOptions 4 | 5 | class TrainOptions(BaseOptions): 6 | def initialize(self): 7 | BaseOptions.initialize(self) 8 | # for displays 9 | self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') 10 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 11 | self.parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results') 12 | self.parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs') 13 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 14 | self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration') 15 | 16 | # for training 17 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 18 | self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location') 19 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 20 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 21 | self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 22 | self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 23 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 24 | self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 25 | 26 | # for discriminators 27 | self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use') 28 | self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') 29 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 30 | self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') 31 | self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss') 32 | self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss') 33 | self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') 34 | self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images') 35 | 36 | self.isTrain = True 37 | -------------------------------------------------------------------------------- /precompute_feature_maps.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | from options.train_options import TrainOptions 4 | from data.data_loader import CreateDataLoader 5 | from models.models import create_model 6 | import os 7 | import util.util as util 8 | from torch.autograd import Variable 9 | import torch.nn as nn 10 | 11 | opt = TrainOptions().parse() 12 | opt.nThreads = 1 13 | opt.batchSize = 1 14 | opt.serial_batches = True 15 | opt.no_flip = True 16 | opt.instance_feat = True 17 | 18 | name = 'features' 19 | save_path = os.path.join(opt.checkpoints_dir, opt.name) 20 | 21 | ############ Initialize ######### 22 | data_loader = CreateDataLoader(opt) 23 | dataset = data_loader.load_data() 24 | dataset_size = len(data_loader) 25 | model = create_model(opt) 26 | util.mkdirs(os.path.join(opt.dataroot, opt.phase + '_feat')) 27 | 28 | ######## Save precomputed feature maps for 1024p training ####### 29 | for i, data in enumerate(dataset): 30 | print('%d / %d images' % (i+1, dataset_size)) 31 | feat_map = model.module.netE.forward(Variable(data['image'].cuda(), volatile=True), data['inst'].cuda()) 32 | feat_map = nn.Upsample(scale_factor=2, mode='nearest')(feat_map) 33 | image_numpy = util.tensor2im(feat_map.data[0]) 34 | save_path = data['path'][0].replace('/train_label/', '/train_feat/') 35 | util.save_image(image_numpy, save_path) -------------------------------------------------------------------------------- /scripts/test_1024p.sh: -------------------------------------------------------------------------------- 1 | ################################ Testing ################################ 2 | # labels only 3 | python test.py --name label2city_1024p --netG local --ngf 32 --resize_or_crop none -------------------------------------------------------------------------------- /scripts/test_1024p_feat.sh: -------------------------------------------------------------------------------- 1 | ################################ Testing ################################ 2 | # first precompute and cluster all features 3 | python encode_features.py --name label2city_1024p_feat --netG local --ngf 32 --resize_or_crop none; 4 | # use instance-wise features 5 | python test.py --name label2city_1024p_feat ---netG local --ngf 32 --resize_or_crop none --instance_feat -------------------------------------------------------------------------------- /scripts/test_512p.sh: -------------------------------------------------------------------------------- 1 | ################################ Testing ################################ 2 | # labels only 3 | python test.py --name label2city_512p -------------------------------------------------------------------------------- /scripts/test_512p_feat.sh: -------------------------------------------------------------------------------- 1 | ################################ Testing ################################ 2 | # first precompute and cluster all features 3 | python encode_features.py --name label2city_512p_feat; 4 | # use instance-wise features 5 | python test.py --name label2city_512p_feat --instance_feat -------------------------------------------------------------------------------- /scripts/train_1024p_12G.sh: -------------------------------------------------------------------------------- 1 | ############## To train images at 2048 x 1024 resolution after training 1024 x 512 resolution models ############# 2 | ##### Using GPUs with 12G memory (not tested) 3 | # Using labels only 4 | python train.py --name label2city_1024p --netG local --ngf 32 --num_D 3 --load_pretrain checkpoints/label2city_512p/ --niter_fix_global 20 --resize_or_crop crop --fineSize 1024 -------------------------------------------------------------------------------- /scripts/train_1024p_24G.sh: -------------------------------------------------------------------------------- 1 | ############## To train images at 2048 x 1024 resolution after training 1024 x 512 resolution models ############# 2 | ######## Using GPUs with 24G memory 3 | # Using labels only 4 | python train.py --name label2city_1024p --netG local --ngf 32 --num_D 3 --load_pretrain checkpoints/label2city_512p/ --niter 50 --niter_decay 50 --niter_fix_global 10 --resize_or_crop none -------------------------------------------------------------------------------- /scripts/train_1024p_feat_12G.sh: -------------------------------------------------------------------------------- 1 | ############## To train images at 2048 x 1024 resolution after training 1024 x 512 resolution models ############# 2 | ##### Using GPUs with 12G memory (not tested) 3 | # First precompute feature maps and save them 4 | python precompute_feature_maps.py --name label2city_512p_feat; 5 | # Adding instances and encoded features 6 | python train.py --name label2city_1024p_feat --netG local --ngf 32 --num_D 3 --load_pretrain checkpoints/label2city_512p_feat/ --niter_fix_global 20 --resize_or_crop crop --fineSize 896 --instance_feat --load_features -------------------------------------------------------------------------------- /scripts/train_1024p_feat_24G.sh: -------------------------------------------------------------------------------- 1 | ############## To train images at 2048 x 1024 resolution after training 1024 x 512 resolution models ############# 2 | ######## Using GPUs with 24G memory 3 | # First precompute feature maps and save them 4 | python precompute_feature_maps.py --name label2city_512p_feat; 5 | # Adding instances and encoded features 6 | python train.py --name label2city_1024p_feat --netG local --ngf 32 --num_D 3 --load_pretrain checkpoints/label2city_512p_feat/ --niter 50 --niter_decay 50 --niter_fix_global 10 --resize_or_crop none --instance_feat --load_features -------------------------------------------------------------------------------- /scripts/train_512p.sh: -------------------------------------------------------------------------------- 1 | ### Using labels only 2 | python train.py --name label2city_512p -------------------------------------------------------------------------------- /scripts/train_512p_feat.sh: -------------------------------------------------------------------------------- 1 | ### Adding instances and encoded features 2 | python train.py --name label2city_512p_feat --instance_feat -------------------------------------------------------------------------------- /scripts/train_512p_multigpu.sh: -------------------------------------------------------------------------------- 1 | ######## Multi-GPU training example ####### 2 | python train.py --name label2city_512p --batchSize 8 --gpu_ids 0,1,2,3,4,5,6,7 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import os 4 | from collections import OrderedDict 5 | from options.test_options import TestOptions 6 | from data.data_loader import CreateDataLoader 7 | from models.models import create_model 8 | import util.util as util 9 | from util.visualizer import Visualizer 10 | from util import html 11 | 12 | opt = TestOptions().parse(save=False) 13 | opt.nThreads = 1 # test code only supports nThreads = 1 14 | opt.batchSize = 1 # test code only supports batchSize = 1 15 | opt.serial_batches = True # no shuffle 16 | opt.no_flip = True # no flip 17 | 18 | data_loader = CreateDataLoader(opt) 19 | dataset = data_loader.load_data() 20 | model = create_model(opt) 21 | visualizer = Visualizer(opt) 22 | # create website 23 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) 24 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 25 | # test 26 | for i, data in enumerate(dataset): 27 | if i >= opt.how_many: 28 | break 29 | generated = model.inference(data['label'], data['inst']) 30 | visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), 31 | ('synthesized_image', util.tensor2im(generated.data[0]))]) 32 | img_path = data['path'] 33 | print('process image... %s' % img_path) 34 | visualizer.save_images(webpage, visuals, img_path) 35 | 36 | webpage.save() 37 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import time 4 | from collections import OrderedDict 5 | from options.train_options import TrainOptions 6 | from data.data_loader import CreateDataLoader 7 | from models.models import create_model 8 | import util.util as util 9 | from util.visualizer import Visualizer 10 | import os 11 | import numpy as np 12 | import torch 13 | from torch.autograd import Variable 14 | 15 | opt = TrainOptions().parse() 16 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') 17 | if opt.continue_train: 18 | try: 19 | start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int) 20 | except: 21 | start_epoch, epoch_iter = 1, 0 22 | print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) 23 | else: 24 | start_epoch, epoch_iter = 1, 0 25 | 26 | if opt.debug: 27 | opt.display_freq = 1 28 | opt.print_freq = 1 29 | opt.niter = 1 30 | opt.niter_decay = 0 31 | opt.max_dataset_size = 10 32 | 33 | data_loader = CreateDataLoader(opt) 34 | dataset = data_loader.load_data() 35 | dataset_size = len(data_loader) 36 | print('#training images = %d' % dataset_size) 37 | 38 | model = create_model(opt) 39 | visualizer = Visualizer(opt) 40 | 41 | total_steps = (start_epoch-1) * dataset_size + epoch_iter 42 | for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): 43 | epoch_start_time = time.time() 44 | if epoch != start_epoch: 45 | epoch_iter = epoch_iter % dataset_size 46 | for i, data in enumerate(dataset, start=epoch_iter): 47 | iter_start_time = time.time() 48 | total_steps += opt.batchSize 49 | epoch_iter += opt.batchSize 50 | 51 | # whether to collect output images 52 | save_fake = total_steps % opt.display_freq == 0 53 | 54 | ############## Forward Pass ###################### 55 | losses, generated = model(Variable(data['label']), Variable(data['inst']), 56 | Variable(data['image']), Variable(data['feat']), infer=save_fake) 57 | 58 | # sum per device losses 59 | losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ] 60 | loss_dict = dict(zip(model.module.loss_names, losses)) 61 | 62 | # calculate final loss scalar 63 | loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 64 | loss_G = loss_dict['G_GAN'] + loss_dict['G_GAN_Feat'] + loss_dict['G_VGG'] 65 | 66 | ############### Backward Pass #################### 67 | # update generator weights 68 | model.module.optimizer_G.zero_grad() 69 | loss_G.backward() 70 | model.module.optimizer_G.step() 71 | 72 | # update discriminator weights 73 | model.module.optimizer_D.zero_grad() 74 | loss_D.backward() 75 | model.module.optimizer_D.step() 76 | 77 | #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) 78 | 79 | ############## Display results and errors ########## 80 | ### print out errors 81 | if total_steps % opt.print_freq == 0: 82 | errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()} 83 | t = (time.time() - iter_start_time) / opt.batchSize 84 | visualizer.print_current_errors(epoch, epoch_iter, errors, t) 85 | visualizer.plot_current_errors(errors, total_steps) 86 | 87 | ### display output images 88 | if save_fake: 89 | visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), 90 | ('synthesized_image', util.tensor2im(generated.data[0])), 91 | ('real_image', util.tensor2im(data['image'][0]))]) 92 | visualizer.display_current_results(visuals, epoch, total_steps) 93 | 94 | ### save latest model 95 | if total_steps % opt.save_latest_freq == 0: 96 | print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) 97 | model.module.save('latest') 98 | np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') 99 | 100 | # end of epoch 101 | iter_end_time = time.time() 102 | print('End of epoch %d / %d \t Time Taken: %d sec' % 103 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 104 | 105 | ### save model for this epoch 106 | if epoch % opt.save_epoch_freq == 0: 107 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) 108 | model.module.save('latest') 109 | model.module.save(epoch) 110 | np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d') 111 | 112 | ### instead of only training the local enhancer, train the entire network after certain iterations 113 | if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): 114 | model.module.update_fixed_params() 115 | 116 | ### linearly decay learning rate after certain iterations 117 | if epoch > opt.niter: 118 | model.module.update_learning_rate() 119 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/util/__init__.py -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, refresh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | 16 | self.doc = dominate.document(title=title) 17 | if refresh > 0: 18 | with self.doc.head: 19 | meta(http_equiv="refresh", content=str(refresh)) 20 | 21 | def get_image_dir(self): 22 | return self.img_dir 23 | 24 | def add_header(self, str): 25 | with self.doc: 26 | h3(str) 27 | 28 | def add_table(self, border=1): 29 | self.t = table(border=border, style="table-layout: fixed;") 30 | self.doc.add(self.t) 31 | 32 | def add_images(self, ims, txts, links, width=512): 33 | self.add_table() 34 | with self.t: 35 | with tr(): 36 | for im, txt, link in zip(ims, txts, links): 37 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 38 | with p(): 39 | with a(href=os.path.join('images', link)): 40 | img(style="width:%dpx" % (width), src=os.path.join('images', im)) 41 | br() 42 | p(txt) 43 | 44 | def save(self): 45 | html_file = '%s/index.html' % self.web_dir 46 | f = open(html_file, 'wt') 47 | f.write(self.doc.render()) 48 | f.close() 49 | 50 | 51 | if __name__ == '__main__': 52 | html = HTML('web/', 'test_html') 53 | html.add_header('hello world') 54 | 55 | ims = [] 56 | txts = [] 57 | links = [] 58 | for n in range(4): 59 | ims.append('image_%d.jpg' % n) 60 | txts.append('text_%d' % n) 61 | links.append('image_%d.jpg' % n) 62 | html.add_images(ims, txts, links) 63 | html.save() 64 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch.autograd import Variable 4 | class ImagePool(): 5 | def __init__(self, pool_size): 6 | self.pool_size = pool_size 7 | if self.pool_size > 0: 8 | self.num_imgs = 0 9 | self.images = [] 10 | 11 | def query(self, images): 12 | if self.pool_size == 0: 13 | return images 14 | return_images = [] 15 | for image in images.data: 16 | image = torch.unsqueeze(image, 0) 17 | if self.num_imgs < self.pool_size: 18 | self.num_imgs = self.num_imgs + 1 19 | self.images.append(image) 20 | return_images.append(image) 21 | else: 22 | p = random.uniform(0, 1) 23 | if p > 0.5: 24 | random_id = random.randint(0, self.pool_size-1) 25 | tmp = self.images[random_id].clone() 26 | self.images[random_id] = image 27 | return_images.append(tmp) 28 | else: 29 | return_images.append(image) 30 | return_images = Variable(torch.cat(return_images, 0)) 31 | return return_images 32 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import numpy as np 6 | import os 7 | 8 | # Converts a Tensor into a Numpy array 9 | # |imtype|: the desired type of the converted numpy array 10 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True): 11 | if isinstance(image_tensor, list): 12 | image_numpy = [] 13 | for i in range(len(image_tensor)): 14 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) 15 | return image_numpy 16 | image_numpy = image_tensor.cpu().float().numpy() 17 | if normalize: 18 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 19 | else: 20 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 21 | image_numpy = np.clip(image_numpy, 0, 255) 22 | if image_numpy.shape[2] == 1: 23 | image_numpy = image_numpy[:,:,0] 24 | return image_numpy.astype(imtype) 25 | 26 | # Converts a one-hot tensor into a colorful label map 27 | def tensor2label(label_tensor, n_label, imtype=np.uint8): 28 | if n_label == 0: 29 | return tensor2im(label_tensor, imtype) 30 | label_tensor = label_tensor.cpu().float() 31 | if label_tensor.size()[0] > 1: 32 | label_tensor = label_tensor.max(0, keepdim=True)[1] 33 | label_tensor = Colorize(n_label)(label_tensor) 34 | label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) 35 | return label_numpy.astype(imtype) 36 | 37 | def save_image(image_numpy, image_path): 38 | image_pil = Image.fromarray(image_numpy) 39 | image_pil.save(image_path) 40 | 41 | def mkdirs(paths): 42 | if isinstance(paths, list) and not isinstance(paths, str): 43 | for path in paths: 44 | mkdir(path) 45 | else: 46 | mkdir(paths) 47 | 48 | def mkdir(path): 49 | if not os.path.exists(path): 50 | os.makedirs(path) 51 | 52 | ############################################################################### 53 | # Code from 54 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py 55 | # Modified so it complies with the Citscape label map colors 56 | ############################################################################### 57 | def uint82bin(n, count=8): 58 | """returns the binary of integer n, count refers to amount of bits""" 59 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)]) 60 | 61 | def labelcolormap(N): 62 | if N == 35: # cityscape 63 | cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81), 64 | (128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153), 65 | (180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0), 66 | (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70), 67 | ( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)], 68 | dtype=np.uint8) 69 | else: 70 | cmap = np.zeros((N, 3), dtype=np.uint8) 71 | for i in range(N): 72 | r, g, b = 0, 0, 0 73 | id = i 74 | for j in range(7): 75 | str_id = uint82bin(id) 76 | r = r ^ (np.uint8(str_id[-1]) << (7-j)) 77 | g = g ^ (np.uint8(str_id[-2]) << (7-j)) 78 | b = b ^ (np.uint8(str_id[-3]) << (7-j)) 79 | id = id >> 3 80 | cmap[i, 0] = r 81 | cmap[i, 1] = g 82 | cmap[i, 2] = b 83 | return cmap 84 | 85 | class Colorize(object): 86 | def __init__(self, n=35): 87 | self.cmap = labelcolormap(n) 88 | self.cmap = torch.from_numpy(self.cmap[:n]) 89 | 90 | def __call__(self, gray_image): 91 | size = gray_image.size() 92 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 93 | 94 | for label in range(0, len(self.cmap)): 95 | mask = (label == gray_image[0]).cpu() 96 | color_image[0][mask] = self.cmap[label][0] 97 | color_image[1][mask] = self.cmap[label][1] 98 | color_image[2][mask] = self.cmap[label][2] 99 | 100 | return color_image 101 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import numpy as np 4 | import os 5 | import ntpath 6 | import time 7 | from . import util 8 | from . import html 9 | import scipy.misc 10 | try: 11 | from StringIO import StringIO # Python 2.7 12 | except ImportError: 13 | from io import BytesIO # Python 3.x 14 | 15 | class Visualizer(): 16 | def __init__(self, opt): 17 | # self.opt = opt 18 | self.tf_log = opt.tf_log 19 | self.use_html = opt.isTrain and not opt.no_html 20 | self.win_size = opt.display_winsize 21 | self.name = opt.name 22 | if self.tf_log: 23 | import tensorflow as tf 24 | self.tf = tf 25 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') 26 | self.writer = tf.summary.FileWriter(self.log_dir) 27 | 28 | if self.use_html: 29 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 30 | self.img_dir = os.path.join(self.web_dir, 'images') 31 | print('create web directory %s...' % self.web_dir) 32 | util.mkdirs([self.web_dir, self.img_dir]) 33 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 34 | with open(self.log_name, "a") as log_file: 35 | now = time.strftime("%c") 36 | log_file.write('================ Training Loss (%s) ================\n' % now) 37 | 38 | # |visuals|: dictionary of images to display or save 39 | def display_current_results(self, visuals, epoch, step): 40 | if self.tf_log: # show images in tensorboard output 41 | img_summaries = [] 42 | for label, image_numpy in visuals.items(): 43 | # Write the image to a string 44 | try: 45 | s = StringIO() 46 | except: 47 | s = BytesIO() 48 | scipy.misc.toimage(image_numpy).save(s, format="jpeg") 49 | # Create an Image object 50 | img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) 51 | # Create a Summary value 52 | img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) 53 | 54 | # Create and write Summary 55 | summary = self.tf.Summary(value=img_summaries) 56 | self.writer.add_summary(summary, step) 57 | 58 | if self.use_html: # save images to a html file 59 | for label, image_numpy in visuals.items(): 60 | if isinstance(image_numpy, list): 61 | for i in range(len(image_numpy)): 62 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i)) 63 | util.save_image(image_numpy[i], img_path) 64 | else: 65 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.jpg' % (epoch, label)) 66 | util.save_image(image_numpy, img_path) 67 | 68 | # update website 69 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=5) 70 | for n in range(epoch, 0, -1): 71 | webpage.add_header('epoch [%d]' % n) 72 | ims = [] 73 | txts = [] 74 | links = [] 75 | 76 | for label, image_numpy in visuals.items(): 77 | if isinstance(image_numpy, list): 78 | for i in range(len(image_numpy)): 79 | img_path = 'epoch%.3d_%s_%d.jpg' % (n, label, i) 80 | ims.append(img_path) 81 | txts.append(label+str(i)) 82 | links.append(img_path) 83 | else: 84 | img_path = 'epoch%.3d_%s.jpg' % (n, label) 85 | ims.append(img_path) 86 | txts.append(label) 87 | links.append(img_path) 88 | if len(ims) < 10: 89 | webpage.add_images(ims, txts, links, width=self.win_size) 90 | else: 91 | num = int(round(len(ims)/2.0)) 92 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) 93 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) 94 | webpage.save() 95 | 96 | # errors: dictionary of error labels and values 97 | def plot_current_errors(self, errors, step): 98 | if self.tf_log: 99 | for tag, value in errors.items(): 100 | summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) 101 | self.writer.add_summary(summary, step) 102 | 103 | # errors: same format as |errors| of plotCurrentErrors 104 | def print_current_errors(self, epoch, i, errors, t): 105 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) 106 | for k, v in errors.items(): 107 | if v != 0: 108 | message += '%s: %.3f ' % (k, v) 109 | 110 | print(message) 111 | with open(self.log_name, "a") as log_file: 112 | log_file.write('%s\n' % message) 113 | 114 | # save image to the disk 115 | def save_images(self, webpage, visuals, image_path): 116 | image_dir = webpage.get_image_dir() 117 | short_path = ntpath.basename(image_path[0]) 118 | name = os.path.splitext(short_path)[0] 119 | 120 | webpage.add_header(name) 121 | ims = [] 122 | txts = [] 123 | links = [] 124 | 125 | for label, image_numpy in visuals.items(): 126 | image_name = '%s_%s.jpg' % (name, label) 127 | save_path = os.path.join(image_dir, image_name) 128 | util.save_image(image_numpy, save_path) 129 | 130 | ims.append(image_name) 131 | txts.append(label) 132 | links.append(image_name) 133 | webpage.add_images(ims, txts, links, width=self.win_size) 134 | --------------------------------------------------------------------------------