├── .gitignore
├── LICENSE.txt
├── README.md
├── _config.yml
├── data
├── __init__.py
├── aligned_dataset.py
├── base_data_loader.py
├── base_dataset.py
├── custom_dataset_data_loader.py
├── data_loader.py
└── image_folder.py
├── datasets
└── cityscapes
│ ├── test_inst
│ ├── frankfurt_000000_000576_gtFine_instanceIds.png
│ ├── frankfurt_000000_001236_gtFine_instanceIds.png
│ ├── frankfurt_000000_003357_gtFine_instanceIds.png
│ ├── frankfurt_000000_011810_gtFine_instanceIds.png
│ ├── frankfurt_000000_012868_gtFine_instanceIds.png
│ ├── frankfurt_000001_013710_gtFine_instanceIds.png
│ ├── frankfurt_000001_015328_gtFine_instanceIds.png
│ ├── frankfurt_000001_023769_gtFine_instanceIds.png
│ ├── frankfurt_000001_028335_gtFine_instanceIds.png
│ ├── frankfurt_000001_032711_gtFine_instanceIds.png
│ ├── frankfurt_000001_033655_gtFine_instanceIds.png
│ ├── frankfurt_000001_042733_gtFine_instanceIds.png
│ ├── frankfurt_000001_047552_gtFine_instanceIds.png
│ ├── frankfurt_000001_054640_gtFine_instanceIds.png
│ └── frankfurt_000001_055387_gtFine_instanceIds.png
│ ├── test_label
│ ├── frankfurt_000000_000576_gtFine_labelIds.png
│ ├── frankfurt_000000_001236_gtFine_labelIds.png
│ ├── frankfurt_000000_003357_gtFine_labelIds.png
│ ├── frankfurt_000000_011810_gtFine_labelIds.png
│ ├── frankfurt_000000_012868_gtFine_labelIds.png
│ ├── frankfurt_000001_013710_gtFine_labelIds.png
│ ├── frankfurt_000001_015328_gtFine_labelIds.png
│ ├── frankfurt_000001_023769_gtFine_labelIds.png
│ ├── frankfurt_000001_028335_gtFine_labelIds.png
│ ├── frankfurt_000001_032711_gtFine_labelIds.png
│ ├── frankfurt_000001_033655_gtFine_labelIds.png
│ ├── frankfurt_000001_042733_gtFine_labelIds.png
│ ├── frankfurt_000001_047552_gtFine_labelIds.png
│ ├── frankfurt_000001_054640_gtFine_labelIds.png
│ └── frankfurt_000001_055387_gtFine_labelIds.png
│ ├── train_img
│ ├── aachen_000000_000019_leftImg8bit.png
│ ├── aachen_000001_000019_leftImg8bit.png
│ ├── aachen_000002_000019_leftImg8bit.png
│ ├── aachen_000003_000019_leftImg8bit.png
│ └── aachen_000004_000019_leftImg8bit.png
│ ├── train_inst
│ ├── aachen_000000_000019_gtFine_instanceIds.png
│ ├── aachen_000001_000019_gtFine_instanceIds.png
│ ├── aachen_000002_000019_gtFine_instanceIds.png
│ ├── aachen_000003_000019_gtFine_instanceIds.png
│ └── aachen_000004_000019_gtFine_instanceIds.png
│ └── train_label
│ ├── aachen_000000_000019_gtFine_labelIds.png
│ ├── aachen_000001_000019_gtFine_labelIds.png
│ ├── aachen_000002_000019_gtFine_labelIds.png
│ ├── aachen_000003_000019_gtFine_labelIds.png
│ └── aachen_000004_000019_gtFine_labelIds.png
├── encode_features.py
├── imgs
├── city_short.gif
├── cityscapes_1.jpg
├── cityscapes_2.jpg
├── cityscapes_3.jpg
├── cityscapes_4.jpg
├── face1_1.jpg
├── face1_2.jpg
├── face1_3.jpg
├── face2_1.jpg
├── face2_2.jpg
├── face2_3.jpg
├── face_short.gif
├── teaser_720.gif
├── teaser_label.gif
├── teaser_label.png
├── teaser_ours.jpg
└── teaser_style.gif
├── models
├── __init__.py
├── base_model.py
├── models.py
├── networks.py
└── pix2pixHD_model.py
├── options
├── __init__.py
├── base_options.py
├── test_options.py
└── train_options.py
├── precompute_feature_maps.py
├── scripts
├── test_1024p.sh
├── test_1024p_feat.sh
├── test_512p.sh
├── test_512p_feat.sh
├── train_1024p_12G.sh
├── train_1024p_24G.sh
├── train_1024p_feat_12G.sh
├── train_1024p_feat_24G.sh
├── train_512p.sh
├── train_512p_feat.sh
└── train_512p_multigpu.sh
├── test.py
├── train.py
└── util
├── __init__.py
├── html.py
├── image_pool.py
├── util.py
└── visualizer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | debug*
2 | checkpoints/
3 | results/
4 | build/
5 | dist/
6 | torch.egg-info/
7 | */**/__pycache__
8 | torch/version.py
9 | torch/csrc/generic/TensorMethods.cpp
10 | torch/lib/*.so*
11 | torch/lib/*.dylib*
12 | torch/lib/*.h
13 | torch/lib/build
14 | torch/lib/tmp_install
15 | torch/lib/include
16 | torch/lib/torch_shm_manager
17 | torch/csrc/cudnn/cuDNN.cpp
18 | torch/csrc/nn/THNN.cwrap
19 | torch/csrc/nn/THNN.cpp
20 | torch/csrc/nn/THCUNN.cwrap
21 | torch/csrc/nn/THCUNN.cpp
22 | torch/csrc/nn/THNN_generic.cwrap
23 | torch/csrc/nn/THNN_generic.cpp
24 | torch/csrc/nn/THNN_generic.h
25 | docs/src/**/*
26 | test/data/legacy_modules.t7
27 | test/data/gpu_tensors.pt
28 | test/htmlcov
29 | test/.coverage
30 | */*.pyc
31 | */**/*.pyc
32 | */**/**/*.pyc
33 | */**/**/**/*.pyc
34 | */**/**/**/**/*.pyc
35 | */*.so*
36 | */**/*.so*
37 | */**/*.dylib*
38 | test/data/legacy_serialized.pt
39 | *.DS_Store
40 | *~
41 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (C) 2017 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu.
2 | All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 |
5 | Permission to use, copy, modify, and distribute this software and its documentation
6 | for any non-commercial purpose is hereby granted without fee, provided that the above
7 | copyright notice appear in all copies and that both that copyright notice and this
8 | permission notice appear in supporting documentation, and that the name of the author
9 | not be used in advertising or publicity pertaining to distribution of the software
10 | without specific, written prior permission.
11 |
12 | THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
13 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
14 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
15 | DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
16 | WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
17 | OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
18 |
19 |
20 | --------------------------- LICENSE FOR pytorch-CycleGAN-and-pix2pix ----------------
21 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
22 | All rights reserved.
23 |
24 | Redistribution and use in source and binary forms, with or without
25 | modification, are permitted provided that the following conditions are met:
26 |
27 | * Redistributions of source code must retain the above copyright notice, this
28 | list of conditions and the following disclaimer.
29 |
30 | * Redistributions in binary form must reproduce the above copyright notice,
31 | this list of conditions and the following disclaimer in the documentation
32 | and/or other materials provided with the distribution.
33 |
34 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
35 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
36 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
37 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
38 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
39 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
40 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
41 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
42 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
43 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
44 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # pix2pixHD
6 | ### [[Project]](https://tcwang0509.github.io/pix2pixHD/) [[Youtube]](https://youtu.be/3AIpPlzM_qs) [[Paper]](https://arxiv.org/pdf/1711.11585.pdf)
7 | Pytorch implementation of our method for high-resolution (e.g. 2048x1024) photorealistic image-to-image translation. It can be used for turning semantic label maps into photo-realistic images or synthesizing portraits from face label maps.
8 | [High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs](https://tcwang0509.github.io/pix2pixHD/)
9 | [Ting-Chun Wang](https://tcwang0509.github.io/)1, [Ming-Yu Liu](http://mingyuliu.net/)1, [Jun-Yan Zhu](http://people.eecs.berkeley.edu/~junyanz/)2, Andrew Tao1, [Jan Kautz](http://jankautz.com/)1, [Bryan Catanzaro](http://catanzaro.name/)1
10 | 1NVIDIA Corporation, 2UC Berkeley
11 | In arxiv, 2017.
12 |
13 | ## Image-to-image translation at 2k/1k resolution
14 | - Our label-to-streetview results
15 |
16 |
17 |
18 |
19 | - Interactive editing results
20 |
21 |
22 |
23 |
24 | - Additional streetview results
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 | - Label-to-face and interactive editing results
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 | - Our editing interface
47 |
48 |
49 |
50 |
51 |
52 | ## Prerequisites
53 | - Linux or macOS
54 | - Python 2 or 3
55 | - NVIDIA GPU (12G or 24G memory) + CUDA cuDNN
56 |
57 | ## Getting Started
58 | ### Installation
59 | - Install PyTorch and dependencies from http://pytorch.org
60 | - Install python libraries [dominate](https://github.com/Knio/dominate).
61 | ```bash
62 | pip install dominate
63 | ```
64 | - Clone this repo:
65 | ```bash
66 | git clone https://github.com/NVIDIA/pix2pixHD
67 | cd pix2pixHD
68 | ```
69 |
70 |
71 | ### Testing
72 | - A few example Cityscapes test images are included in the `datasets` folder.
73 | - Please download the pre-trained Cityscapes model from [here](https://drive.google.com/file/d/1h9SykUnuZul7J3Nbms2QGH1wa85nbN2-/view?usp=sharing) (google drive link), and put it under `./checkpoints/label2city_1024p/`
74 | - Test the model (`bash ./scripts/test_1024p.sh`):
75 | ```bash
76 | #!./scripts/test_1024p.sh
77 | python test.py --name label2city_1024p --netG local --ngf 32 --resize_or_crop none
78 | ```
79 | The test results will be saved to a html file here: `./results/label2city_1024p/test_latest/index.html`.
80 |
81 | More example scripts can be found in the `scripts` directory.
82 |
83 |
84 | ### Dataset
85 | - We use the Cityscapes dataset. To train a model on the full dataset, please download it from the [official website](https://www.cityscapes-dataset.com/) (registration required).
86 | After downloading, please put it under the `datasets` folder in the same way the example images are provided.
87 |
88 |
89 | ### Training
90 | - Train a model at 1024 x 512 resolution (`bash ./scripts/train_512p.sh`):
91 | ```bash
92 | #!./scripts/train_512p.sh
93 | python train.py --name label2city_512p
94 | ```
95 | - To view training results, please checkout intermediate results in `./checkpoints/label2city_512p/web/index.html`.
96 | If you have tensorflow installed, you can see tensorboard logs in `./checkpoints/label2city_512p/logs` by adding `--tf_log` to the training scripts.
97 |
98 | ### Multi-GPU training
99 | - Train a model using multiple GPUs (`bash ./scripts/train_512p_multigpu.sh`):
100 | ```bash
101 | #!./scripts/train_512p_multigpu.sh
102 | python train.py --name label2city_512p --batchSize 8 --gpu_ids 0,1,2,3,4,5,6,7
103 | ```
104 | Note: this is not tested and we trained our model using single GPU only. Please use at your own discretion.
105 |
106 | ### Training at full resolution
107 | - To train the images at full resolution (2048 x 1024) requires a GPU with 24G memory (`bash ./scripts/train_1024p_24G.sh`).
108 | If only GPUs with 12G memory are available, please use the 12G script (`bash ./scripts/train_1024p_12G.sh`), which will crop the images during training. Performance is not guaranteed using this script.
109 |
110 | ### Training with your own dataset
111 | - If you want to train with your own dataset, please generate label maps which are one-channel whose pixel values correspond to the object labels (i.e. 0,1,...,N-1, where N is the number of labels). This is because we need to generate one-hot vectors from the label maps. Please also specity `--label_nc N` during both training and testing.
112 | - If your input is not a label map, please just specify `--label_nc 0` which will directly use the RGB colors as input.
113 | - If you don't have instance maps or don't want to use them, please specify `--no_instance`.
114 | - The default setting for preprocessing is `scale_width`, which will scale the width of all training images to `opt.loadSize` (1024) while keeping the aspect ratio. If you want a different setting, please change it by using the `--resize_or_crop` option. For example, `scale_width_and_crop` first resizes the image to have width `opt.loadSize` and then does random cropping of size `(opt.fineSize, opt.fineSize)`. `crop` skips the resizing step and only performs random cropping. If you don't want any preprocessing, please specify `none`, which will do nothing other than making sure the image is divisible by 32.
115 |
116 | ## More Training/Test Details
117 | - Flags: see `options/train_options.py` and `options/base_options.py` for all the training flags; see `options/test_options.py` and `options/base_options.py` for all the test flags.
118 | - Instance map: we take in both label maps and instance maps as input. If you don't want to use instance maps, please specify the flag `--no_instance`.
119 |
120 |
121 | ## Citation
122 |
123 | If you find this useful for your research, please use the following.
124 |
125 | ```
126 | @article{wang2017highres,
127 | title={High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs},
128 | author={Ting-Chun Wang and Ming-Yu Liu and Jun-Yan Zhu and Andrew Tao and Jan Kautz and Bryan Catanzaro},
129 | journal={arXiv preprint arXiv:1711.11585},
130 | year={2017}
131 | }
132 | ```
133 |
134 | ## Acknowledgments
135 | This code borrows heavily from [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).
136 |
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-minimal
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/data/__init__.py
--------------------------------------------------------------------------------
/data/aligned_dataset.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import os.path
4 | from data.base_dataset import BaseDataset, get_params, get_transform, normalize
5 | from data.image_folder import make_dataset
6 | from PIL import Image
7 |
8 | class AlignedDataset(BaseDataset):
9 | def initialize(self, opt):
10 | self.opt = opt
11 | self.root = opt.dataroot
12 |
13 | ### label maps
14 | self.dir_label = os.path.join(opt.dataroot, opt.phase + '_label')
15 | self.label_paths = sorted(make_dataset(self.dir_label))
16 |
17 | ### real images
18 | if opt.isTrain:
19 | self.dir_image = os.path.join(opt.dataroot, opt.phase + '_img')
20 | self.image_paths = sorted(make_dataset(self.dir_image))
21 |
22 | ### instance maps
23 | if not opt.no_instance:
24 | self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst')
25 | self.inst_paths = sorted(make_dataset(self.dir_inst))
26 |
27 | ### load precomputed instance-wise encoded features
28 | if opt.load_features:
29 | self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat')
30 | print('----------- loading features from %s ----------' % self.dir_feat)
31 | self.feat_paths = sorted(make_dataset(self.dir_feat))
32 |
33 | self.dataset_size = len(self.label_paths)
34 |
35 | def __getitem__(self, index):
36 | ### label maps
37 | label_path = self.label_paths[index]
38 | label = Image.open(label_path)
39 | params = get_params(self.opt, label.size)
40 | if self.opt.label_nc == 0:
41 | transform_label = get_transform(self.opt, params)
42 | label_tensor = transform_label(label.convert('RGB'))
43 | else:
44 | transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
45 | label_tensor = transform_label(label) * 255.0
46 |
47 | image_tensor = inst_tensor = feat_tensor = 0
48 | ### real images
49 | if self.opt.isTrain:
50 | image_path = self.image_paths[index]
51 | image = Image.open(image_path).convert('RGB')
52 | transform_image = get_transform(self.opt, params)
53 | image_tensor = transform_image(image)
54 |
55 | ### if using instance maps
56 | if not self.opt.no_instance:
57 | inst_path = self.inst_paths[index]
58 | inst = Image.open(inst_path)
59 | inst_tensor = transform_label(inst)
60 |
61 | if self.opt.load_features:
62 | feat_path = self.feat_paths[index]
63 | feat = Image.open(feat_path).convert('RGB')
64 | norm = normalize()
65 | feat_tensor = norm(transform_label(feat))
66 |
67 | input_dict = {'label': label_tensor, 'inst': inst_tensor, 'image': image_tensor,
68 | 'feat': feat_tensor, 'path': label_path}
69 |
70 | return input_dict
71 |
72 | def __len__(self):
73 | return len(self.label_paths)
74 |
75 | def name(self):
76 | return 'AlignedDataset'
--------------------------------------------------------------------------------
/data/base_data_loader.py:
--------------------------------------------------------------------------------
1 |
2 | class BaseDataLoader():
3 | def __init__(self):
4 | pass
5 |
6 | def initialize(self, opt):
7 | self.opt = opt
8 | pass
9 |
10 | def load_data():
11 | return None
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import torch.utils.data as data
4 | from PIL import Image
5 | import torchvision.transforms as transforms
6 | import numpy as np
7 | import random
8 |
9 | class BaseDataset(data.Dataset):
10 | def __init__(self):
11 | super(BaseDataset, self).__init__()
12 |
13 | def name(self):
14 | return 'BaseDataset'
15 |
16 | def initialize(self, opt):
17 | pass
18 |
19 | def get_params(opt, size):
20 | w, h = size
21 | new_h = h
22 | new_w = w
23 | if opt.resize_or_crop == 'resize_and_crop':
24 | new_h = new_w = opt.loadSize
25 | elif opt.resize_or_crop == 'scale_width_and_crop':
26 | new_w = opt.loadSize
27 | new_h = opt.loadSize * h // w
28 |
29 | x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
30 | y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
31 |
32 | flip = random.random() > 0.5
33 | return {'crop_pos': (x, y), 'flip': flip}
34 |
35 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
36 | transform_list = []
37 | if 'resize' in opt.resize_or_crop:
38 | osize = [opt.loadSize, opt.loadSize]
39 | transform_list.append(transforms.Scale(osize, method))
40 | elif 'scale_width' in opt.resize_or_crop:
41 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
42 |
43 | if 'crop' in opt.resize_or_crop:
44 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
45 |
46 | if opt.resize_or_crop == 'none':
47 | base = float(2 ** opt.n_downsample_global)
48 | if opt.netG == 'local':
49 | base *= (2 ** opt.n_local_enhancers)
50 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
51 |
52 | if opt.isTrain and not opt.no_flip:
53 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
54 |
55 | transform_list += [transforms.ToTensor()]
56 |
57 | if normalize:
58 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
59 | (0.5, 0.5, 0.5))]
60 | return transforms.Compose(transform_list)
61 |
62 | def normalize():
63 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
64 |
65 | def __make_power_2(img, base, method=Image.BICUBIC):
66 | ow, oh = img.size
67 | h = int(round(oh / base) * base)
68 | w = int(round(ow / base) * base)
69 | if (h == oh) and (w == ow):
70 | return img
71 | return img.resize((w, h), method)
72 |
73 | def __scale_width(img, target_width, method=Image.BICUBIC):
74 | ow, oh = img.size
75 | if (ow == target_width):
76 | return img
77 | w = target_width
78 | h = int(target_width * oh / ow)
79 | return img.resize((w, h), method)
80 |
81 | def __crop(img, pos, size):
82 | ow, oh = img.size
83 | x1, y1 = pos
84 | tw = th = size
85 | if (ow > tw or oh > th):
86 | return img.crop((x1, y1, x1 + tw, y1 + th))
87 | return img
88 |
89 | def __flip(img, flip):
90 | if flip:
91 | return img.transpose(Image.FLIP_LEFT_RIGHT)
92 | return img
93 |
--------------------------------------------------------------------------------
/data/custom_dataset_data_loader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from data.base_data_loader import BaseDataLoader
3 |
4 |
5 | def CreateDataset(opt):
6 | dataset = None
7 | from data.aligned_dataset import AlignedDataset
8 | dataset = AlignedDataset()
9 |
10 | print("dataset [%s] was created" % (dataset.name()))
11 | dataset.initialize(opt)
12 | return dataset
13 |
14 | class CustomDatasetDataLoader(BaseDataLoader):
15 | def name(self):
16 | return 'CustomDatasetDataLoader'
17 |
18 | def initialize(self, opt):
19 | BaseDataLoader.initialize(self, opt)
20 | self.dataset = CreateDataset(opt)
21 | self.dataloader = torch.utils.data.DataLoader(
22 | self.dataset,
23 | batch_size=opt.batchSize,
24 | shuffle=not opt.serial_batches,
25 | num_workers=int(opt.nThreads))
26 |
27 | def load_data(self):
28 | return self.dataloader
29 |
30 | def __len__(self):
31 | return min(len(self.dataset), self.opt.max_dataset_size)
32 |
--------------------------------------------------------------------------------
/data/data_loader.py:
--------------------------------------------------------------------------------
1 |
2 | def CreateDataLoader(opt):
3 | from data.custom_dataset_data_loader import CustomDatasetDataLoader
4 | data_loader = CustomDatasetDataLoader()
5 | print(data_loader.name())
6 | data_loader.initialize(opt)
7 | return data_loader
8 |
--------------------------------------------------------------------------------
/data/image_folder.py:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Code from
3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4 | # Modified the original code so that it also loads images from the current
5 | # directory as well as the subdirectories
6 | ###############################################################################
7 | import torch.utils.data as data
8 | from PIL import Image
9 | import os
10 |
11 | IMG_EXTENSIONS = [
12 | '.jpg', '.JPG', '.jpeg', '.JPEG',
13 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
14 | ]
15 |
16 |
17 | def is_image_file(filename):
18 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
19 |
20 |
21 | def make_dataset(dir):
22 | images = []
23 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
24 |
25 | for root, _, fnames in sorted(os.walk(dir)):
26 | for fname in fnames:
27 | if is_image_file(fname):
28 | path = os.path.join(root, fname)
29 | images.append(path)
30 |
31 | return images
32 |
33 |
34 | def default_loader(path):
35 | return Image.open(path).convert('RGB')
36 |
37 |
38 | class ImageFolder(data.Dataset):
39 |
40 | def __init__(self, root, transform=None, return_paths=False,
41 | loader=default_loader):
42 | imgs = make_dataset(root)
43 | if len(imgs) == 0:
44 | raise(RuntimeError("Found 0 images in: " + root + "\n"
45 | "Supported image extensions are: " +
46 | ",".join(IMG_EXTENSIONS)))
47 |
48 | self.root = root
49 | self.imgs = imgs
50 | self.transform = transform
51 | self.return_paths = return_paths
52 | self.loader = loader
53 |
54 | def __getitem__(self, index):
55 | path = self.imgs[index]
56 | img = self.loader(path)
57 | if self.transform is not None:
58 | img = self.transform(img)
59 | if self.return_paths:
60 | return img, path
61 | else:
62 | return img
63 |
64 | def __len__(self):
65 | return len(self.imgs)
66 |
--------------------------------------------------------------------------------
/datasets/cityscapes/test_inst/frankfurt_000000_000576_gtFine_instanceIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000000_000576_gtFine_instanceIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_inst/frankfurt_000000_001236_gtFine_instanceIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000000_001236_gtFine_instanceIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_inst/frankfurt_000000_003357_gtFine_instanceIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000000_003357_gtFine_instanceIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_inst/frankfurt_000000_011810_gtFine_instanceIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000000_011810_gtFine_instanceIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_inst/frankfurt_000000_012868_gtFine_instanceIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000000_012868_gtFine_instanceIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_inst/frankfurt_000001_013710_gtFine_instanceIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000001_013710_gtFine_instanceIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_inst/frankfurt_000001_015328_gtFine_instanceIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000001_015328_gtFine_instanceIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_inst/frankfurt_000001_023769_gtFine_instanceIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000001_023769_gtFine_instanceIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_inst/frankfurt_000001_028335_gtFine_instanceIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000001_028335_gtFine_instanceIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_inst/frankfurt_000001_032711_gtFine_instanceIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000001_032711_gtFine_instanceIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_inst/frankfurt_000001_033655_gtFine_instanceIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000001_033655_gtFine_instanceIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_inst/frankfurt_000001_042733_gtFine_instanceIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000001_042733_gtFine_instanceIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_inst/frankfurt_000001_047552_gtFine_instanceIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000001_047552_gtFine_instanceIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_inst/frankfurt_000001_054640_gtFine_instanceIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000001_054640_gtFine_instanceIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_inst/frankfurt_000001_055387_gtFine_instanceIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_inst/frankfurt_000001_055387_gtFine_instanceIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_label/frankfurt_000000_000576_gtFine_labelIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000000_000576_gtFine_labelIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_label/frankfurt_000000_001236_gtFine_labelIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000000_001236_gtFine_labelIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_label/frankfurt_000000_003357_gtFine_labelIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000000_003357_gtFine_labelIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_label/frankfurt_000000_011810_gtFine_labelIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000000_011810_gtFine_labelIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_label/frankfurt_000000_012868_gtFine_labelIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000000_012868_gtFine_labelIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_label/frankfurt_000001_013710_gtFine_labelIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000001_013710_gtFine_labelIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_label/frankfurt_000001_015328_gtFine_labelIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000001_015328_gtFine_labelIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_label/frankfurt_000001_023769_gtFine_labelIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000001_023769_gtFine_labelIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_label/frankfurt_000001_028335_gtFine_labelIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000001_028335_gtFine_labelIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_label/frankfurt_000001_032711_gtFine_labelIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000001_032711_gtFine_labelIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_label/frankfurt_000001_033655_gtFine_labelIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000001_033655_gtFine_labelIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_label/frankfurt_000001_042733_gtFine_labelIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000001_042733_gtFine_labelIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_label/frankfurt_000001_047552_gtFine_labelIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000001_047552_gtFine_labelIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_label/frankfurt_000001_054640_gtFine_labelIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000001_054640_gtFine_labelIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/test_label/frankfurt_000001_055387_gtFine_labelIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/test_label/frankfurt_000001_055387_gtFine_labelIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/train_img/aachen_000000_000019_leftImg8bit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_img/aachen_000000_000019_leftImg8bit.png
--------------------------------------------------------------------------------
/datasets/cityscapes/train_img/aachen_000001_000019_leftImg8bit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_img/aachen_000001_000019_leftImg8bit.png
--------------------------------------------------------------------------------
/datasets/cityscapes/train_img/aachen_000002_000019_leftImg8bit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_img/aachen_000002_000019_leftImg8bit.png
--------------------------------------------------------------------------------
/datasets/cityscapes/train_img/aachen_000003_000019_leftImg8bit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_img/aachen_000003_000019_leftImg8bit.png
--------------------------------------------------------------------------------
/datasets/cityscapes/train_img/aachen_000004_000019_leftImg8bit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_img/aachen_000004_000019_leftImg8bit.png
--------------------------------------------------------------------------------
/datasets/cityscapes/train_inst/aachen_000000_000019_gtFine_instanceIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_inst/aachen_000000_000019_gtFine_instanceIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/train_inst/aachen_000001_000019_gtFine_instanceIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_inst/aachen_000001_000019_gtFine_instanceIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/train_inst/aachen_000002_000019_gtFine_instanceIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_inst/aachen_000002_000019_gtFine_instanceIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/train_inst/aachen_000003_000019_gtFine_instanceIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_inst/aachen_000003_000019_gtFine_instanceIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/train_inst/aachen_000004_000019_gtFine_instanceIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_inst/aachen_000004_000019_gtFine_instanceIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/train_label/aachen_000000_000019_gtFine_labelIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_label/aachen_000000_000019_gtFine_labelIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/train_label/aachen_000001_000019_gtFine_labelIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_label/aachen_000001_000019_gtFine_labelIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/train_label/aachen_000002_000019_gtFine_labelIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_label/aachen_000002_000019_gtFine_labelIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/train_label/aachen_000003_000019_gtFine_labelIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_label/aachen_000003_000019_gtFine_labelIds.png
--------------------------------------------------------------------------------
/datasets/cityscapes/train_label/aachen_000004_000019_gtFine_labelIds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/datasets/cityscapes/train_label/aachen_000004_000019_gtFine_labelIds.png
--------------------------------------------------------------------------------
/encode_features.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | from options.train_options import TrainOptions
4 | from data.data_loader import CreateDataLoader
5 | from models.models import create_model
6 | import numpy as np
7 | import os
8 |
9 | opt = TrainOptions().parse()
10 | opt.nThreads = 1
11 | opt.batchSize = 1
12 | opt.serial_batches = True
13 | opt.no_flip = True
14 | opt.instance_feat = True
15 |
16 | name = 'features'
17 | save_path = os.path.join(opt.checkpoints_dir, opt.name)
18 |
19 | ############ Initialize #########
20 | data_loader = CreateDataLoader(opt)
21 | dataset = data_loader.load_data()
22 | dataset_size = len(data_loader)
23 | model = create_model(opt)
24 |
25 | ########### Encode features ###########
26 | reencode = True
27 | if reencode:
28 | features = {}
29 | for label in range(opt.label_nc):
30 | features[label] = np.zeros((0, opt.feat_num+1))
31 | for i, data in enumerate(dataset):
32 | feat = model.module.encode_features(data['image'], data['inst'])
33 | for label in range(opt.label_nc):
34 | features[label] = np.append(features[label], feat[label], axis=0)
35 |
36 | print('%d / %d images' % (i+1, dataset_size))
37 | save_name = os.path.join(save_path, name + '.npy')
38 | np.save(save_name, features)
39 |
40 | ############## Clustering ###########
41 | n_clusters = opt.n_clusters
42 | load_name = os.path.join(save_path, name + '.npy')
43 | features = np.load(load_name).item()
44 | from sklearn.cluster import KMeans
45 | centers = {}
46 | for label in range(opt.label_nc):
47 | feat = features[label]
48 | feat = feat[feat[:,-1] > 0.5, :-1]
49 | if feat.shape[0]:
50 | n_clusters = min(feat.shape[0], opt.n_clusters)
51 | kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(feat)
52 | centers[label] = kmeans.cluster_centers_
53 | save_name = os.path.join(save_path, name + '_clustered_%03d.npy' % opt.n_clusters)
54 | np.save(save_name, centers)
55 | print('saving to %s' % save_name)
--------------------------------------------------------------------------------
/imgs/city_short.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/city_short.gif
--------------------------------------------------------------------------------
/imgs/cityscapes_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/cityscapes_1.jpg
--------------------------------------------------------------------------------
/imgs/cityscapes_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/cityscapes_2.jpg
--------------------------------------------------------------------------------
/imgs/cityscapes_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/cityscapes_3.jpg
--------------------------------------------------------------------------------
/imgs/cityscapes_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/cityscapes_4.jpg
--------------------------------------------------------------------------------
/imgs/face1_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/face1_1.jpg
--------------------------------------------------------------------------------
/imgs/face1_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/face1_2.jpg
--------------------------------------------------------------------------------
/imgs/face1_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/face1_3.jpg
--------------------------------------------------------------------------------
/imgs/face2_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/face2_1.jpg
--------------------------------------------------------------------------------
/imgs/face2_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/face2_2.jpg
--------------------------------------------------------------------------------
/imgs/face2_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/face2_3.jpg
--------------------------------------------------------------------------------
/imgs/face_short.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/face_short.gif
--------------------------------------------------------------------------------
/imgs/teaser_720.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/teaser_720.gif
--------------------------------------------------------------------------------
/imgs/teaser_label.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/teaser_label.gif
--------------------------------------------------------------------------------
/imgs/teaser_label.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/teaser_label.png
--------------------------------------------------------------------------------
/imgs/teaser_ours.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/teaser_ours.jpg
--------------------------------------------------------------------------------
/imgs/teaser_style.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/imgs/teaser_style.gif
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/models/__init__.py
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import os
4 | import torch
5 |
6 | class BaseModel(torch.nn.Module):
7 | def name(self):
8 | return 'BaseModel'
9 |
10 | def initialize(self, opt):
11 | self.opt = opt
12 | self.gpu_ids = opt.gpu_ids
13 | self.isTrain = opt.isTrain
14 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
15 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
16 |
17 | def set_input(self, input):
18 | self.input = input
19 |
20 | def forward(self):
21 | pass
22 |
23 | # used in test time, no backprop
24 | def test(self):
25 | pass
26 |
27 | def get_image_paths(self):
28 | pass
29 |
30 | def optimize_parameters(self):
31 | pass
32 |
33 | def get_current_visuals(self):
34 | return self.input
35 |
36 | def get_current_errors(self):
37 | return {}
38 |
39 | def save(self, label):
40 | pass
41 |
42 | # helper saving function that can be used by subclasses
43 | def save_network(self, network, network_label, epoch_label, gpu_ids):
44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
45 | save_path = os.path.join(self.save_dir, save_filename)
46 | torch.save(network.cpu().state_dict(), save_path)
47 | if len(gpu_ids) and torch.cuda.is_available():
48 | network.cuda()
49 |
50 | # helper loading function that can be used by subclasses
51 | def load_network(self, network, network_label, epoch_label, save_dir=''):
52 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
53 | if not save_dir:
54 | save_dir = self.save_dir
55 | save_path = os.path.join(save_dir, save_filename)
56 | if not os.path.isfile(save_path):
57 | print('%s not exists yet!' % save_path)
58 | if network_label == 'G':
59 | raise('Generator must exist!')
60 | else:
61 | #network.load_state_dict(torch.load(save_path))
62 | try:
63 | network.load_state_dict(torch.load(save_path))
64 | except:
65 | pretrained_dict = torch.load(save_path)
66 | model_dict = network.state_dict()
67 | try:
68 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
69 | network.load_state_dict(pretrained_dict)
70 | print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
71 | except:
72 | print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
73 | from sets import Set
74 | not_initialized = Set()
75 | for k, v in pretrained_dict.items():
76 | if v.size() == model_dict[k].size():
77 | model_dict[k] = v
78 |
79 | for k, v in model_dict.items():
80 | if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
81 | not_initialized.add(k.split('.')[0])
82 | print(sorted(not_initialized))
83 | network.load_state_dict(model_dict)
84 |
85 | def update_learning_rate():
86 | pass
87 |
--------------------------------------------------------------------------------
/models/models.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import torch
4 |
5 | def create_model(opt):
6 | from .pix2pixHD_model import Pix2PixHDModel
7 | model = Pix2PixHDModel()
8 | model.initialize(opt)
9 | print("model [%s] was created" % (model.name()))
10 |
11 | if opt.isTrain and len(opt.gpu_ids):
12 | model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
13 |
14 | return model
15 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import torch
4 | import torch.nn as nn
5 | import functools
6 | from torch.autograd import Variable
7 | import numpy as np
8 |
9 | ###############################################################################
10 | # Functions
11 | ###############################################################################
12 | def weights_init(m):
13 | classname = m.__class__.__name__
14 | if classname.find('Conv') != -1:
15 | m.weight.data.normal_(0.0, 0.02)
16 | elif classname.find('BatchNorm2d') != -1:
17 | m.weight.data.normal_(1.0, 0.02)
18 | m.bias.data.fill_(0)
19 |
20 | def get_norm_layer(norm_type='instance'):
21 | if norm_type == 'batch':
22 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
23 | elif norm_type == 'instance':
24 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
25 | else:
26 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
27 | return norm_layer
28 |
29 | def define_G(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
30 | n_blocks_local=3, norm='instance', gpu_ids=[]):
31 | norm_layer = get_norm_layer(norm_type=norm)
32 | if netG == 'global':
33 | netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
34 | elif netG == 'local':
35 | netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global,
36 | n_local_enhancers, n_blocks_local, norm_layer)
37 | elif netG == 'encoder':
38 | netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, norm_layer)
39 | else:
40 | raise('generator not implemented!')
41 | print(netG)
42 | if len(gpu_ids) > 0:
43 | assert(torch.cuda.is_available())
44 | netG.cuda(gpu_ids[0])
45 | netG.apply(weights_init)
46 | return netG
47 |
48 | def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]):
49 | norm_layer = get_norm_layer(norm_type=norm)
50 | netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)
51 | print(netD)
52 | if len(gpu_ids) > 0:
53 | assert(torch.cuda.is_available())
54 | netD.cuda(gpu_ids[0])
55 | netD.apply(weights_init)
56 | return netD
57 |
58 | def print_network(net):
59 | if isinstance(net, list):
60 | net = net[0]
61 | num_params = 0
62 | for param in net.parameters():
63 | num_params += param.numel()
64 | print(net)
65 | print('Total number of parameters: %d' % num_params)
66 |
67 | ##############################################################################
68 | # Losses
69 | ##############################################################################
70 | class GANLoss(nn.Module):
71 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
72 | tensor=torch.FloatTensor):
73 | super(GANLoss, self).__init__()
74 | self.real_label = target_real_label
75 | self.fake_label = target_fake_label
76 | self.real_label_var = None
77 | self.fake_label_var = None
78 | self.Tensor = tensor
79 | if use_lsgan:
80 | self.loss = nn.MSELoss()
81 | else:
82 | self.loss = nn.BCELoss()
83 |
84 | def get_target_tensor(self, input, target_is_real):
85 | target_tensor = None
86 | if target_is_real:
87 | create_label = ((self.real_label_var is None) or
88 | (self.real_label_var.numel() != input.numel()))
89 | if create_label:
90 | real_tensor = self.Tensor(input.size()).fill_(self.real_label)
91 | self.real_label_var = Variable(real_tensor, requires_grad=False)
92 | target_tensor = self.real_label_var
93 | else:
94 | create_label = ((self.fake_label_var is None) or
95 | (self.fake_label_var.numel() != input.numel()))
96 | if create_label:
97 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
98 | self.fake_label_var = Variable(fake_tensor, requires_grad=False)
99 | target_tensor = self.fake_label_var
100 | return target_tensor
101 |
102 | def __call__(self, input, target_is_real):
103 | if isinstance(input[0], list):
104 | loss = 0
105 | for input_i in input:
106 | pred = input_i[-1]
107 | target_tensor = self.get_target_tensor(pred, target_is_real)
108 | loss += self.loss(pred, target_tensor)
109 | return loss
110 | else:
111 | target_tensor = self.get_target_tensor(input[-1], target_is_real)
112 | return self.loss(input[-1], target_tensor)
113 |
114 | class VGGLoss(nn.Module):
115 | def __init__(self, gpu_ids):
116 | super(VGGLoss, self).__init__()
117 | self.vgg = Vgg19().cuda()
118 | self.criterion = nn.L1Loss()
119 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
120 |
121 | def forward(self, x, y):
122 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
123 | loss = 0
124 | for i in range(len(x_vgg)):
125 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
126 | return loss
127 |
128 | ##############################################################################
129 | # Generator
130 | ##############################################################################
131 | class LocalEnhancer(nn.Module):
132 | def __init__(self, input_nc, output_nc, ngf=32, n_downsample_global=3, n_blocks_global=9,
133 | n_local_enhancers=1, n_blocks_local=3, norm_layer=nn.BatchNorm2d, padding_type='reflect'):
134 | super(LocalEnhancer, self).__init__()
135 | self.n_local_enhancers = n_local_enhancers
136 |
137 | ###### global generator model #####
138 | ngf_global = ngf * (2**n_local_enhancers)
139 | model_global = GlobalGenerator(input_nc, output_nc, ngf_global, n_downsample_global, n_blocks_global, norm_layer).model
140 | model_global = [model_global[i] for i in range(len(model_global)-3)] # get rid of final convolution layers
141 | self.model = nn.Sequential(*model_global)
142 |
143 | ###### local enhancer layers #####
144 | for n in range(1, n_local_enhancers+1):
145 | ### downsample
146 | ngf_global = ngf * (2**(n_local_enhancers-n))
147 | model_downsample = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0),
148 | norm_layer(ngf_global), nn.ReLU(True),
149 | nn.Conv2d(ngf_global, ngf_global * 2, kernel_size=3, stride=2, padding=1),
150 | norm_layer(ngf_global * 2), nn.ReLU(True)]
151 | ### residual blocks
152 | model_upsample = []
153 | for i in range(n_blocks_local):
154 | model_upsample += [ResnetBlock(ngf_global * 2, padding_type=padding_type, norm_layer=norm_layer)]
155 |
156 | ### upsample
157 | model_upsample += [nn.ConvTranspose2d(ngf_global * 2, ngf_global, kernel_size=3, stride=2, padding=1, output_padding=1),
158 | norm_layer(ngf_global), nn.ReLU(True)]
159 |
160 | ### final convolution
161 | if n == n_local_enhancers:
162 | model_upsample += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
163 |
164 | setattr(self, 'model'+str(n)+'_1', nn.Sequential(*model_downsample))
165 | setattr(self, 'model'+str(n)+'_2', nn.Sequential(*model_upsample))
166 |
167 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
168 |
169 | def forward(self, input):
170 | ### create input pyramid
171 | input_downsampled = [input]
172 | for i in range(self.n_local_enhancers):
173 | input_downsampled.append(self.downsample(input_downsampled[-1]))
174 |
175 | ### output at coarest level
176 | output_prev = self.model(input_downsampled[-1])
177 | ### build up one layer at a time
178 | for n_local_enhancers in range(1, self.n_local_enhancers+1):
179 | model_downsample = getattr(self, 'model'+str(n_local_enhancers)+'_1')
180 | model_upsample = getattr(self, 'model'+str(n_local_enhancers)+'_2')
181 | input_i = input_downsampled[self.n_local_enhancers-n_local_enhancers]
182 | output_prev = model_upsample(model_downsample(input_i) + output_prev)
183 | return output_prev
184 |
185 | class GlobalGenerator(nn.Module):
186 | def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
187 | padding_type='reflect'):
188 | assert(n_blocks >= 0)
189 | super(GlobalGenerator, self).__init__()
190 | activation = nn.ReLU(True)
191 |
192 | model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
193 | ### downsample
194 | for i in range(n_downsampling):
195 | mult = 2**i
196 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
197 | norm_layer(ngf * mult * 2), activation]
198 |
199 | ### resnet blocks
200 | mult = 2**n_downsampling
201 | for i in range(n_blocks):
202 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
203 |
204 | ### upsample
205 | for i in range(n_downsampling):
206 | mult = 2**(n_downsampling - i)
207 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
208 | norm_layer(int(ngf * mult / 2)), activation]
209 | model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
210 | self.model = nn.Sequential(*model)
211 |
212 | def forward(self, input):
213 | return self.model(input)
214 |
215 | # Define a resnet block
216 | class ResnetBlock(nn.Module):
217 | def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
218 | super(ResnetBlock, self).__init__()
219 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
220 |
221 | def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
222 | conv_block = []
223 | p = 0
224 | if padding_type == 'reflect':
225 | conv_block += [nn.ReflectionPad2d(1)]
226 | elif padding_type == 'replicate':
227 | conv_block += [nn.ReplicationPad2d(1)]
228 | elif padding_type == 'zero':
229 | p = 1
230 | else:
231 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
232 |
233 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
234 | norm_layer(dim),
235 | activation]
236 | if use_dropout:
237 | conv_block += [nn.Dropout(0.5)]
238 |
239 | p = 0
240 | if padding_type == 'reflect':
241 | conv_block += [nn.ReflectionPad2d(1)]
242 | elif padding_type == 'replicate':
243 | conv_block += [nn.ReplicationPad2d(1)]
244 | elif padding_type == 'zero':
245 | p = 1
246 | else:
247 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
248 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
249 | norm_layer(dim)]
250 |
251 | return nn.Sequential(*conv_block)
252 |
253 | def forward(self, x):
254 | out = x + self.conv_block(x)
255 | return out
256 |
257 | class Encoder(nn.Module):
258 | def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d):
259 | super(Encoder, self).__init__()
260 | self.output_nc = output_nc
261 |
262 | model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
263 | norm_layer(ngf), nn.ReLU(True)]
264 | ### downsample
265 | for i in range(n_downsampling):
266 | mult = 2**i
267 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
268 | norm_layer(ngf * mult * 2), nn.ReLU(True)]
269 |
270 | ### upsample
271 | for i in range(n_downsampling):
272 | mult = 2**(n_downsampling - i)
273 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
274 | norm_layer(int(ngf * mult / 2)), nn.ReLU(True)]
275 |
276 | model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
277 | self.model = nn.Sequential(*model)
278 |
279 | def forward(self, input, inst):
280 | outputs = self.model(input)
281 |
282 | # instance-wise average pooling
283 | outputs_mean = outputs.clone()
284 | inst_list = np.unique(inst.cpu().numpy().astype(int))
285 | for i in inst_list:
286 | indices = (inst == i).nonzero() # n x 4
287 | for j in range(self.output_nc):
288 | output_ins = outputs[indices[:,0], indices[:,1] + j, indices[:,2], indices[:,3]]
289 | mean_feat = torch.mean(output_ins).expand_as(output_ins)
290 | outputs_mean[indices[:,0], indices[:,1] + j, indices[:,2], indices[:,3]] = mean_feat
291 | return outputs_mean
292 |
293 | class MultiscaleDiscriminator(nn.Module):
294 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
295 | use_sigmoid=False, num_D=3, getIntermFeat=False):
296 | super(MultiscaleDiscriminator, self).__init__()
297 | self.num_D = num_D
298 | self.n_layers = n_layers
299 | self.getIntermFeat = getIntermFeat
300 |
301 | for i in range(num_D):
302 | netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
303 | if getIntermFeat:
304 | for j in range(n_layers+2):
305 | setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))
306 | else:
307 | setattr(self, 'layer'+str(i), netD.model)
308 |
309 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
310 |
311 | def singleD_forward(self, model, input):
312 | if self.getIntermFeat:
313 | result = [input]
314 | for i in range(len(model)):
315 | result.append(model[i](result[-1]))
316 | return result[1:]
317 | else:
318 | return [model(input)]
319 |
320 | def forward(self, input):
321 | num_D = self.num_D
322 | result = []
323 | input_downsampled = input
324 | for i in range(num_D):
325 | if self.getIntermFeat:
326 | model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
327 | else:
328 | model = getattr(self, 'layer'+str(num_D-1-i))
329 | result.append(self.singleD_forward(model, input_downsampled))
330 | if i != (num_D-1):
331 | input_downsampled = self.downsample(input_downsampled)
332 | return result
333 |
334 | # Defines the PatchGAN discriminator with the specified arguments.
335 | class NLayerDiscriminator(nn.Module):
336 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
337 | super(NLayerDiscriminator, self).__init__()
338 | self.getIntermFeat = getIntermFeat
339 | self.n_layers = n_layers
340 |
341 | kw = 4
342 | padw = int(np.ceil((kw-1.0)/2))
343 | sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
344 |
345 | nf = ndf
346 | for n in range(1, n_layers):
347 | nf_prev = nf
348 | nf = min(nf * 2, 512)
349 | sequence += [[
350 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
351 | norm_layer(nf), nn.LeakyReLU(0.2, True)
352 | ]]
353 |
354 | nf_prev = nf
355 | nf = min(nf * 2, 512)
356 | sequence += [[
357 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
358 | norm_layer(nf),
359 | nn.LeakyReLU(0.2, True)
360 | ]]
361 |
362 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
363 |
364 | if use_sigmoid:
365 | sequence += [[nn.Sigmoid()]]
366 |
367 | if getIntermFeat:
368 | for n in range(len(sequence)):
369 | setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
370 | else:
371 | sequence_stream = []
372 | for n in range(len(sequence)):
373 | sequence_stream += sequence[n]
374 | self.model = nn.Sequential(*sequence_stream)
375 |
376 | def forward(self, input):
377 | if self.getIntermFeat:
378 | res = [input]
379 | for n in range(self.n_layers+2):
380 | model = getattr(self, 'model'+str(n))
381 | res.append(model(res[-1]))
382 | return res[1:]
383 | else:
384 | return self.model(input)
385 |
386 | from torchvision import models
387 | class Vgg19(torch.nn.Module):
388 | def __init__(self, requires_grad=False):
389 | super(Vgg19, self).__init__()
390 | vgg_pretrained_features = models.vgg19(pretrained=True).features
391 | self.slice1 = torch.nn.Sequential()
392 | self.slice2 = torch.nn.Sequential()
393 | self.slice3 = torch.nn.Sequential()
394 | self.slice4 = torch.nn.Sequential()
395 | self.slice5 = torch.nn.Sequential()
396 | for x in range(2):
397 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
398 | for x in range(2, 7):
399 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
400 | for x in range(7, 12):
401 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
402 | for x in range(12, 21):
403 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
404 | for x in range(21, 30):
405 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
406 | if not requires_grad:
407 | for param in self.parameters():
408 | param.requires_grad = False
409 |
410 | def forward(self, X):
411 | h_relu1 = self.slice1(X)
412 | h_relu2 = self.slice2(h_relu1)
413 | h_relu3 = self.slice3(h_relu2)
414 | h_relu4 = self.slice4(h_relu3)
415 | h_relu5 = self.slice5(h_relu4)
416 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
417 | return out
418 |
--------------------------------------------------------------------------------
/models/pix2pixHD_model.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import numpy as np
4 | import torch
5 | import os
6 | from torch.autograd import Variable
7 | from util.image_pool import ImagePool
8 | from .base_model import BaseModel
9 | from . import networks
10 |
11 | class Pix2PixHDModel(BaseModel):
12 | def name(self):
13 | return 'Pix2PixHDModel'
14 |
15 | def initialize(self, opt):
16 | BaseModel.initialize(self, opt)
17 | if opt.resize_or_crop != 'none': # when training at full res this causes OOM
18 | torch.backends.cudnn.benchmark = True
19 | self.isTrain = opt.isTrain
20 | self.use_features = opt.instance_feat or opt.label_feat
21 | self.gen_features = self.use_features and not self.opt.load_features
22 | input_nc = opt.label_nc if opt.label_nc != 0 else 3
23 |
24 | ##### define networks
25 | # Generator network
26 | netG_input_nc = input_nc
27 | if not opt.no_instance:
28 | netG_input_nc += 1
29 | if self.use_features:
30 | netG_input_nc += opt.feat_num
31 | self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,
32 | opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
33 | opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)
34 |
35 | # Discriminator network
36 | if self.isTrain:
37 | use_sigmoid = opt.no_lsgan
38 | netD_input_nc = input_nc + opt.output_nc
39 | if not opt.no_instance:
40 | netD_input_nc += 1
41 | self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid,
42 | opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
43 |
44 | ### Encoder network
45 | if self.gen_features:
46 | self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder',
47 | opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)
48 |
49 | print('---------- Networks initialized -------------')
50 |
51 | # load networks
52 | if not self.isTrain or opt.continue_train or opt.load_pretrain:
53 | pretrained_path = '' if not self.isTrain else opt.load_pretrain
54 | self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
55 | if self.isTrain:
56 | self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
57 | if self.gen_features:
58 | self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)
59 |
60 | # set loss functions and optimizers
61 | if self.isTrain:
62 | if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
63 | raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
64 | self.fake_pool = ImagePool(opt.pool_size)
65 | self.old_lr = opt.lr
66 |
67 | # define loss functions
68 | self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
69 | self.criterionFeat = torch.nn.L1Loss()
70 | if not opt.no_vgg_loss:
71 | self.criterionVGG = networks.VGGLoss(self.gpu_ids)
72 |
73 | # Names so we can breakout loss
74 | self.loss_names = ['G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake']
75 |
76 | # initialize optimizers
77 | # optimizer G
78 | if opt.niter_fix_global > 0:
79 | print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
80 | params_dict = dict(self.netG.named_parameters())
81 | params = []
82 | for key, value in params_dict.items():
83 | if key.startswith('model' + str(opt.n_local_enhancers)):
84 | params += [{'params':[value],'lr':opt.lr}]
85 | else:
86 | params += [{'params':[value],'lr':0.0}]
87 | else:
88 | params = list(self.netG.parameters())
89 | if self.gen_features:
90 | params += list(self.netE.parameters())
91 | self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
92 |
93 | # optimizer D
94 | params = list(self.netD.parameters())
95 | self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
96 |
97 | def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
98 | if self.opt.label_nc == 0:
99 | input_label = label_map.data.cuda()
100 | else:
101 | # create one-hot vector for label map
102 | size = label_map.size()
103 | oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
104 | input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
105 | input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
106 |
107 | # get edges from instance map
108 | if not self.opt.no_instance:
109 | inst_map = inst_map.data.cuda()
110 | edge_map = self.get_edges(inst_map)
111 | input_label = torch.cat((input_label, edge_map), dim=1)
112 | input_label = Variable(input_label, volatile=infer)
113 |
114 | # real images for training
115 | if real_image is not None:
116 | real_image = Variable(real_image.data.cuda())
117 |
118 | # instance map for feature encoding
119 | if self.use_features:
120 | # get precomputed feature maps
121 | if self.opt.load_features:
122 | feat_map = Variable(feat_map.data.cuda())
123 |
124 | return input_label, inst_map, real_image, feat_map
125 |
126 | def discriminate(self, input_label, test_image, use_pool=False):
127 | input_concat = torch.cat((input_label, test_image.detach()), dim=1)
128 | if use_pool:
129 | fake_query = self.fake_pool.query(input_concat)
130 | return self.netD.forward(fake_query)
131 | else:
132 | return self.netD.forward(input_concat)
133 |
134 | def forward(self, label, inst, image, feat, infer=False):
135 | # Encode Inputs
136 | input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)
137 |
138 | # Fake Generation
139 | if self.use_features:
140 | if not self.opt.load_features:
141 | feat_map = self.netE.forward(real_image, inst_map)
142 | input_concat = torch.cat((input_label, feat_map), dim=1)
143 | else:
144 | input_concat = input_label
145 | fake_image = self.netG.forward(input_concat)
146 |
147 | # Fake Detection and Loss
148 | pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
149 | loss_D_fake = self.criterionGAN(pred_fake_pool, False)
150 |
151 | # Real Detection and Loss
152 | pred_real = self.discriminate(input_label, real_image)
153 | loss_D_real = self.criterionGAN(pred_real, True)
154 |
155 | # GAN loss (Fake Passability Loss)
156 | pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
157 | loss_G_GAN = self.criterionGAN(pred_fake, True)
158 |
159 | # GAN feature matching loss
160 | loss_G_GAN_Feat = 0
161 | if not self.opt.no_ganFeat_loss:
162 | feat_weights = 4.0 / (self.opt.n_layers_D + 1)
163 | D_weights = 1.0 / self.opt.num_D
164 | for i in range(self.opt.num_D):
165 | for j in range(len(pred_fake[i])-1):
166 | loss_G_GAN_Feat += D_weights * feat_weights * \
167 | self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
168 |
169 | # VGG feature matching loss
170 | loss_G_VGG = 0
171 | if not self.opt.no_vgg_loss:
172 | loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
173 |
174 | # Only return the fake_B image if necessary to save BW
175 | return [ [ loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ], None if not infer else fake_image ]
176 |
177 | def inference(self, label, inst):
178 | # Encode Inputs
179 | input_label, inst_map, _, _ = self.encode_input(Variable(label), Variable(inst), infer=True)
180 |
181 | # Fake Generation
182 | if self.use_features:
183 | # sample clusters from precomputed features
184 | feat_map = self.sample_features(inst_map)
185 | input_concat = torch.cat((input_label, feat_map), dim=1)
186 | else:
187 | input_concat = input_label
188 | fake_image = self.netG.forward(input_concat)
189 | return fake_image
190 |
191 | def sample_features(self, inst):
192 | # read precomputed feature clusters
193 | cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)
194 | features_clustered = np.load(cluster_path).item()
195 |
196 | # randomly sample from the feature clusters
197 | inst_np = inst.cpu().numpy().astype(int)
198 | feat_map = torch.cuda.FloatTensor(1, self.opt.feat_num, inst.size()[2], inst.size()[3])
199 | for i in np.unique(inst_np):
200 | label = i if i < 1000 else i//1000
201 | if label in features_clustered:
202 | feat = features_clustered[label]
203 | cluster_idx = np.random.randint(0, feat.shape[0])
204 |
205 | idx = (inst == i).nonzero()
206 | for k in range(self.opt.feat_num):
207 | feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
208 | return feat_map
209 |
210 | def encode_features(self, image, inst):
211 | image = Variable(image.cuda(), volatile=True)
212 | feat_num = self.opt.feat_num
213 | h, w = inst.size()[2], inst.size()[3]
214 | block_num = 32
215 | feat_map = self.netE.forward(image, inst.cuda())
216 | inst_np = inst.cpu().numpy().astype(int)
217 | feature = {}
218 | for i in range(self.opt.label_nc):
219 | feature[i] = np.zeros((0, feat_num+1))
220 | for i in np.unique(inst_np):
221 | label = i if i < 1000 else i//1000
222 | idx = (inst == i).nonzero()
223 | num = idx.size()[0]
224 | idx = idx[num//2,:]
225 | val = np.zeros((1, feat_num+1))
226 | for k in range(feat_num):
227 | val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]
228 | val[0, feat_num] = float(num) / (h * w // block_num)
229 | feature[label] = np.append(feature[label], val, axis=0)
230 | return feature
231 |
232 | def get_edges(self, t):
233 | edge = torch.cuda.ByteTensor(t.size()).zero_()
234 | edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
235 | edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
236 | edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
237 | edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
238 | return edge.float()
239 |
240 | def save(self, which_epoch):
241 | self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
242 | self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
243 | if self.gen_features:
244 | self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)
245 |
246 | def update_fixed_params(self):
247 | # after fixing the global generator for a number of iterations, also start finetuning it
248 | params = list(self.netG.parameters())
249 | if self.gen_features:
250 | params += list(self.netE.parameters())
251 | self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
252 | print('------------ Now also finetuning global generator -----------')
253 |
254 | def update_learning_rate(self):
255 | lrd = self.opt.lr / self.opt.niter_decay
256 | lr = self.old_lr - lrd
257 | for param_group in self.optimizer_D.param_groups:
258 | param_group['lr'] = lr
259 | for param_group in self.optimizer_G.param_groups:
260 | param_group['lr'] = lr
261 | print('update learning rate: %f -> %f' % (self.old_lr, lr))
262 | self.old_lr = lr
263 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/options/__init__.py
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import argparse
4 | import os
5 | from util import util
6 | import torch
7 |
8 | class BaseOptions():
9 | def __init__(self):
10 | self.parser = argparse.ArgumentParser()
11 | self.initialized = False
12 |
13 | def initialize(self):
14 | # experiment specifics
15 | self.parser.add_argument('--name', type=str, default='label2city', help='name of the experiment. It decides where to store samples and models')
16 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
17 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
18 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
19 | self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
20 |
21 | # input/output sizes
22 | self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
23 | self.parser.add_argument('--loadSize', type=int, default=1024, help='scale images to this size')
24 | self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size')
25 | self.parser.add_argument('--label_nc', type=int, default=35, help='# of input image channels')
26 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
27 |
28 | # for setting inputs
29 | self.parser.add_argument('--dataroot', type=str, default='./datasets/cityscapes/')
30 | self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
31 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
32 | self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
33 | self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
34 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
35 |
36 | # for displays
37 | self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size')
38 | self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed')
39 |
40 | # for generator
41 | self.parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG')
42 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
43 | self.parser.add_argument('--n_downsample_global', type=int, default=4, help='number of downsampling layers in netG')
44 | self.parser.add_argument('--n_blocks_global', type=int, default=9, help='number of residual blocks in the global generator network')
45 | self.parser.add_argument('--n_blocks_local', type=int, default=3, help='number of residual blocks in the local enhancer network')
46 | self.parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use')
47 | self.parser.add_argument('--niter_fix_global', type=int, default=0, help='number of epochs that we only train the outmost local enhancer')
48 |
49 | # for instance-wise features
50 | self.parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input')
51 | self.parser.add_argument('--instance_feat', action='store_true', help='if specified, add encoded instance features as input')
52 | self.parser.add_argument('--label_feat', action='store_true', help='if specified, add encoded label features as input')
53 | self.parser.add_argument('--feat_num', type=int, default=3, help='vector length for encoded features')
54 | self.parser.add_argument('--load_features', action='store_true', help='if specified, load precomputed feature maps')
55 | self.parser.add_argument('--n_downsample_E', type=int, default=4, help='# of downsampling layers in encoder')
56 | self.parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer')
57 | self.parser.add_argument('--n_clusters', type=int, default=10, help='number of clusters for features')
58 |
59 | self.initialized = True
60 |
61 | def parse(self, save=True):
62 | if not self.initialized:
63 | self.initialize()
64 | self.opt = self.parser.parse_args()
65 | self.opt.isTrain = self.isTrain # train or test
66 |
67 | str_ids = self.opt.gpu_ids.split(',')
68 | self.opt.gpu_ids = []
69 | for str_id in str_ids:
70 | id = int(str_id)
71 | if id >= 0:
72 | self.opt.gpu_ids.append(id)
73 |
74 | # set gpu ids
75 | if len(self.opt.gpu_ids) > 0:
76 | torch.cuda.set_device(self.opt.gpu_ids[0])
77 |
78 | args = vars(self.opt)
79 |
80 | print('------------ Options -------------')
81 | for k, v in sorted(args.items()):
82 | print('%s: %s' % (str(k), str(v)))
83 | print('-------------- End ----------------')
84 |
85 | # save to the disk
86 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
87 | util.mkdirs(expr_dir)
88 | if save and not self.opt.continue_train:
89 | file_name = os.path.join(expr_dir, 'opt.txt')
90 | with open(file_name, 'wt') as opt_file:
91 | opt_file.write('------------ Options -------------\n')
92 | for k, v in sorted(args.items()):
93 | opt_file.write('%s: %s\n' % (str(k), str(v)))
94 | opt_file.write('-------------- End ----------------\n')
95 | return self.opt
96 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | from .base_options import BaseOptions
4 |
5 | class TestOptions(BaseOptions):
6 | def initialize(self):
7 | BaseOptions.initialize(self)
8 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
9 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
10 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
11 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
12 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
13 | self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
14 | self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features')
15 | self.isTrain = False
16 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | from .base_options import BaseOptions
4 |
5 | class TrainOptions(BaseOptions):
6 | def initialize(self):
7 | BaseOptions.initialize(self)
8 | # for displays
9 | self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
10 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
11 | self.parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results')
12 | self.parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs')
13 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
14 | self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration')
15 |
16 | # for training
17 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
18 | self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location')
19 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
20 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
21 | self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
22 | self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
23 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
24 | self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
25 |
26 | # for discriminators
27 | self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use')
28 | self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
29 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
30 | self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
31 | self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
32 | self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss')
33 | self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
34 | self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images')
35 |
36 | self.isTrain = True
37 |
--------------------------------------------------------------------------------
/precompute_feature_maps.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | from options.train_options import TrainOptions
4 | from data.data_loader import CreateDataLoader
5 | from models.models import create_model
6 | import os
7 | import util.util as util
8 | from torch.autograd import Variable
9 | import torch.nn as nn
10 |
11 | opt = TrainOptions().parse()
12 | opt.nThreads = 1
13 | opt.batchSize = 1
14 | opt.serial_batches = True
15 | opt.no_flip = True
16 | opt.instance_feat = True
17 |
18 | name = 'features'
19 | save_path = os.path.join(opt.checkpoints_dir, opt.name)
20 |
21 | ############ Initialize #########
22 | data_loader = CreateDataLoader(opt)
23 | dataset = data_loader.load_data()
24 | dataset_size = len(data_loader)
25 | model = create_model(opt)
26 | util.mkdirs(os.path.join(opt.dataroot, opt.phase + '_feat'))
27 |
28 | ######## Save precomputed feature maps for 1024p training #######
29 | for i, data in enumerate(dataset):
30 | print('%d / %d images' % (i+1, dataset_size))
31 | feat_map = model.module.netE.forward(Variable(data['image'].cuda(), volatile=True), data['inst'].cuda())
32 | feat_map = nn.Upsample(scale_factor=2, mode='nearest')(feat_map)
33 | image_numpy = util.tensor2im(feat_map.data[0])
34 | save_path = data['path'][0].replace('/train_label/', '/train_feat/')
35 | util.save_image(image_numpy, save_path)
--------------------------------------------------------------------------------
/scripts/test_1024p.sh:
--------------------------------------------------------------------------------
1 | ################################ Testing ################################
2 | # labels only
3 | python test.py --name label2city_1024p --netG local --ngf 32 --resize_or_crop none
--------------------------------------------------------------------------------
/scripts/test_1024p_feat.sh:
--------------------------------------------------------------------------------
1 | ################################ Testing ################################
2 | # first precompute and cluster all features
3 | python encode_features.py --name label2city_1024p_feat --netG local --ngf 32 --resize_or_crop none;
4 | # use instance-wise features
5 | python test.py --name label2city_1024p_feat ---netG local --ngf 32 --resize_or_crop none --instance_feat
--------------------------------------------------------------------------------
/scripts/test_512p.sh:
--------------------------------------------------------------------------------
1 | ################################ Testing ################################
2 | # labels only
3 | python test.py --name label2city_512p
--------------------------------------------------------------------------------
/scripts/test_512p_feat.sh:
--------------------------------------------------------------------------------
1 | ################################ Testing ################################
2 | # first precompute and cluster all features
3 | python encode_features.py --name label2city_512p_feat;
4 | # use instance-wise features
5 | python test.py --name label2city_512p_feat --instance_feat
--------------------------------------------------------------------------------
/scripts/train_1024p_12G.sh:
--------------------------------------------------------------------------------
1 | ############## To train images at 2048 x 1024 resolution after training 1024 x 512 resolution models #############
2 | ##### Using GPUs with 12G memory (not tested)
3 | # Using labels only
4 | python train.py --name label2city_1024p --netG local --ngf 32 --num_D 3 --load_pretrain checkpoints/label2city_512p/ --niter_fix_global 20 --resize_or_crop crop --fineSize 1024
--------------------------------------------------------------------------------
/scripts/train_1024p_24G.sh:
--------------------------------------------------------------------------------
1 | ############## To train images at 2048 x 1024 resolution after training 1024 x 512 resolution models #############
2 | ######## Using GPUs with 24G memory
3 | # Using labels only
4 | python train.py --name label2city_1024p --netG local --ngf 32 --num_D 3 --load_pretrain checkpoints/label2city_512p/ --niter 50 --niter_decay 50 --niter_fix_global 10 --resize_or_crop none
--------------------------------------------------------------------------------
/scripts/train_1024p_feat_12G.sh:
--------------------------------------------------------------------------------
1 | ############## To train images at 2048 x 1024 resolution after training 1024 x 512 resolution models #############
2 | ##### Using GPUs with 12G memory (not tested)
3 | # First precompute feature maps and save them
4 | python precompute_feature_maps.py --name label2city_512p_feat;
5 | # Adding instances and encoded features
6 | python train.py --name label2city_1024p_feat --netG local --ngf 32 --num_D 3 --load_pretrain checkpoints/label2city_512p_feat/ --niter_fix_global 20 --resize_or_crop crop --fineSize 896 --instance_feat --load_features
--------------------------------------------------------------------------------
/scripts/train_1024p_feat_24G.sh:
--------------------------------------------------------------------------------
1 | ############## To train images at 2048 x 1024 resolution after training 1024 x 512 resolution models #############
2 | ######## Using GPUs with 24G memory
3 | # First precompute feature maps and save them
4 | python precompute_feature_maps.py --name label2city_512p_feat;
5 | # Adding instances and encoded features
6 | python train.py --name label2city_1024p_feat --netG local --ngf 32 --num_D 3 --load_pretrain checkpoints/label2city_512p_feat/ --niter 50 --niter_decay 50 --niter_fix_global 10 --resize_or_crop none --instance_feat --load_features
--------------------------------------------------------------------------------
/scripts/train_512p.sh:
--------------------------------------------------------------------------------
1 | ### Using labels only
2 | python train.py --name label2city_512p
--------------------------------------------------------------------------------
/scripts/train_512p_feat.sh:
--------------------------------------------------------------------------------
1 | ### Adding instances and encoded features
2 | python train.py --name label2city_512p_feat --instance_feat
--------------------------------------------------------------------------------
/scripts/train_512p_multigpu.sh:
--------------------------------------------------------------------------------
1 | ######## Multi-GPU training example #######
2 | python train.py --name label2city_512p --batchSize 8 --gpu_ids 0,1,2,3,4,5,6,7
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import os
4 | from collections import OrderedDict
5 | from options.test_options import TestOptions
6 | from data.data_loader import CreateDataLoader
7 | from models.models import create_model
8 | import util.util as util
9 | from util.visualizer import Visualizer
10 | from util import html
11 |
12 | opt = TestOptions().parse(save=False)
13 | opt.nThreads = 1 # test code only supports nThreads = 1
14 | opt.batchSize = 1 # test code only supports batchSize = 1
15 | opt.serial_batches = True # no shuffle
16 | opt.no_flip = True # no flip
17 |
18 | data_loader = CreateDataLoader(opt)
19 | dataset = data_loader.load_data()
20 | model = create_model(opt)
21 | visualizer = Visualizer(opt)
22 | # create website
23 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
24 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
25 | # test
26 | for i, data in enumerate(dataset):
27 | if i >= opt.how_many:
28 | break
29 | generated = model.inference(data['label'], data['inst'])
30 | visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
31 | ('synthesized_image', util.tensor2im(generated.data[0]))])
32 | img_path = data['path']
33 | print('process image... %s' % img_path)
34 | visualizer.save_images(webpage, visuals, img_path)
35 |
36 | webpage.save()
37 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import time
4 | from collections import OrderedDict
5 | from options.train_options import TrainOptions
6 | from data.data_loader import CreateDataLoader
7 | from models.models import create_model
8 | import util.util as util
9 | from util.visualizer import Visualizer
10 | import os
11 | import numpy as np
12 | import torch
13 | from torch.autograd import Variable
14 |
15 | opt = TrainOptions().parse()
16 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
17 | if opt.continue_train:
18 | try:
19 | start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)
20 | except:
21 | start_epoch, epoch_iter = 1, 0
22 | print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
23 | else:
24 | start_epoch, epoch_iter = 1, 0
25 |
26 | if opt.debug:
27 | opt.display_freq = 1
28 | opt.print_freq = 1
29 | opt.niter = 1
30 | opt.niter_decay = 0
31 | opt.max_dataset_size = 10
32 |
33 | data_loader = CreateDataLoader(opt)
34 | dataset = data_loader.load_data()
35 | dataset_size = len(data_loader)
36 | print('#training images = %d' % dataset_size)
37 |
38 | model = create_model(opt)
39 | visualizer = Visualizer(opt)
40 |
41 | total_steps = (start_epoch-1) * dataset_size + epoch_iter
42 | for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
43 | epoch_start_time = time.time()
44 | if epoch != start_epoch:
45 | epoch_iter = epoch_iter % dataset_size
46 | for i, data in enumerate(dataset, start=epoch_iter):
47 | iter_start_time = time.time()
48 | total_steps += opt.batchSize
49 | epoch_iter += opt.batchSize
50 |
51 | # whether to collect output images
52 | save_fake = total_steps % opt.display_freq == 0
53 |
54 | ############## Forward Pass ######################
55 | losses, generated = model(Variable(data['label']), Variable(data['inst']),
56 | Variable(data['image']), Variable(data['feat']), infer=save_fake)
57 |
58 | # sum per device losses
59 | losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ]
60 | loss_dict = dict(zip(model.module.loss_names, losses))
61 |
62 | # calculate final loss scalar
63 | loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
64 | loss_G = loss_dict['G_GAN'] + loss_dict['G_GAN_Feat'] + loss_dict['G_VGG']
65 |
66 | ############### Backward Pass ####################
67 | # update generator weights
68 | model.module.optimizer_G.zero_grad()
69 | loss_G.backward()
70 | model.module.optimizer_G.step()
71 |
72 | # update discriminator weights
73 | model.module.optimizer_D.zero_grad()
74 | loss_D.backward()
75 | model.module.optimizer_D.step()
76 |
77 | #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])
78 |
79 | ############## Display results and errors ##########
80 | ### print out errors
81 | if total_steps % opt.print_freq == 0:
82 | errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
83 | t = (time.time() - iter_start_time) / opt.batchSize
84 | visualizer.print_current_errors(epoch, epoch_iter, errors, t)
85 | visualizer.plot_current_errors(errors, total_steps)
86 |
87 | ### display output images
88 | if save_fake:
89 | visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
90 | ('synthesized_image', util.tensor2im(generated.data[0])),
91 | ('real_image', util.tensor2im(data['image'][0]))])
92 | visualizer.display_current_results(visuals, epoch, total_steps)
93 |
94 | ### save latest model
95 | if total_steps % opt.save_latest_freq == 0:
96 | print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
97 | model.module.save('latest')
98 | np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
99 |
100 | # end of epoch
101 | iter_end_time = time.time()
102 | print('End of epoch %d / %d \t Time Taken: %d sec' %
103 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
104 |
105 | ### save model for this epoch
106 | if epoch % opt.save_epoch_freq == 0:
107 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
108 | model.module.save('latest')
109 | model.module.save(epoch)
110 | np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')
111 |
112 | ### instead of only training the local enhancer, train the entire network after certain iterations
113 | if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
114 | model.module.update_fixed_params()
115 |
116 | ### linearly decay learning rate after certain iterations
117 | if epoch > opt.niter:
118 | model.module.update_learning_rate()
119 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenxli/High-Resolution-Image-Synthesis-and-Semantic-Manipulation-with-Conditional-GANsl-/dd05da797863a13f6e45ec0a2d2ff3c7f8142f38/util/__init__.py
--------------------------------------------------------------------------------
/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import *
3 | import os
4 |
5 |
6 | class HTML:
7 | def __init__(self, web_dir, title, refresh=0):
8 | self.title = title
9 | self.web_dir = web_dir
10 | self.img_dir = os.path.join(self.web_dir, 'images')
11 | if not os.path.exists(self.web_dir):
12 | os.makedirs(self.web_dir)
13 | if not os.path.exists(self.img_dir):
14 | os.makedirs(self.img_dir)
15 |
16 | self.doc = dominate.document(title=title)
17 | if refresh > 0:
18 | with self.doc.head:
19 | meta(http_equiv="refresh", content=str(refresh))
20 |
21 | def get_image_dir(self):
22 | return self.img_dir
23 |
24 | def add_header(self, str):
25 | with self.doc:
26 | h3(str)
27 |
28 | def add_table(self, border=1):
29 | self.t = table(border=border, style="table-layout: fixed;")
30 | self.doc.add(self.t)
31 |
32 | def add_images(self, ims, txts, links, width=512):
33 | self.add_table()
34 | with self.t:
35 | with tr():
36 | for im, txt, link in zip(ims, txts, links):
37 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
38 | with p():
39 | with a(href=os.path.join('images', link)):
40 | img(style="width:%dpx" % (width), src=os.path.join('images', im))
41 | br()
42 | p(txt)
43 |
44 | def save(self):
45 | html_file = '%s/index.html' % self.web_dir
46 | f = open(html_file, 'wt')
47 | f.write(self.doc.render())
48 | f.close()
49 |
50 |
51 | if __name__ == '__main__':
52 | html = HTML('web/', 'test_html')
53 | html.add_header('hello world')
54 |
55 | ims = []
56 | txts = []
57 | links = []
58 | for n in range(4):
59 | ims.append('image_%d.jpg' % n)
60 | txts.append('text_%d' % n)
61 | links.append('image_%d.jpg' % n)
62 | html.add_images(ims, txts, links)
63 | html.save()
64 |
--------------------------------------------------------------------------------
/util/image_pool.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from torch.autograd import Variable
4 | class ImagePool():
5 | def __init__(self, pool_size):
6 | self.pool_size = pool_size
7 | if self.pool_size > 0:
8 | self.num_imgs = 0
9 | self.images = []
10 |
11 | def query(self, images):
12 | if self.pool_size == 0:
13 | return images
14 | return_images = []
15 | for image in images.data:
16 | image = torch.unsqueeze(image, 0)
17 | if self.num_imgs < self.pool_size:
18 | self.num_imgs = self.num_imgs + 1
19 | self.images.append(image)
20 | return_images.append(image)
21 | else:
22 | p = random.uniform(0, 1)
23 | if p > 0.5:
24 | random_id = random.randint(0, self.pool_size-1)
25 | tmp = self.images[random_id].clone()
26 | self.images[random_id] = image
27 | return_images.append(tmp)
28 | else:
29 | return_images.append(image)
30 | return_images = Variable(torch.cat(return_images, 0))
31 | return return_images
32 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | import numpy as np
6 | import os
7 |
8 | # Converts a Tensor into a Numpy array
9 | # |imtype|: the desired type of the converted numpy array
10 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
11 | if isinstance(image_tensor, list):
12 | image_numpy = []
13 | for i in range(len(image_tensor)):
14 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
15 | return image_numpy
16 | image_numpy = image_tensor.cpu().float().numpy()
17 | if normalize:
18 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
19 | else:
20 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
21 | image_numpy = np.clip(image_numpy, 0, 255)
22 | if image_numpy.shape[2] == 1:
23 | image_numpy = image_numpy[:,:,0]
24 | return image_numpy.astype(imtype)
25 |
26 | # Converts a one-hot tensor into a colorful label map
27 | def tensor2label(label_tensor, n_label, imtype=np.uint8):
28 | if n_label == 0:
29 | return tensor2im(label_tensor, imtype)
30 | label_tensor = label_tensor.cpu().float()
31 | if label_tensor.size()[0] > 1:
32 | label_tensor = label_tensor.max(0, keepdim=True)[1]
33 | label_tensor = Colorize(n_label)(label_tensor)
34 | label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
35 | return label_numpy.astype(imtype)
36 |
37 | def save_image(image_numpy, image_path):
38 | image_pil = Image.fromarray(image_numpy)
39 | image_pil.save(image_path)
40 |
41 | def mkdirs(paths):
42 | if isinstance(paths, list) and not isinstance(paths, str):
43 | for path in paths:
44 | mkdir(path)
45 | else:
46 | mkdir(paths)
47 |
48 | def mkdir(path):
49 | if not os.path.exists(path):
50 | os.makedirs(path)
51 |
52 | ###############################################################################
53 | # Code from
54 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py
55 | # Modified so it complies with the Citscape label map colors
56 | ###############################################################################
57 | def uint82bin(n, count=8):
58 | """returns the binary of integer n, count refers to amount of bits"""
59 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
60 |
61 | def labelcolormap(N):
62 | if N == 35: # cityscape
63 | cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81),
64 | (128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153),
65 | (180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0),
66 | (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70),
67 | ( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)],
68 | dtype=np.uint8)
69 | else:
70 | cmap = np.zeros((N, 3), dtype=np.uint8)
71 | for i in range(N):
72 | r, g, b = 0, 0, 0
73 | id = i
74 | for j in range(7):
75 | str_id = uint82bin(id)
76 | r = r ^ (np.uint8(str_id[-1]) << (7-j))
77 | g = g ^ (np.uint8(str_id[-2]) << (7-j))
78 | b = b ^ (np.uint8(str_id[-3]) << (7-j))
79 | id = id >> 3
80 | cmap[i, 0] = r
81 | cmap[i, 1] = g
82 | cmap[i, 2] = b
83 | return cmap
84 |
85 | class Colorize(object):
86 | def __init__(self, n=35):
87 | self.cmap = labelcolormap(n)
88 | self.cmap = torch.from_numpy(self.cmap[:n])
89 |
90 | def __call__(self, gray_image):
91 | size = gray_image.size()
92 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
93 |
94 | for label in range(0, len(self.cmap)):
95 | mask = (label == gray_image[0]).cpu()
96 | color_image[0][mask] = self.cmap[label][0]
97 | color_image[1][mask] = self.cmap[label][1]
98 | color_image[2][mask] = self.cmap[label][2]
99 |
100 | return color_image
101 |
--------------------------------------------------------------------------------
/util/visualizer.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 | import numpy as np
4 | import os
5 | import ntpath
6 | import time
7 | from . import util
8 | from . import html
9 | import scipy.misc
10 | try:
11 | from StringIO import StringIO # Python 2.7
12 | except ImportError:
13 | from io import BytesIO # Python 3.x
14 |
15 | class Visualizer():
16 | def __init__(self, opt):
17 | # self.opt = opt
18 | self.tf_log = opt.tf_log
19 | self.use_html = opt.isTrain and not opt.no_html
20 | self.win_size = opt.display_winsize
21 | self.name = opt.name
22 | if self.tf_log:
23 | import tensorflow as tf
24 | self.tf = tf
25 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs')
26 | self.writer = tf.summary.FileWriter(self.log_dir)
27 |
28 | if self.use_html:
29 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
30 | self.img_dir = os.path.join(self.web_dir, 'images')
31 | print('create web directory %s...' % self.web_dir)
32 | util.mkdirs([self.web_dir, self.img_dir])
33 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
34 | with open(self.log_name, "a") as log_file:
35 | now = time.strftime("%c")
36 | log_file.write('================ Training Loss (%s) ================\n' % now)
37 |
38 | # |visuals|: dictionary of images to display or save
39 | def display_current_results(self, visuals, epoch, step):
40 | if self.tf_log: # show images in tensorboard output
41 | img_summaries = []
42 | for label, image_numpy in visuals.items():
43 | # Write the image to a string
44 | try:
45 | s = StringIO()
46 | except:
47 | s = BytesIO()
48 | scipy.misc.toimage(image_numpy).save(s, format="jpeg")
49 | # Create an Image object
50 | img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1])
51 | # Create a Summary value
52 | img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum))
53 |
54 | # Create and write Summary
55 | summary = self.tf.Summary(value=img_summaries)
56 | self.writer.add_summary(summary, step)
57 |
58 | if self.use_html: # save images to a html file
59 | for label, image_numpy in visuals.items():
60 | if isinstance(image_numpy, list):
61 | for i in range(len(image_numpy)):
62 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i))
63 | util.save_image(image_numpy[i], img_path)
64 | else:
65 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.jpg' % (epoch, label))
66 | util.save_image(image_numpy, img_path)
67 |
68 | # update website
69 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=5)
70 | for n in range(epoch, 0, -1):
71 | webpage.add_header('epoch [%d]' % n)
72 | ims = []
73 | txts = []
74 | links = []
75 |
76 | for label, image_numpy in visuals.items():
77 | if isinstance(image_numpy, list):
78 | for i in range(len(image_numpy)):
79 | img_path = 'epoch%.3d_%s_%d.jpg' % (n, label, i)
80 | ims.append(img_path)
81 | txts.append(label+str(i))
82 | links.append(img_path)
83 | else:
84 | img_path = 'epoch%.3d_%s.jpg' % (n, label)
85 | ims.append(img_path)
86 | txts.append(label)
87 | links.append(img_path)
88 | if len(ims) < 10:
89 | webpage.add_images(ims, txts, links, width=self.win_size)
90 | else:
91 | num = int(round(len(ims)/2.0))
92 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size)
93 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size)
94 | webpage.save()
95 |
96 | # errors: dictionary of error labels and values
97 | def plot_current_errors(self, errors, step):
98 | if self.tf_log:
99 | for tag, value in errors.items():
100 | summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])
101 | self.writer.add_summary(summary, step)
102 |
103 | # errors: same format as |errors| of plotCurrentErrors
104 | def print_current_errors(self, epoch, i, errors, t):
105 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
106 | for k, v in errors.items():
107 | if v != 0:
108 | message += '%s: %.3f ' % (k, v)
109 |
110 | print(message)
111 | with open(self.log_name, "a") as log_file:
112 | log_file.write('%s\n' % message)
113 |
114 | # save image to the disk
115 | def save_images(self, webpage, visuals, image_path):
116 | image_dir = webpage.get_image_dir()
117 | short_path = ntpath.basename(image_path[0])
118 | name = os.path.splitext(short_path)[0]
119 |
120 | webpage.add_header(name)
121 | ims = []
122 | txts = []
123 | links = []
124 |
125 | for label, image_numpy in visuals.items():
126 | image_name = '%s_%s.jpg' % (name, label)
127 | save_path = os.path.join(image_dir, image_name)
128 | util.save_image(image_numpy, save_path)
129 |
130 | ims.append(image_name)
131 | txts.append(label)
132 | links.append(image_name)
133 | webpage.add_images(ims, txts, links, width=self.win_size)
134 |
--------------------------------------------------------------------------------