├── LICENSE ├── README.md ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── base_dataset.cpython-36.pyc │ ├── base_dataset.cpython-37.pyc │ ├── custom_dataset.cpython-36.pyc │ ├── custom_dataset.cpython-37.pyc │ ├── image_folder.cpython-36.pyc │ ├── image_folder.cpython-37.pyc │ ├── pix2pix_dataset.cpython-36.pyc │ └── pix2pix_dataset.cpython-37.pyc ├── base_dataset.py ├── custom_dataset.py ├── image_folder.py └── pix2pix_dataset.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── pix2pix_model.cpython-36.pyc │ └── pix2pix_model.cpython-37.pyc ├── networks │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── architecture.cpython-36.pyc │ │ ├── architecture.cpython-37.pyc │ │ ├── base_network.cpython-36.pyc │ │ ├── base_network.cpython-37.pyc │ │ ├── discriminator.cpython-36.pyc │ │ ├── discriminator.cpython-37.pyc │ │ ├── encoder.cpython-36.pyc │ │ ├── encoder.cpython-37.pyc │ │ ├── generator.cpython-36.pyc │ │ ├── generator.cpython-37.pyc │ │ ├── loss.cpython-36.pyc │ │ ├── loss.cpython-37.pyc │ │ ├── normalization.cpython-36.pyc │ │ └── normalization.cpython-37.pyc │ ├── architecture.py │ ├── base_network.py │ ├── discriminator.py │ ├── encoder.py │ ├── generator.py │ ├── loss.py │ ├── normalization.py │ └── sync_batchnorm │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── batchnorm.cpython-36.pyc │ │ ├── batchnorm.cpython-37.pyc │ │ ├── comm.cpython-36.pyc │ │ ├── comm.cpython-37.pyc │ │ ├── replicate.cpython-36.pyc │ │ └── replicate.cpython-37.pyc │ │ ├── batchnorm.py │ │ ├── batchnorm_reimpl.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py └── pix2pix_model.py ├── options ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── base_options.cpython-36.pyc │ ├── base_options.cpython-37.pyc │ ├── test_options.cpython-36.pyc │ ├── test_options.cpython-37.pyc │ ├── train_options.cpython-36.pyc │ └── train_options.cpython-37.pyc ├── base_options.py ├── test_options.py └── train_options.py ├── save_mslic.py ├── save_style_vector.py ├── test.py ├── train.py ├── trainers ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── pix2pix_trainer.cpython-36.pyc │ └── pix2pix_trainer.cpython-37.pyc └── pix2pix_trainer.py └── util ├── __init__.py ├── coco.py ├── html.py ├── util.py └── visualizer.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jonghyun Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SuperStyleNet: Deep Image Synthesis with Superpixel Based Style Encoder (BMVC 2021) 2 | 3 | ![Mix_comp](https://user-images.githubusercontent.com/42399549/137694588-28f522ee-e9aa-480c-8f85-eba8f1ebe0e6.png) 4 | **Figure:** Style mixing with multiple style images. The style vectors are replaced from source to style image on given semantic masks. 5 | 6 | ![SPSE](https://user-images.githubusercontent.com/42399549/137692560-ccb7e96e-6b9a-417c-8bbe-97db01205ea2.png) 7 | **Figure:** Superpixel based Style Encoding. To extract style codes of a specific semantic mask, we convert the input image into the five-dimensional space and cluster it in the semantic mask into superpixels. Thereafter, pixel values in each superpixel are averaged to obtain a style code. 8 | 9 | ## Update 10 | ### Update (December 28, 2021) 11 | This update is to correct minor errors in `'save_style_vector.py'`. 12 | 13 | 14 | ## Abstract 15 | 16 |
17 | CLICK ME 18 | Existing methods for image synthesis utilized a style encoder based on stacks of convolutions and pooling layers to generate style codes from input images. However, the encoded vectors do not necessarily contain local information of the corresponding images since small-scale objects are tended to "wash away" through such downscaling procedures. In this paper, we propose deep image synthesis with superpixel based style encoder, named as SuperStyleNet. First, we directly extract the style codes from the original image based on superpixels to consider local objects. Second, we recover spatial relationships in vectorized style codes based on graphical analysis. Thus, the proposed network achieves high-quality image synthesis by mapping the style codes into semantic labels. Experimental results show that the proposed method outperforms state-of-the-art ones in terms of visual quality and quantitative measurements. Furthermore, we achieve elaborate spatial style editing by adjusting style codes. 19 |
20 | 21 | > **SuperStlyeNet: Deep Image Synthesis with Superpixel Based Style Encoder** 22 | > 23 | > Jonghuyn Kim, Gen Li, Cheolkon Jung, Joongkyu Kim 24 | > British Machine Vision Conference **BMVC 2021** 25 | 26 | [[Paper](https://www.bmvc2021-virtualconference.com/assets/papers/0051.pdf)] [[Full Paper](https://arxiv.org/abs/2112.09367)] 27 | 28 | ## Installation 29 | 30 | Clone this repo. 31 | 32 | Install requirements: 33 | 34 |
35 | CLICK ME 36 | ``` 37 | torch==1.2.0 38 | torchvision==0.4.0 39 | easydict 40 | matplotlib 41 | opencv-python 42 | glob3 43 | pillow 44 | dill 45 | dominate>=2.3.1 46 | scikit-image 47 | QDarkStyle==2.7 48 | qdarkgraystyle==1.0.2 49 | tensorboard==1.14.0 50 | tensorboardX==1.9 51 | tqdm==4.32.1 52 | urllib3==1.25.8 53 | visdom==0.1.8.9 54 | ``` 55 |
56 | 57 | ## Dataset 58 | 59 | 1. This network uses [CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ), [Cityscapes](https://www.cityscapes-dataset.com/), and [CMP-Facade](https://cmp.felk.cvut.cz/~tylecr1/facade/) datasets. After downloading these datasets, unzip and save train and test images as follows: 60 | ``` 61 | dataset 62 | ├── celeba 63 | | ├── train 64 | | | ├── images 65 | | | ├── labels 66 | | | └── codes 67 | | ├── test 68 | | | ├── images 69 | | | ├── labels 70 | | | └── codes 71 | ├── cityscapes 72 | | ├── train 73 | | | ├── images 74 | | | ├── labels 75 | | | └── codes 76 | | ├── test 77 | | | ├── images 78 | | | ├── labels 79 | | | └── codes 80 | ``` 81 | 2. Download style codes in each dataset from [Google Drive](https://drive.google.com/file/d/1m3SAljvNebIaCy3gEM_Kzs6TRfI-T839/view?usp=sharing). After downloading them, unzip and save in `./dataset/[dataset name]/[train or test]/codes`. **To extract style codes using SPSE, it requires a lot of time. Thereby, we provide all style codes of three datasets.** 82 | 83 | ## Generating images using a pretrained model with style codes 84 | 85 | After preparing test images, the reconstructed images can be obtained using the pretrained model. 86 | 87 | 1. Creat a `checkpoint/celeba` folder. Download pretrained weight from [Google Drive](https://drive.google.com/file/d/1XBoHicrboLrePqJULgKdcbIrlnXbI9JS/view?usp=sharing) and upzip this `checkpoint.zip` in the `./checkpoint/celeba` folder. 88 | 2. Run `test.py` to generate synthesized images with a below code, which will be saved in `./checkpoint/celeba/result`. Save path and details can be edited in `./options/base_options.py` and `./options/test_options.py`. 89 | ``` 90 | python test.py --name celeba --load_size 256 --crop_size 256 --dataset_mode custom --label_dir datasets/celeba/test/labels --image_dir datasets/celeba/test/images --label_nc 19 --instance_dir datasets/celeba/test/codes --which_epoch 50 --gpu_ids 0 91 | ``` 92 | 93 | ## Training a new model on personal dataset 94 | 95 | ### For CelebAMask-HQ 96 | 1. Check your personal setting (i.e., implementation details, save path, and so on) in `./options/base_options.py` and `./options/train_options.py`. 97 | 2. Run `train.py`. 98 | ``` 99 | python train.py --name celeba --gpu_ids 0,1,2,3 --batchSize 32 --load_size 256 --crop_size 256 --dataset_mode custom --label_nc 19 --label_dir datasets/celeba/train/labels --image_dir datasets/celeba/train/images --instance_dir datasets/celeba/train/codes 100 | ``` 101 | 102 | ### For personal dataset 103 | 1. Save train and test images with labels in `./datasets/[dataset name]/train/[images or labels]` and `./datasets/[dataset name]/test/[images or labels]` folders, respectively. 104 | 2. Run `save_style_vector.py` to extract and save style vectors. This process requires a lot of time. 105 | 3. Check your personal setting (i.e., implementation details, save path, and so on) in `./options/base_options.py` and `./options/train_options.py`. 106 | 4. Run `train.py`. 107 | ``` 108 | python train.py --name personal_data --gpu_ids 0,1,2,3 --batchSize 32 --load_size 256 --crop_size 256 --dataset_mode custom --label_nc 19 --label_dir datasets/[dataset name]/train/labels --image_dir datasets/[dataset name]/train/images --instance_dir datasets/[dataset name]/train/codes 109 | ``` 110 | 111 | ## Citation 112 | If you use this code for your research, please cite our papers. 113 | ``` 114 | @inproceedings{kim2021superstylenet, 115 | title={SuperStyleNet: Deep Image Synthesis with Superpixel Based Style Encoder}, 116 | author={Jonghyun Kim and Gen Li and Cheolkon Jung and Joongkyu Kim}, 117 | booktitle={British Machine Vision Conference}, 118 | year={2021} 119 | } 120 | ``` 121 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import importlib 7 | import torch.utils.data 8 | from data.base_dataset import BaseDataset 9 | 10 | 11 | def find_dataset_using_name(dataset_name): 12 | # Given the option --dataset [datasetname], 13 | # the file "datasets/datasetname_dataset.py" 14 | # will be imported. 15 | dataset_filename = "data." + dataset_name + "_dataset" 16 | datasetlib = importlib.import_module(dataset_filename) 17 | 18 | # In the file, the class called DatasetNameDataset() will 19 | # be instantiated. It has to be a subclass of BaseDataset, 20 | # and it is case-insensitive. 21 | dataset = None 22 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 23 | for name, cls in datasetlib.__dict__.items(): 24 | if name.lower() == target_dataset_name.lower() \ 25 | and issubclass(cls, BaseDataset): 26 | dataset = cls 27 | 28 | if dataset is None: 29 | raise ValueError("In %s.py, there should be a subclass of BaseDataset " 30 | "with class name that matches %s in lowercase." % 31 | (dataset_filename, target_dataset_name)) 32 | 33 | return dataset 34 | 35 | 36 | def get_option_setter(dataset_name): 37 | dataset_class = find_dataset_using_name(dataset_name) 38 | return dataset_class.modify_commandline_options 39 | 40 | 41 | def create_dataloader(opt): 42 | dataset = find_dataset_using_name(opt.dataset_mode) 43 | instance = dataset() 44 | instance.initialize(opt) 45 | print("dataset [%s] of size %d was created" % 46 | (type(instance).__name__, len(instance))) 47 | dataloader = torch.utils.data.DataLoader( 48 | instance, 49 | batch_size=opt.batchSize, 50 | shuffle=not opt.serial_batches, 51 | num_workers=int(opt.nThreads), 52 | drop_last=opt.isTrain, 53 | pin_memory=True 54 | ) 55 | return dataloader 56 | -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/base_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/data/__pycache__/base_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/base_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/data/__pycache__/base_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/custom_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/data/__pycache__/custom_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/custom_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/data/__pycache__/custom_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/image_folder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/data/__pycache__/image_folder.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/image_folder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/data/__pycache__/image_folder.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/pix2pix_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/data/__pycache__/pix2pix_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/pix2pix_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/data/__pycache__/pix2pix_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch.utils.data as data 7 | from PIL import Image 8 | import torchvision.transforms as transforms 9 | import numpy as np 10 | import random 11 | 12 | 13 | class BaseDataset(data.Dataset): 14 | def __init__(self): 15 | super(BaseDataset, self).__init__() 16 | 17 | @staticmethod 18 | def modify_commandline_options(parser, is_train): 19 | return parser 20 | 21 | def initialize(self, opt): 22 | pass 23 | 24 | 25 | def get_params(opt, size): 26 | w, h = size 27 | new_h = h 28 | new_w = w 29 | if opt.preprocess_mode == 'resize_and_crop': 30 | new_h = new_w = opt.load_size 31 | elif opt.preprocess_mode == 'scale_width_and_crop': 32 | new_w = opt.load_size 33 | new_h = opt.load_size * h // w 34 | elif opt.preprocess_mode == 'scale_shortside_and_crop': 35 | ss, ls = min(w, h), max(w, h) # shortside and longside 36 | width_is_shorter = w == ss 37 | ls = int(opt.load_size * ls / ss) 38 | new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss) 39 | 40 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 41 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 42 | 43 | flip = random.random() > 0.5 44 | return {'crop_pos': (x, y), 'flip': flip} 45 | 46 | 47 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True): 48 | transform_list = [] 49 | if 'resize' in opt.preprocess_mode: 50 | osize = [opt.load_size, opt.load_size] 51 | transform_list.append(transforms.Resize(osize, interpolation=method)) 52 | elif 'scale_width' in opt.preprocess_mode: 53 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 54 | elif 'scale_shortside' in opt.preprocess_mode: 55 | transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method))) 56 | 57 | if 'crop' in opt.preprocess_mode: 58 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 59 | 60 | if opt.preprocess_mode == 'none': 61 | base = 32 62 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) 63 | 64 | if opt.preprocess_mode == 'fixed': 65 | w = opt.crop_size 66 | h = round(opt.crop_size / opt.aspect_ratio) 67 | transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method))) 68 | 69 | if opt.isTrain and not opt.no_flip: 70 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 71 | 72 | if toTensor: 73 | transform_list += [transforms.ToTensor()] 74 | 75 | if normalize: 76 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 77 | (0.5, 0.5, 0.5))] 78 | return transforms.Compose(transform_list) 79 | 80 | def get_transform_instance(toTensor=True, normalize=True): 81 | transform_list = [] 82 | #print('==============1') 83 | 84 | 85 | if toTensor: 86 | transform_list += [transforms.ToTensor()] 87 | 88 | if normalize: 89 | transform_list += [transforms.Normalize((0.5,), 90 | (0.5,))] 91 | 92 | return transforms.Compose(transform_list) 93 | 94 | 95 | def normalize(): 96 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 97 | 98 | 99 | def __resize(img, w, h, method=Image.BICUBIC): 100 | return img.resize((w, h), method) 101 | 102 | 103 | def __make_power_2(img, base, method=Image.BICUBIC): 104 | ow, oh = img.size 105 | h = int(round(oh / base) * base) 106 | w = int(round(ow / base) * base) 107 | if (h == oh) and (w == ow): 108 | return img 109 | return img.resize((w, h), method) 110 | 111 | 112 | def __scale_width(img, target_width, method=Image.BICUBIC): 113 | ow, oh = img.size 114 | if (ow == target_width): 115 | return img 116 | w = target_width 117 | h = int(target_width * oh / ow) 118 | return img.resize((w, h), method) 119 | 120 | 121 | def __scale_shortside(img, target_width, method=Image.BICUBIC): 122 | ow, oh = img.size 123 | ss, ls = min(ow, oh), max(ow, oh) # shortside and longside 124 | width_is_shorter = ow == ss 125 | if (ss == target_width): 126 | return img 127 | ls = int(target_width * ls / ss) 128 | nw, nh = (ss, ls) if width_is_shorter else (ls, ss) 129 | return img.resize((nw, nh), method) 130 | 131 | 132 | def __crop(img, pos, size): 133 | ow, oh = img.size 134 | x1, y1 = pos 135 | tw = th = size 136 | return img.crop((x1, y1, x1 + tw, y1 + th)) 137 | 138 | 139 | def __flip(img, flip): 140 | if flip: 141 | return img.transpose(Image.FLIP_LEFT_RIGHT) 142 | return img 143 | -------------------------------------------------------------------------------- /data/custom_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from data.pix2pix_dataset import Pix2pixDataset 7 | from data.image_folder import make_dataset 8 | 9 | 10 | class CustomDataset(Pix2pixDataset): 11 | """ Dataset that loads images from directories 12 | Use option --label_dir, --image_dir, --instance_dir to specify the directories. 13 | The images in the directories are sorted in alphabetical order and paired in order. 14 | """ 15 | 16 | @staticmethod 17 | def modify_commandline_options(parser, is_train): 18 | parser = Pix2pixDataset.modify_commandline_options(parser, is_train) 19 | parser.set_defaults(preprocess_mode='resize_and_crop') 20 | load_size = 286 if is_train else 256 21 | parser.set_defaults(load_size=load_size) 22 | parser.set_defaults(crop_size=256) 23 | parser.set_defaults(display_winsize=256) 24 | parser.set_defaults(label_nc=13) 25 | parser.set_defaults(contain_dontcare_label=False) 26 | 27 | parser.add_argument('--label_dir', type=str, required=True, 28 | help='path to the directory that contains label images') 29 | parser.add_argument('--image_dir', type=str, required=True, 30 | help='path to the directory that contains photo images') 31 | parser.add_argument('--instance_dir', type=str, default='', 32 | help='path to the directory that contains instance maps. Leave black if not exists') 33 | ''' 34 | parser.add_argument('--label_test_dir', type=str, required=True, 35 | help='path to the directory that contains label images') 36 | parser.add_argument('--image_test_dir', type=str, required=True, 37 | help='path to the directory that contains photo images') 38 | parser.add_argument('--instance_test_dir', type=str, default='', 39 | help='path to the directory that contains instance maps. Leave black if not exists') 40 | ''' 41 | return parser 42 | 43 | def get_paths(self, opt): 44 | label_dir = opt.label_dir 45 | label_paths = make_dataset(label_dir, recursive=False, read_cache=True) 46 | 47 | image_dir = opt.image_dir 48 | image_paths = make_dataset(image_dir, recursive=False, read_cache=True) 49 | 50 | if len(opt.instance_dir) > 0: 51 | instance_dir = opt.instance_dir 52 | instance_paths = make_dataset(instance_dir, recursive=False, read_cache=True) 53 | else: 54 | instance_paths = [] 55 | 56 | assert len(label_paths) == len(image_paths), "The #images in %s and %s do not match. Is there something wrong?" 57 | 58 | return label_paths, image_paths, instance_paths 59 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | ############################################################################### 7 | # Code from 8 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 9 | # Modified the original code so that it also loads images from the current 10 | # directory as well as the subdirectories 11 | ############################################################################### 12 | import torch.utils.data as data 13 | from PIL import Image 14 | import os 15 | 16 | IMG_EXTENSIONS = [ 17 | '.jpg', '.JPG', '.jpeg', '.JPEG', 18 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff', '.webp', 19 | '.npy' 20 | ] 21 | 22 | 23 | def is_image_file(filename): 24 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 25 | 26 | 27 | def make_dataset_rec(dir, images): 28 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 29 | 30 | for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)): 31 | for fname in fnames: 32 | if is_image_file(fname): 33 | path = os.path.join(root, fname) 34 | images.append(path) 35 | 36 | 37 | def make_dataset(dir, recursive=False, read_cache=False, write_cache=False, read_numpy=False): 38 | images = [] 39 | 40 | if read_numpy: 41 | possible_filelist = os.path.join(dir, 'files.list') 42 | if os.path.isfile(possible_filelist): 43 | with open(possible_filelist, 'r') as f: 44 | images = f.read().splitlines() 45 | return images 46 | 47 | if read_cache: 48 | possible_filelist = os.path.join(dir, 'files.list') 49 | if os.path.isfile(possible_filelist): 50 | with open(possible_filelist, 'r') as f: 51 | images = f.read().splitlines() 52 | return images 53 | 54 | if recursive: 55 | make_dataset_rec(dir, images) 56 | else: 57 | assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir 58 | 59 | for root, dnames, fnames in sorted(os.walk(dir)): 60 | for fname in fnames: 61 | if is_image_file(fname): 62 | path = os.path.join(root, fname) 63 | images.append(path) 64 | 65 | if write_cache: 66 | filelist_cache = os.path.join(dir, 'files.list') 67 | with open(filelist_cache, 'w') as f: 68 | for path in images: 69 | f.write("%s\n" % path) 70 | print('wrote filelist cache at %s' % filelist_cache) 71 | 72 | return images 73 | 74 | 75 | def default_loader(path): 76 | return Image.open(path).convert('RGB') 77 | 78 | 79 | class ImageFolder(data.Dataset): 80 | 81 | def __init__(self, root, transform=None, return_paths=False, 82 | loader=default_loader): 83 | imgs = make_dataset(root) 84 | if len(imgs) == 0: 85 | raise(RuntimeError("Found 0 images in: " + root + "\n" 86 | "Supported image extensions are: " + 87 | ",".join(IMG_EXTENSIONS))) 88 | 89 | self.root = root 90 | self.imgs = imgs 91 | self.transform = transform 92 | self.return_paths = return_paths 93 | self.loader = loader 94 | 95 | def __getitem__(self, index): 96 | path = self.imgs[index] 97 | img = self.loader(path) 98 | if self.transform is not None: 99 | img = self.transform(img) 100 | if self.return_paths: 101 | return img, path 102 | else: 103 | return img 104 | 105 | def __len__(self): 106 | return len(self.imgs) 107 | -------------------------------------------------------------------------------- /data/pix2pix_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from data.base_dataset import BaseDataset, get_params, get_transform, get_transform_instance 7 | from PIL import Image 8 | import util.util as util 9 | import os 10 | import torch 11 | import numpy as np 12 | 13 | class Pix2pixDataset(BaseDataset): 14 | @staticmethod 15 | def modify_commandline_options(parser, is_train): 16 | parser.add_argument('--no_pairing_check', action='store_true', 17 | help='If specified, skip sanity check of correct label-image file pairing') 18 | return parser 19 | 20 | def initialize(self, opt): 21 | self.opt = opt 22 | 23 | label_paths, image_paths, instance_paths = self.get_paths(opt) 24 | 25 | util.natural_sort(label_paths) 26 | util.natural_sort(image_paths) 27 | if not opt.no_instance: 28 | util.natural_sort(instance_paths) 29 | 30 | label_paths = label_paths[:opt.max_dataset_size] 31 | image_paths = image_paths[:opt.max_dataset_size] 32 | instance_paths = instance_paths[:opt.max_dataset_size] 33 | 34 | if not opt.no_pairing_check: 35 | for path1, path2 in zip(label_paths, image_paths): 36 | assert self.paths_match(path1, path2), \ 37 | "The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this." % (path1, path2) 38 | 39 | self.label_paths = label_paths 40 | self.image_paths = image_paths 41 | self.instance_paths = instance_paths 42 | 43 | size = len(self.label_paths) 44 | self.dataset_size = size 45 | 46 | def get_paths(self, opt): 47 | label_paths = [] 48 | image_paths = [] 49 | instance_paths = [] 50 | assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)" 51 | return label_paths, image_paths, instance_paths 52 | 53 | def paths_match(self, path1, path2): 54 | filename1_without_ext = os.path.splitext(os.path.basename(path1))[0] 55 | filename2_without_ext = os.path.splitext(os.path.basename(path2))[0] 56 | return filename1_without_ext == filename2_without_ext 57 | 58 | def __getitem__(self, index): 59 | # Label Image 60 | label_path = self.label_paths[index] 61 | label = Image.open(label_path) 62 | params = get_params(self.opt, label.size) 63 | transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) 64 | label_tensor = transform_label(label) * 255.0 65 | label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc 66 | 67 | # input image (real images) 68 | image_path = self.image_paths[index] 69 | assert self.paths_match(label_path, image_path), \ 70 | "The label_path %s and image_path %s don't match." % \ 71 | (label_path, image_path) 72 | image = Image.open(image_path) 73 | image = image.convert('RGB') 74 | 75 | transform_image = get_transform(self.opt, params) 76 | image_tensor = transform_image(image) 77 | 78 | # input style (style vectors) 79 | 80 | 81 | # if using instance maps 82 | if self.opt.no_instance: 83 | instance_tensor = 0 84 | else: 85 | instance_path = self.instance_paths[index] 86 | instance = np.load(instance_path) 87 | instance = np.squeeze(instance, axis=(1, 2, 3)) 88 | instance = np.array(instance, dtype=np.float32) 89 | #print(np.shape(instance)) 90 | 91 | transform_instance = get_transform_instance() 92 | instance_tensor = transform_instance(instance) 93 | instance_tensor.squeeze(0) 94 | 95 | input_dict = {'label': label_tensor, 96 | 'instance': instance_tensor, 97 | 'image': image_tensor, 98 | 'path': image_path, 99 | } 100 | 101 | # Give subclasses a chance to modify the final output 102 | 103 | self.postprocess(input_dict) 104 | 105 | return input_dict 106 | 107 | def postprocess(self, input_dict): 108 | return input_dict 109 | 110 | def __len__(self): 111 | return self.dataset_size 112 | 113 | 114 | # Our codes get input images and labels 115 | def get_input_by_names(self, image_path, image, label_img): 116 | label = Image.fromarray(label_img) 117 | params = get_params(self.opt, label.size) 118 | transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) 119 | label_tensor = transform_label(label) * 255.0 120 | label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc 121 | label_tensor.unsqueeze_(0) 122 | 123 | 124 | # input image (real images)] 125 | # image = Image.open(image_path) 126 | # image = image.convert('RGB') 127 | 128 | transform_image = get_transform(self.opt, params) 129 | image_tensor = transform_image(image) 130 | image_tensor.unsqueeze_(0) 131 | 132 | # if using instance maps 133 | if self.opt.no_instance: 134 | instance_tensor = torch.Tensor([0]) 135 | 136 | input_dict = {'label': label_tensor, 137 | 'instance': instance_tensor, 138 | 'image': image_tensor, 139 | 'path': image_path, 140 | } 141 | 142 | # Give subclasses a chance to modify the final output 143 | self.postprocess(input_dict) 144 | 145 | return input_dict 146 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import importlib 7 | import torch 8 | 9 | 10 | def find_model_using_name(model_name): 11 | # Given the option --model [modelname], 12 | # the file "models/modelname_model.py" 13 | # will be imported. 14 | model_filename = "models." + model_name + "_model" 15 | modellib = importlib.import_module(model_filename) 16 | 17 | # In the file, the class called ModelNameModel() will 18 | # be instantiated. It has to be a subclass of torch.nn.Module, 19 | # and it is case-insensitive. 20 | model = None 21 | target_model_name = model_name.replace('_', '') + 'model' 22 | for name, cls in modellib.__dict__.items(): 23 | if name.lower() == target_model_name.lower() \ 24 | and issubclass(cls, torch.nn.Module): 25 | model = cls 26 | 27 | if model is None: 28 | print("In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name)) 29 | exit(0) 30 | 31 | return model 32 | 33 | 34 | def get_option_setter(model_name): 35 | model_class = find_model_using_name(model_name) 36 | return model_class.modify_commandline_options 37 | 38 | 39 | def create_model(opt): 40 | model = find_model_using_name(opt.model) 41 | instance = model(opt) 42 | print("model [%s] was created" % (type(instance).__name__)) 43 | 44 | return instance 45 | -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/pix2pix_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/__pycache__/pix2pix_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/pix2pix_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/__pycache__/pix2pix_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | from models.networks.base_network import BaseNetwork 8 | from models.networks.loss import * 9 | from models.networks.discriminator import * 10 | from models.networks.generator import * 11 | from models.networks.encoder import * 12 | import util.util as util 13 | 14 | 15 | def find_network_using_name(target_network_name, filename): 16 | target_class_name = target_network_name + filename 17 | module_name = 'models.networks.' + filename 18 | network = util.find_class_in_module(target_class_name, module_name) 19 | 20 | assert issubclass(network, BaseNetwork), \ 21 | "Class %s should be a subclass of BaseNetwork" % network 22 | 23 | return network 24 | 25 | 26 | def modify_commandline_options(parser, is_train): 27 | opt, _ = parser.parse_known_args() 28 | 29 | netG_cls = find_network_using_name(opt.netG, 'generator') 30 | parser = netG_cls.modify_commandline_options(parser, is_train) 31 | if is_train: 32 | netD_cls = find_network_using_name(opt.netD, 'discriminator') 33 | parser = netD_cls.modify_commandline_options(parser, is_train) 34 | netE_cls = find_network_using_name('conv', 'encoder') 35 | parser = netE_cls.modify_commandline_options(parser, is_train) 36 | 37 | return parser 38 | 39 | 40 | def create_network(cls, opt): 41 | net = cls(opt) 42 | net.print_network() 43 | if len(opt.gpu_ids) > 0: 44 | assert(torch.cuda.is_available()) 45 | net.cuda() 46 | net.init_weights(opt.init_type, opt.init_variance) 47 | return net 48 | 49 | 50 | def define_G(opt): 51 | netG_cls = find_network_using_name(opt.netG, 'generator') 52 | return create_network(netG_cls, opt) 53 | 54 | 55 | def define_D(opt): 56 | netD_cls = find_network_using_name(opt.netD, 'discriminator') 57 | return create_network(netD_cls, opt) 58 | 59 | 60 | def define_E(opt): 61 | # there exists only one encoder type 62 | netE_cls = find_network_using_name('conv', 'encoder') 63 | return create_network(netE_cls, opt) 64 | -------------------------------------------------------------------------------- /models/networks/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/architecture.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/__pycache__/architecture.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/architecture.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/__pycache__/architecture.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/base_network.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/__pycache__/base_network.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/base_network.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/__pycache__/base_network.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/discriminator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/__pycache__/discriminator.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/discriminator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/__pycache__/discriminator.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/encoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/__pycache__/encoder.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/__pycache__/encoder.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/generator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/__pycache__/generator.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/generator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/__pycache__/generator.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/normalization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/__pycache__/normalization.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/normalization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/__pycache__/normalization.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/architecture.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision 10 | import torch.nn.utils.spectral_norm as spectral_norm 11 | from models.networks.normalization import SPADE, ACE 12 | 13 | 14 | # ResNet block that uses SPADE. 15 | # It differs from the ResNet block of pix2pixHD in that 16 | # it takes in the segmentation map as input, learns the skip connection if necessary, 17 | # and applies normalization first and then convolution. 18 | # This architecture seemed like a standard architecture for unconditional or 19 | # class-conditional GAN architecture using residual block. 20 | # The code was inspired from https://github.com/LMescheder/GAN_stability. 21 | class SPADEResnetBlock(nn.Module): 22 | def __init__(self, fin, fout, opt, Block_Name=None, use_rgb=True): 23 | super().__init__() 24 | 25 | self.use_rgb = use_rgb 26 | 27 | self.Block_Name = Block_Name 28 | self.status = opt.status 29 | 30 | # Attributes 31 | self.learned_shortcut = (fin != fout) 32 | fmiddle = min(fin, fout) 33 | 34 | # create conv layers 35 | self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) 36 | self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) 37 | if self.learned_shortcut: 38 | self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) 39 | 40 | # apply spectral norm if specified 41 | if 'spectral' in opt.norm_G: 42 | self.conv_0 = spectral_norm(self.conv_0) 43 | self.conv_1 = spectral_norm(self.conv_1) 44 | if self.learned_shortcut: 45 | self.conv_s = spectral_norm(self.conv_s) 46 | 47 | # define normalization layers 48 | spade_config_str = opt.norm_G.replace('spectral', '') 49 | 50 | 51 | ########### Modifications 1 52 | normtype_list = ['spadeinstance3x3', 'spadesyncbatch3x3', 'spadebatch3x3'] 53 | our_norm_type = 'spadesyncbatch3x3' 54 | 55 | self.ace_0 = ACE(our_norm_type, fin, 3, ACE_Name= Block_Name + '_ACE_0', status=self.status, spade_params=[spade_config_str, fin, opt.semantic_nc], use_rgb=use_rgb) 56 | ########### Modifications 1 57 | 58 | 59 | ########### Modifications 1 60 | self.ace_1 = ACE(our_norm_type, fmiddle, 3, ACE_Name= Block_Name + '_ACE_1', status=self.status, spade_params=[spade_config_str, fmiddle, opt.semantic_nc], use_rgb=use_rgb) 61 | ########### Modifications 1 62 | 63 | if self.learned_shortcut: 64 | self.ace_s = ACE(our_norm_type, fin, 3, ACE_Name= Block_Name + '_ACE_s', status=self.status, spade_params=[spade_config_str, fin, opt.semantic_nc], use_rgb=use_rgb) 65 | 66 | # note the resnet block with SPADE also takes in |seg|, 67 | # the semantic segmentation map as input 68 | def forward(self, x, seg, style_codes, obj_dic=None): 69 | 70 | 71 | x_s = self.shortcut(x, seg, style_codes, obj_dic) 72 | 73 | 74 | ########### Modifications 1 75 | dx = self.ace_0(x, seg, style_codes, obj_dic) 76 | 77 | dx = self.conv_0(self.actvn(dx)) 78 | 79 | dx = self.ace_1(dx, seg, style_codes, obj_dic) 80 | 81 | dx = self.conv_1(self.actvn(dx)) 82 | ########### Modifications 1 83 | 84 | 85 | out = x_s + dx 86 | return out 87 | 88 | def shortcut(self, x, seg, style_codes, obj_dic): 89 | if self.learned_shortcut: 90 | x_s = self.ace_s(x, seg, style_codes, obj_dic) 91 | x_s = self.conv_s(x_s) 92 | 93 | else: 94 | x_s = x 95 | return x_s 96 | 97 | def actvn(self, x): 98 | return F.leaky_relu(x, 2e-1) 99 | 100 | 101 | # ResNet block used in pix2pixHD 102 | # We keep the same architecture as pix2pixHD. 103 | class ResnetBlock(nn.Module): 104 | def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3): 105 | super().__init__() 106 | 107 | pw = (kernel_size - 1) // 2 108 | self.conv_block = nn.Sequential( 109 | nn.ReflectionPad2d(pw), 110 | norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)), 111 | activation, 112 | nn.ReflectionPad2d(pw), 113 | norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)) 114 | ) 115 | 116 | def forward(self, x): 117 | y = self.conv_block(x) 118 | out = x + y 119 | return out 120 | 121 | 122 | # VGG architecter, used for the perceptual loss using a pretrained VGG network 123 | class VGG19(torch.nn.Module): 124 | def __init__(self, requires_grad=False): 125 | super().__init__() 126 | vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features 127 | self.slice1 = torch.nn.Sequential() 128 | self.slice2 = torch.nn.Sequential() 129 | self.slice3 = torch.nn.Sequential() 130 | self.slice4 = torch.nn.Sequential() 131 | self.slice5 = torch.nn.Sequential() 132 | for x in range(2): 133 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 134 | for x in range(2, 7): 135 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 136 | for x in range(7, 12): 137 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 138 | for x in range(12, 21): 139 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 140 | for x in range(21, 30): 141 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 142 | if not requires_grad: 143 | for param in self.parameters(): 144 | param.requires_grad = False 145 | 146 | def forward(self, X): 147 | h_relu1 = self.slice1(X) 148 | h_relu2 = self.slice2(h_relu1) 149 | h_relu3 = self.slice3(h_relu2) 150 | h_relu4 = self.slice4(h_relu3) 151 | h_relu5 = self.slice5(h_relu4) 152 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 153 | return out 154 | 155 | class Zencoder(torch.nn.Module): 156 | def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=2, norm_layer=nn.InstanceNorm2d): 157 | super(Zencoder, self).__init__() 158 | self.output_nc = output_nc 159 | 160 | model = [nn.ReflectionPad2d(1), nn.Conv2d(input_nc, ngf, kernel_size=3, padding=0), 161 | norm_layer(ngf), nn.LeakyReLU(0.2, False)] 162 | ### downsample 163 | for i in range(n_downsampling): 164 | mult = 2**i 165 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), 166 | norm_layer(ngf * mult * 2), nn.LeakyReLU(0.2, False)] 167 | 168 | ### upsample 169 | for i in range(1): 170 | mult = 2**(n_downsampling - i) 171 | model += [nn.ConvTranspose2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, output_padding=1), 172 | norm_layer(int(ngf * mult / 2)), nn.LeakyReLU(0.2, False)] 173 | 174 | model += [nn.ReflectionPad2d(1), nn.Conv2d(256, output_nc, kernel_size=3, padding=0), nn.Tanh()] 175 | self.model = nn.Sequential(*model) 176 | 177 | 178 | def forward(self, input, segmap): 179 | 180 | codes = self.model(input) 181 | 182 | segmap = F.interpolate(segmap, size=codes.size()[2:], mode='nearest') 183 | 184 | # print(segmap.shape) 185 | # print(codes.shape) 186 | 187 | 188 | b_size = codes.shape[0] 189 | # h_size = codes.shape[2] 190 | # w_size = codes.shape[3] 191 | f_size = codes.shape[1] 192 | 193 | s_size = segmap.shape[1] 194 | 195 | codes_vector = torch.zeros((b_size, s_size, f_size), dtype=codes.dtype, device=codes.device) 196 | 197 | 198 | for i in range(b_size): 199 | for j in range(s_size): 200 | component_mask_area = torch.sum(segmap.bool()[i, j]) 201 | 202 | if component_mask_area > 0: 203 | codes_component_feature = codes[i].masked_select(segmap.bool()[i, j]).reshape(f_size, component_mask_area).mean(1) 204 | codes_vector[i][j] = codes_component_feature 205 | 206 | # codes_avg[i].masked_scatter_(segmap.bool()[i, j], codes_component_mu) 207 | 208 | return codes_vector 209 | ''' 210 | class Zencoder(torch.nn.Module): 211 | def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=2, norm_layer=nn.InstanceNorm2d): 212 | super(Zencoder, self).__init__() 213 | self.output_nc = output_nc 214 | 215 | model = [nn.ReflectionPad2d(1), nn.Conv2d(input_nc, ngf, kernel_size=3, padding=0), 216 | norm_layer(ngf), nn.LeakyReLU(0.2, False)] 217 | ### downsample 218 | for i in range(n_downsampling): 219 | mult = 2**i 220 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), 221 | norm_layer(ngf * mult * 2), nn.LeakyReLU(0.2, False)] 222 | 223 | ### upsample 224 | for i in range(1): 225 | mult = 2**(n_downsampling - i) 226 | model += [nn.ConvTranspose2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, output_padding=1), 227 | norm_layer(int(ngf * mult / 2)), nn.LeakyReLU(0.2, False)] 228 | 229 | model += [nn.ReflectionPad2d(1), nn.Conv2d(256, output_nc, kernel_size=3, padding=0), nn.Tanh()] 230 | self.model = nn.Sequential(*model) 231 | 232 | 233 | def forward(self, input, segmap): 234 | 235 | codes = self.model(input) 236 | 237 | segmap = F.interpolate(segmap, size=codes.size()[2:], mode='nearest') 238 | 239 | # print(segmap.shape) 240 | # print(codes.shape) 241 | 242 | 243 | b_size = codes.shape[0] 244 | # h_size = codes.shape[2] 245 | # w_size = codes.shape[3] 246 | f_size = codes.shape[1] 247 | 248 | s_size = segmap.shape[1] 249 | 250 | codes_vector = torch.zeros((b_size, s_size, f_size), dtype=codes.dtype, device=codes.device) 251 | 252 | 253 | for i in range(b_size): 254 | for j in range(s_size): 255 | component_mask_area = torch.sum(segmap.bool()[i, j]) 256 | 257 | if component_mask_area > 0: 258 | codes_component_feature = codes[i].masked_select(segmap.bool()[i, j]).reshape(f_size, component_mask_area).mean(1) 259 | codes_vector[i][j] = codes_component_feature 260 | 261 | # codes_avg[i].masked_scatter_(segmap.bool()[i, j], codes_component_mu) 262 | 263 | return codes_vector 264 | ''' -------------------------------------------------------------------------------- /models/networks/base_network.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch.nn as nn 7 | from torch.nn import init 8 | 9 | 10 | class BaseNetwork(nn.Module): 11 | def __init__(self): 12 | super(BaseNetwork, self).__init__() 13 | 14 | @staticmethod 15 | def modify_commandline_options(parser, is_train): 16 | return parser 17 | 18 | def print_network(self): 19 | if isinstance(self, list): 20 | self = self[0] 21 | num_params = 0 22 | for param in self.parameters(): 23 | num_params += param.numel() 24 | print('Network [%s] was created. Total number of parameters: %.1f million. ' 25 | 'To see the architecture, do print(network).' 26 | % (type(self).__name__, num_params / 1000000)) 27 | 28 | def init_weights(self, init_type='normal', gain=0.02): 29 | def init_func(m): 30 | classname = m.__class__.__name__ 31 | if classname.find('BatchNorm2d') != -1: 32 | if hasattr(m, 'weight') and m.weight is not None: 33 | init.normal_(m.weight.data, 1.0, gain) 34 | if hasattr(m, 'bias') and m.bias is not None: 35 | init.constant_(m.bias.data, 0.0) 36 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 37 | if init_type == 'normal': 38 | init.normal_(m.weight.data, 0.0, gain) 39 | elif init_type == 'xavier': 40 | init.xavier_normal_(m.weight.data, gain=gain) 41 | elif init_type == 'xavier_uniform': 42 | init.xavier_uniform_(m.weight.data, gain=1.0) 43 | elif init_type == 'kaiming': 44 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 45 | elif init_type == 'orthogonal': 46 | init.orthogonal_(m.weight.data, gain=gain) 47 | elif init_type == 'none': # uses pytorch's default init method 48 | m.reset_parameters() 49 | else: 50 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 51 | if hasattr(m, 'bias') and m.bias is not None: 52 | init.constant_(m.bias.data, 0.0) 53 | 54 | self.apply(init_func) 55 | 56 | # propagate to children 57 | for m in self.children(): 58 | if hasattr(m, 'init_weights'): 59 | m.init_weights(init_type, gain) 60 | -------------------------------------------------------------------------------- /models/networks/discriminator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch.nn as nn 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from models.networks.base_network import BaseNetwork 10 | from models.networks.normalization import get_nonspade_norm_layer 11 | import util.util as util 12 | 13 | 14 | class MultiscaleDiscriminator(BaseNetwork): 15 | @staticmethod 16 | def modify_commandline_options(parser, is_train): 17 | parser.add_argument('--netD_subarch', type=str, default='n_layer', 18 | help='architecture of each discriminator') 19 | parser.add_argument('--num_D', type=int, default=2, 20 | help='number of discriminators to be used in multiscale') 21 | opt, _ = parser.parse_known_args() 22 | 23 | # define properties of each discriminator of the multiscale discriminator 24 | subnetD = util.find_class_in_module(opt.netD_subarch + 'discriminator', 25 | 'models.networks.discriminator') 26 | subnetD.modify_commandline_options(parser, is_train) 27 | 28 | return parser 29 | 30 | def __init__(self, opt): 31 | super().__init__() 32 | self.opt = opt 33 | 34 | for i in range(opt.num_D): 35 | subnetD = self.create_single_discriminator(opt) 36 | self.add_module('discriminator_%d' % i, subnetD) 37 | 38 | def create_single_discriminator(self, opt): 39 | subarch = opt.netD_subarch 40 | if subarch == 'n_layer': 41 | netD = NLayerDiscriminator(opt) 42 | else: 43 | raise ValueError('unrecognized discriminator subarchitecture %s' % subarch) 44 | return netD 45 | 46 | def downsample(self, input): 47 | return F.avg_pool2d(input, kernel_size=3, 48 | stride=2, padding=[1, 1], 49 | count_include_pad=False) 50 | 51 | # Returns list of lists of discriminator outputs. 52 | # The final result is of size opt.num_D x opt.n_layers_D 53 | def forward(self, input): 54 | result = [] 55 | get_intermediate_features = not self.opt.no_ganFeat_loss 56 | for name, D in self.named_children(): 57 | out = D(input) 58 | if not get_intermediate_features: 59 | out = [out] 60 | result.append(out) 61 | input = self.downsample(input) 62 | 63 | return result 64 | 65 | 66 | # Defines the PatchGAN discriminator with the specified arguments. 67 | class NLayerDiscriminator(BaseNetwork): 68 | @staticmethod 69 | def modify_commandline_options(parser, is_train): 70 | parser.add_argument('--n_layers_D', type=int, default=3, 71 | help='# layers in each discriminator') 72 | return parser 73 | 74 | def __init__(self, opt): 75 | super().__init__() 76 | self.opt = opt 77 | 78 | kw = 4 79 | padw = int(np.ceil((kw - 1.0) / 2)) 80 | nf = opt.ndf 81 | input_nc = self.compute_D_input_nc(opt) 82 | 83 | norm_layer = get_nonspade_norm_layer(opt, opt.norm_D) 84 | sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), 85 | nn.LeakyReLU(0.2, False)]] 86 | 87 | for n in range(1, opt.n_layers_D): 88 | nf_prev = nf 89 | nf = min(nf * 2, 512) 90 | sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, 91 | stride=2, padding=padw)), 92 | nn.LeakyReLU(0.2, False) 93 | ]] 94 | 95 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 96 | 97 | # We divide the layers into groups to extract intermediate layer outputs 98 | for n in range(len(sequence)): 99 | self.add_module('model' + str(n), nn.Sequential(*sequence[n])) 100 | 101 | def compute_D_input_nc(self, opt): 102 | input_nc = opt.label_nc + opt.output_nc 103 | if opt.contain_dontcare_label: 104 | input_nc += 1 105 | 106 | # modify input_nc += 1 to 0 107 | if not opt.no_instance: 108 | input_nc += 0 109 | return input_nc 110 | 111 | def forward(self, input): 112 | results = [input] 113 | for submodel in self.children(): 114 | intermediate_output = submodel(results[-1]) 115 | results.append(intermediate_output) 116 | 117 | get_intermediate_features = not self.opt.no_ganFeat_loss 118 | if get_intermediate_features: 119 | return results[1:] 120 | else: 121 | return results[-1] 122 | -------------------------------------------------------------------------------- /models/networks/encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch.nn as nn 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from models.networks.base_network import BaseNetwork 10 | from models.networks.normalization import get_nonspade_norm_layer 11 | 12 | 13 | class ConvEncoder(BaseNetwork): 14 | """ Same architecture as the image discriminator """ 15 | 16 | def __init__(self, opt): 17 | super().__init__() 18 | 19 | kw = 3 20 | pw = int(np.ceil((kw - 1.0) / 2)) 21 | ndf = opt.ngf 22 | norm_layer = get_nonspade_norm_layer(opt, opt.norm_E) 23 | self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw)) 24 | self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw)) 25 | self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw)) 26 | self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw)) 27 | self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)) 28 | if opt.crop_size >= 256: 29 | self.layer6 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)) 30 | 31 | self.so = s0 = 4 32 | self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256) 33 | self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256) 34 | 35 | self.actvn = nn.LeakyReLU(0.2, False) 36 | self.opt = opt 37 | 38 | def forward(self, x): 39 | if x.size(2) != 256 or x.size(3) != 256: 40 | x = F.interpolate(x, size=(256, 256), mode='bilinear') 41 | 42 | x = self.layer1(x) 43 | x = self.layer2(self.actvn(x)) 44 | x = self.layer3(self.actvn(x)) 45 | x = self.layer4(self.actvn(x)) 46 | x = self.layer5(self.actvn(x)) 47 | if self.opt.crop_size >= 256: 48 | x = self.layer6(self.actvn(x)) 49 | x = self.actvn(x) 50 | 51 | x = x.view(x.size(0), -1) 52 | mu = self.fc_mu(x) 53 | logvar = self.fc_var(x) 54 | 55 | return mu, logvar 56 | -------------------------------------------------------------------------------- /models/networks/generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from models.networks.base_network import BaseNetwork 10 | from models.networks.normalization import get_nonspade_norm_layer 11 | from models.networks.architecture import ResnetBlock as ResnetBlock 12 | from models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock 13 | 14 | 15 | class SPADEGenerator(BaseNetwork): 16 | @staticmethod 17 | def modify_commandline_options(parser, is_train): 18 | parser.set_defaults(norm_G='spectralspadesyncbatch3x3') 19 | parser.add_argument('--num_upsampling_layers', 20 | choices=('normal', 'more', 'most'), default='normal', 21 | help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator") 22 | 23 | return parser 24 | 25 | def __init__(self, opt): 26 | super().__init__() 27 | self.opt = opt 28 | nf = opt.ngf 29 | 30 | self.sw, self.sh = self.compute_latent_vector_size(opt) 31 | 32 | 33 | self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1) 34 | 35 | self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='head_0') 36 | 37 | self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='G_middle_0') 38 | self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='G_middle_1') 39 | 40 | self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt, Block_Name='up_0') 41 | self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt, Block_Name='up_1') 42 | self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt, Block_Name='up_2') 43 | self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt, Block_Name='up_3', use_rgb=False) 44 | 45 | final_nc = nf 46 | 47 | if opt.num_upsampling_layers == 'most': 48 | self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt, Block_Name='up_4') 49 | final_nc = nf // 2 50 | 51 | self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1) 52 | 53 | self.up = nn.Upsample(scale_factor=2) 54 | #self.up = nn.Upsample(scale_factor=2, mode='bilinear') 55 | 56 | 57 | def compute_latent_vector_size(self, opt): 58 | if opt.num_upsampling_layers == 'normal': 59 | num_up_layers = 5 60 | elif opt.num_upsampling_layers == 'more': 61 | num_up_layers = 6 62 | elif opt.num_upsampling_layers == 'most': 63 | num_up_layers = 7 64 | else: 65 | raise ValueError('opt.num_upsampling_layers [%s] not recognized' % 66 | opt.num_upsampling_layers) 67 | 68 | sw = opt.crop_size // (2**num_up_layers) 69 | sh = round(sw / opt.aspect_ratio) 70 | 71 | return sw, sh 72 | 73 | def forward(self, input, rgb_img, input_instance, obj_dic=None): 74 | seg = input 75 | 76 | x = F.interpolate(seg, size=(self.sh, self.sw)) 77 | x = self.fc(x) 78 | 79 | #style_codes = self.Zencoder(input=rgb_img, segmap=seg) 80 | style_codes = input_instance 81 | 82 | x = self.head_0(x, seg, style_codes, obj_dic=obj_dic) 83 | 84 | x = self.up(x) 85 | x = self.G_middle_0(x, seg, style_codes, obj_dic=obj_dic) 86 | 87 | if self.opt.num_upsampling_layers == 'more' or \ 88 | self.opt.num_upsampling_layers == 'most': 89 | x = self.up(x) 90 | 91 | x = self.G_middle_1(x, seg, style_codes, obj_dic=obj_dic) 92 | 93 | x = self.up(x) 94 | x = self.up_0(x, seg, style_codes, obj_dic=obj_dic) 95 | x = self.up(x) 96 | x = self.up_1(x, seg, style_codes, obj_dic=obj_dic) 97 | x = self.up(x) 98 | x = self.up_2(x, seg, style_codes, obj_dic=obj_dic) 99 | x = self.up(x) 100 | x = self.up_3(x, seg, style_codes, obj_dic=obj_dic) 101 | 102 | # if self.opt.num_upsampling_layers == 'most': 103 | # x = self.up(x) 104 | # x= self.up_4(x, seg, style_codes, obj_dic=obj_dic) 105 | 106 | x = self.conv_img(F.leaky_relu(x, 2e-1)) 107 | x = F.tanh(x) 108 | return x 109 | 110 | 111 | # class Pix2PixHDGenerator(BaseNetwork): 112 | # @staticmethod 113 | # def modify_commandline_options(parser, is_train): 114 | # parser.add_argument('--resnet_n_downsample', type=int, default=4, help='number of downsampling layers in netG') 115 | # parser.add_argument('--resnet_n_blocks', type=int, default=9, help='number of residual blocks in the global generator network') 116 | # parser.add_argument('--resnet_kernel_size', type=int, default=3, 117 | # help='kernel size of the resnet block') 118 | # parser.add_argument('--resnet_initial_kernel_size', type=int, default=7, 119 | # help='kernel size of the first convolution') 120 | # parser.set_defaults(norm_G='instance') 121 | # return parser 122 | # 123 | # def __init__(self, opt): 124 | # super().__init__() 125 | # input_nc = opt.label_nc + (1 if opt.contain_dontcare_label else 0) + (0 if opt.no_instance else 1) 126 | # 127 | # norm_layer = get_nonspade_norm_layer(opt, opt.norm_G) 128 | # activation = nn.ReLU(False) 129 | # 130 | # model = [] 131 | # 132 | # # initial conv 133 | # model += [nn.ReflectionPad2d(opt.resnet_initial_kernel_size // 2), 134 | # norm_layer(nn.Conv2d(input_nc, opt.ngf, 135 | # kernel_size=opt.resnet_initial_kernel_size, 136 | # padding=0)), 137 | # activation] 138 | # 139 | # # downsample 140 | # mult = 1 141 | # for i in range(opt.resnet_n_downsample): 142 | # model += [norm_layer(nn.Conv2d(opt.ngf * mult, opt.ngf * mult * 2, 143 | # kernel_size=3, stride=2, padding=1)), 144 | # activation] 145 | # mult *= 2 146 | # 147 | # # resnet blocks 148 | # for i in range(opt.resnet_n_blocks): 149 | # model += [ResnetBlock(opt.ngf * mult, 150 | # norm_layer=norm_layer, 151 | # activation=activation, 152 | # kernel_size=opt.resnet_kernel_size)] 153 | # 154 | # # upsample 155 | # for i in range(opt.resnet_n_downsample): 156 | # nc_in = int(opt.ngf * mult) 157 | # nc_out = int((opt.ngf * mult) / 2) 158 | # model += [norm_layer(nn.ConvTranspose2d(nc_in, nc_out, 159 | # kernel_size=3, stride=2, 160 | # padding=1, output_padding=1)), 161 | # activation] 162 | # mult = mult // 2 163 | # 164 | # # final output conv 165 | # model += [nn.ReflectionPad2d(3), 166 | # nn.Conv2d(nc_out, opt.output_nc, kernel_size=7, padding=0), 167 | # nn.Tanh()] 168 | # 169 | # self.model = nn.Sequential(*model) 170 | # 171 | # def forward(self, input, z=None): 172 | # return self.model(input) 173 | -------------------------------------------------------------------------------- /models/networks/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from models.networks.architecture import VGG19 10 | 11 | 12 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 13 | # When LSGAN is used, it is basically same as MSELoss, 14 | # but it abstracts away the need to create the target label tensor 15 | # that has the same size as the input 16 | class GANLoss(nn.Module): 17 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, 18 | tensor=torch.FloatTensor, opt=None): 19 | super(GANLoss, self).__init__() 20 | self.real_label = target_real_label 21 | self.fake_label = target_fake_label 22 | self.real_label_tensor = None 23 | self.fake_label_tensor = None 24 | self.zero_tensor = None 25 | self.Tensor = tensor 26 | self.gan_mode = gan_mode 27 | self.opt = opt 28 | if gan_mode == 'ls': 29 | pass 30 | elif gan_mode == 'original': 31 | pass 32 | elif gan_mode == 'w': 33 | pass 34 | elif gan_mode == 'hinge': 35 | pass 36 | else: 37 | raise ValueError('Unexpected gan_mode {}'.format(gan_mode)) 38 | 39 | def get_target_tensor(self, input, target_is_real): 40 | if target_is_real: 41 | if self.real_label_tensor is None: 42 | self.real_label_tensor = self.Tensor(1).fill_(self.real_label) 43 | self.real_label_tensor.requires_grad_(False) 44 | return self.real_label_tensor.expand_as(input) 45 | else: 46 | if self.fake_label_tensor is None: 47 | self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label) 48 | self.fake_label_tensor.requires_grad_(False) 49 | return self.fake_label_tensor.expand_as(input) 50 | 51 | def get_zero_tensor(self, input): 52 | if self.zero_tensor is None: 53 | self.zero_tensor = self.Tensor(1).fill_(0) 54 | self.zero_tensor.requires_grad_(False) 55 | return self.zero_tensor.expand_as(input) 56 | 57 | def loss(self, input, target_is_real, for_discriminator=True): 58 | if self.gan_mode == 'original': # cross entropy loss 59 | target_tensor = self.get_target_tensor(input, target_is_real) 60 | loss = F.binary_cross_entropy_with_logits(input, target_tensor) 61 | return loss 62 | elif self.gan_mode == 'ls': 63 | target_tensor = self.get_target_tensor(input, target_is_real) 64 | return F.mse_loss(input, target_tensor) 65 | elif self.gan_mode == 'hinge': 66 | if for_discriminator: 67 | if target_is_real: 68 | minval = torch.min(input - 1, self.get_zero_tensor(input)) 69 | loss = -torch.mean(minval) 70 | else: 71 | minval = torch.min(-input - 1, self.get_zero_tensor(input)) 72 | loss = -torch.mean(minval) 73 | else: 74 | assert target_is_real, "The generator's hinge loss must be aiming for real" 75 | loss = -torch.mean(input) 76 | return loss 77 | else: 78 | # wgan 79 | if target_is_real: 80 | return -input.mean() 81 | else: 82 | return input.mean() 83 | 84 | def __call__(self, input, target_is_real, for_discriminator=True): 85 | # computing loss is a bit complicated because |input| may not be 86 | # a tensor, but list of tensors in case of multiscale discriminator 87 | if isinstance(input, list): 88 | loss = 0 89 | for pred_i in input: 90 | if isinstance(pred_i, list): 91 | pred_i = pred_i[-1] 92 | loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) 93 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) 94 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) 95 | loss += new_loss 96 | return loss / len(input) 97 | else: 98 | return self.loss(input, target_is_real, for_discriminator) 99 | 100 | 101 | # Perceptual loss that uses a pretrained VGG network 102 | class VGGLoss(nn.Module): 103 | def __init__(self, gpu_ids): 104 | super(VGGLoss, self).__init__() 105 | self.vgg = VGG19().cuda() 106 | self.criterion = nn.L1Loss() 107 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] 108 | 109 | def forward(self, x, y): 110 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 111 | loss = 0 112 | for i in range(len(x_vgg)): 113 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 114 | return loss 115 | 116 | 117 | 118 | # KL Divergence loss used in VAE with an image encoder 119 | # class KLDLoss(nn.Module): 120 | # def forward(self, mu, logvar): 121 | # return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 122 | -------------------------------------------------------------------------------- /models/networks/normalization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import re 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from models.networks.sync_batchnorm import SynchronizedBatchNorm2d 11 | import torch.nn.utils.spectral_norm as spectral_norm 12 | import os 13 | import numpy as np 14 | 15 | 16 | 17 | 18 | # Returns a function that creates a normalization function 19 | # that does not condition on semantic map 20 | def get_nonspade_norm_layer(opt, norm_type='instance'): 21 | # helper function to get # output channels of the previous layer 22 | def get_out_channel(layer): 23 | if hasattr(layer, 'out_channels'): 24 | return getattr(layer, 'out_channels') 25 | return layer.weight.size(0) 26 | 27 | # this function will be returned 28 | def add_norm_layer(layer): 29 | nonlocal norm_type 30 | if norm_type.startswith('spectral'): 31 | layer = spectral_norm(layer) 32 | subnorm_type = norm_type[len('spectral'):] 33 | 34 | if subnorm_type == 'none' or len(subnorm_type) == 0: 35 | return layer 36 | 37 | # remove bias in the previous layer, which is meaningless 38 | # since it has no effect after normalization 39 | if getattr(layer, 'bias', None) is not None: 40 | delattr(layer, 'bias') 41 | layer.register_parameter('bias', None) 42 | 43 | if subnorm_type == 'batch': 44 | norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) 45 | elif subnorm_type == 'sync_batch': 46 | norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) 47 | elif subnorm_type == 'instance': 48 | norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) 49 | else: 50 | raise ValueError('normalization layer %s is not recognized' % subnorm_type) 51 | 52 | return nn.Sequential(layer, norm_layer) 53 | 54 | return add_norm_layer 55 | 56 | 57 | # Creates SPADE normalization layer based on the given configuration 58 | # SPADE consists of two steps. First, it normalizes the activations using 59 | # your favorite normalization method, such as Batch Norm or Instance Norm. 60 | # Second, it applies scale and bias to the normalized output, conditioned on 61 | # the segmentation map. 62 | # The format of |config_text| is spade(norm)(ks), where 63 | # (norm) specifies the type of parameter-free normalization. 64 | # (e.g. syncbatch, batch, instance) 65 | # (ks) specifies the size of kernel in the SPADE module (e.g. 3x3) 66 | # Example |config_text| will be spadesyncbatch3x3, or spadeinstance5x5. 67 | # Also, the other arguments are 68 | # |norm_nc|: the #channels of the normalized activations, hence the output dim of SPADE 69 | # |label_nc|: the #channels of the input semantic map, hence the input dim of SPADE 70 | 71 | 72 | 73 | class ACE(nn.Module): 74 | def __init__(self, config_text, norm_nc, label_nc, ACE_Name=None, status='train', spade_params=None, use_rgb=True): 75 | super().__init__() 76 | 77 | self.ACE_Name = ACE_Name 78 | self.status = status 79 | self.save_npy = True 80 | self.Spade = SPADE(*spade_params) 81 | self.use_rgb = use_rgb 82 | self.style_length = 512 83 | self.Gsas = GSAS(self.style_length) 84 | self.blending_gamma = nn.Parameter(torch.zeros(1), requires_grad=True) 85 | self.blending_beta = nn.Parameter(torch.zeros(1), requires_grad=True) 86 | self.noise_var = nn.Parameter(torch.zeros(norm_nc), requires_grad=True) 87 | 88 | 89 | assert config_text.startswith('spade') 90 | parsed = re.search('spade(\D+)(\d)x\d', config_text) 91 | param_free_norm_type = str(parsed.group(1)) 92 | ks = int(parsed.group(2)) 93 | pw = ks // 2 94 | 95 | if param_free_norm_type == 'instance': 96 | self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) 97 | elif param_free_norm_type == 'syncbatch': 98 | self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) 99 | elif param_free_norm_type == 'batch': 100 | self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) 101 | else: 102 | raise ValueError('%s is not a recognized param-free norm type in SPADE' 103 | % param_free_norm_type) 104 | 105 | # The dimension of the intermediate embedding space. Yes, hardcoded. 106 | 107 | 108 | if self.use_rgb: 109 | self.create_gamma_beta_fc_layers() 110 | 111 | self.conv_gamma = nn.Conv2d(self.style_length, norm_nc, kernel_size=ks, padding=pw) 112 | self.conv_beta = nn.Conv2d(self.style_length, norm_nc, kernel_size=ks, padding=pw) 113 | 114 | 115 | def forward(self, x, segmap, style_codes=None, obj_dic=None): 116 | 117 | # Part 0. consider hidden representations between superpixels 118 | style_codes = self.Gsas(style_codes) 119 | 120 | # Part 1. generate parameter-free normalized activations 121 | added_noise = (torch.randn(x.shape[0], x.shape[3], x.shape[2], 1).cuda() * self.noise_var).transpose(1, 3) 122 | normalized = self.param_free_norm(x + added_noise) 123 | 124 | # Part 2. produce scaling and bias conditioned on semantic map 125 | segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') 126 | 127 | if self.use_rgb: 128 | [b_size, f_size, h_size, w_size] = normalized.shape 129 | middle_avg = torch.zeros((b_size, self.style_length, h_size, w_size), device=normalized.device) 130 | 131 | if self.status == 'UI_mode': 132 | ############## hard coding 133 | 134 | for i in range(1): 135 | for j in range(segmap.shape[1]): 136 | 137 | component_mask_area = torch.sum(segmap.bool()[i, j]) 138 | 139 | if component_mask_area > 0: 140 | if obj_dic is None: 141 | print('wrong even it is the first input') 142 | else: 143 | style_code_tmp = obj_dic[str(j)]['ACE'] 144 | 145 | middle_mu = F.relu(self.__getattr__('fc_mu' + str(j))(style_code_tmp)) 146 | component_mu = middle_mu.reshape(self.style_length, 1).expand(self.style_length,component_mask_area) 147 | 148 | middle_avg[i].masked_scatter_(segmap.bool()[i, j], component_mu) 149 | 150 | else: 151 | 152 | for i in range(b_size): 153 | for j in range(segmap.shape[1]): 154 | component_mask_area = torch.sum(segmap.bool()[i, j]) 155 | 156 | if component_mask_area > 0: 157 | 158 | #print('======================\n') 159 | #print(style_codes.shape) 160 | 161 | middle_mu = F.relu(self.__getattr__('fc_mu' + str(j))(style_codes[i][j])) 162 | component_mu = middle_mu.reshape(self.style_length, 1).expand(self.style_length, component_mask_area) 163 | 164 | middle_avg[i].masked_scatter_(segmap.bool()[i, j], component_mu) 165 | 166 | 167 | if self.status == 'ttttttt' and self.save_npy and self.ACE_Name=='up_2_ACE_0': 168 | tmp = style_codes[i][j].cpu().numpy() 169 | dir_path = 'styles_test' 170 | 171 | ############### some problem with obj_dic[i] 172 | 173 | im_name = os.path.basename(obj_dic[i]) 174 | folder_path = os.path.join(dir_path, 'style_codes', im_name, str(j)) 175 | if not os.path.exists(folder_path): 176 | os.makedirs(folder_path) 177 | 178 | style_code_path = os.path.join(folder_path, 'ACE.npy') 179 | np.save(style_code_path, tmp) 180 | 181 | 182 | gamma_avg = self.conv_gamma(middle_avg) 183 | beta_avg = self.conv_beta(middle_avg) 184 | 185 | 186 | gamma_spade, beta_spade = self.Spade(segmap) 187 | 188 | gamma_alpha = F.sigmoid(self.blending_gamma) 189 | beta_alpha = F.sigmoid(self.blending_beta) 190 | 191 | gamma_final = gamma_alpha * gamma_avg + (1 - gamma_alpha) * gamma_spade 192 | beta_final = beta_alpha * beta_avg + (1 - beta_alpha) * beta_spade 193 | out = normalized * (1 + gamma_final) + beta_final 194 | else: 195 | gamma_spade, beta_spade = self.Spade(segmap) 196 | gamma_final = gamma_spade 197 | beta_final = beta_spade 198 | out = normalized * (1 + gamma_final) + beta_final 199 | 200 | return out 201 | 202 | 203 | 204 | 205 | 206 | def create_gamma_beta_fc_layers(self): 207 | ################### These codes should be replaced with torch.nn.ModuleList 208 | 209 | style_length = self.style_length 210 | 211 | self.fc_mu0 = nn.Linear(style_length, style_length) 212 | self.fc_mu1 = nn.Linear(style_length, style_length) 213 | self.fc_mu2 = nn.Linear(style_length, style_length) 214 | self.fc_mu3 = nn.Linear(style_length, style_length) 215 | self.fc_mu4 = nn.Linear(style_length, style_length) 216 | self.fc_mu5 = nn.Linear(style_length, style_length) 217 | self.fc_mu6 = nn.Linear(style_length, style_length) 218 | self.fc_mu7 = nn.Linear(style_length, style_length) 219 | self.fc_mu8 = nn.Linear(style_length, style_length) 220 | self.fc_mu9 = nn.Linear(style_length, style_length) 221 | self.fc_mu10 = nn.Linear(style_length, style_length) 222 | self.fc_mu11 = nn.Linear(style_length, style_length) 223 | self.fc_mu12 = nn.Linear(style_length, style_length) 224 | self.fc_mu13 = nn.Linear(style_length, style_length) 225 | self.fc_mu14 = nn.Linear(style_length, style_length) 226 | self.fc_mu15 = nn.Linear(style_length, style_length) 227 | self.fc_mu16 = nn.Linear(style_length, style_length) 228 | self.fc_mu17 = nn.Linear(style_length, style_length) 229 | self.fc_mu18 = nn.Linear(style_length, style_length) 230 | # Below is for city scapes 231 | ''' 232 | self.fc_mu19 = nn.Linear(style_length, style_length) 233 | self.fc_mu20 = nn.Linear(style_length, style_length) 234 | self.fc_mu21 = nn.Linear(style_length, style_length) 235 | self.fc_mu22 = nn.Linear(style_length, style_length) 236 | self.fc_mu23 = nn.Linear(style_length, style_length) 237 | self.fc_mu24 = nn.Linear(style_length, style_length) 238 | self.fc_mu25 = nn.Linear(style_length, style_length) 239 | self.fc_mu26 = nn.Linear(style_length, style_length) 240 | self.fc_mu27 = nn.Linear(style_length, style_length) 241 | self.fc_mu28 = nn.Linear(style_length, style_length) 242 | self.fc_mu29 = nn.Linear(style_length, style_length) 243 | self.fc_mu30 = nn.Linear(style_length, style_length) 244 | self.fc_mu31 = nn.Linear(style_length, style_length) 245 | self.fc_mu32 = nn.Linear(style_length, style_length) 246 | self.fc_mu33 = nn.Linear(style_length, style_length) 247 | ''' 248 | 249 | 250 | class GSAS(nn.Module): 251 | def __init__(self, style_length): 252 | super().__init__() 253 | self.coefficient = nn.Conv2d(2, 1, kernel_size=1) 254 | self.softmax = nn.Softmax(dim=1) 255 | self.fc = nn.Linear(style_length * 2, style_length) 256 | 257 | def forward(self, style_code): 258 | [b_size, s_size, c_size] = style_code.shape 259 | r_style_code = style_code.reshape(b_size * s_size, 1, c_size) 260 | e_style_code = r_style_code.expand(-1, c_size, -1) 261 | t_style_code = e_style_code.permute(0, 2, 1) 262 | 263 | e_e_style_code = e_style_code.unsqueeze(1) 264 | t_e_style_code = t_style_code.unsqueeze(1) 265 | 266 | coefficient = torch.cat((e_e_style_code, t_e_style_code), 1) 267 | coefficient = self.coefficient(coefficient).squeeze() 268 | weights = self.softmax(coefficient) 269 | 270 | new_style_code = t_style_code * weights 271 | new_style_code = new_style_code.mean(1) 272 | 273 | #output = torch.cat((style_code.reshape(b_size * s_size, c_size), new_style_code), -1) 274 | #output = self.fc(output).reshape(b_size, s_size, c_size) 275 | output = style_code + new_style_code.reshape(b_size, s_size, c_size) 276 | return output 277 | 278 | 279 | 280 | class SPADE(nn.Module): 281 | def __init__(self, config_text, norm_nc, label_nc): 282 | super().__init__() 283 | 284 | assert config_text.startswith('spade') 285 | parsed = re.search('spade(\D+)(\d)x\d', config_text) 286 | param_free_norm_type = str(parsed.group(1)) 287 | ks = int(parsed.group(2)) 288 | 289 | if param_free_norm_type == 'instance': 290 | self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) 291 | elif param_free_norm_type == 'syncbatch': 292 | self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) 293 | elif param_free_norm_type == 'batch': 294 | self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) 295 | else: 296 | raise ValueError('%s is not a recognized param-free norm type in SPADE' 297 | % param_free_norm_type) 298 | 299 | # The dimension of the intermediate embedding space. Yes, hardcoded. 300 | nhidden = 128 301 | 302 | pw = ks // 2 303 | self.mlp_shared = nn.Sequential( 304 | nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), 305 | nn.ReLU() 306 | ) 307 | 308 | self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 309 | self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 310 | 311 | def forward(self, segmap): 312 | 313 | inputmap = segmap 314 | 315 | actv = self.mlp_shared(inputmap) 316 | gamma = self.mlp_gamma(actv) 317 | beta = self.mlp_beta(actv) 318 | 319 | return gamma, beta 320 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .batchnorm import convert_model 13 | from .replicate import DataParallelWithCallback, patch_replication_callback 14 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/sync_batchnorm/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/sync_batchnorm/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/sync_batchnorm/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__pycache__/comm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/sync_batchnorm/__pycache__/comm.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__pycache__/replicate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/sync_batchnorm/__pycache__/replicate.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__pycache__/replicate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/models/networks/sync_batchnorm/__pycache__/replicate.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | from .replicate import DataParallelWithCallback 21 | 22 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 23 | 'SynchronizedBatchNorm3d', 'convert_model'] 24 | 25 | 26 | def _sum_ft(tensor): 27 | """sum over the first and last dimention""" 28 | return tensor.sum(dim=0).sum(dim=-1) 29 | 30 | 31 | def _unsqueeze_ft(tensor): 32 | """add new dementions at the front and the tail""" 33 | return tensor.unsqueeze(0).unsqueeze(-1) 34 | 35 | 36 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 37 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 38 | 39 | 40 | class _SynchronizedBatchNorm(_BatchNorm): 41 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 42 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 43 | 44 | self._sync_master = SyncMaster(self._data_parallel_master) 45 | 46 | self._is_parallel = False 47 | self._parallel_id = None 48 | self._slave_pipe = None 49 | 50 | def forward(self, input): 51 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 52 | if not (self._is_parallel and self.training): 53 | return F.batch_norm( 54 | input, self.running_mean, self.running_var, self.weight, self.bias, 55 | self.training, self.momentum, self.eps) 56 | 57 | # Resize the input to (B, C, -1). 58 | input_shape = input.size() 59 | input = input.view(input.size(0), self.num_features, -1) 60 | 61 | # Compute the sum and square-sum. 62 | sum_size = input.size(0) * input.size(2) 63 | input_sum = _sum_ft(input) 64 | input_ssum = _sum_ft(input ** 2) 65 | 66 | # Reduce-and-broadcast the statistics. 67 | if self._parallel_id == 0: 68 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | else: 70 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 71 | 72 | # Compute the output. 73 | if self.affine: 74 | # MJY:: Fuse the multiplication for speed. 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 76 | else: 77 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 78 | 79 | # Reshape it. 80 | return output.view(input_shape) 81 | 82 | def __data_parallel_replicate__(self, ctx, copy_id): 83 | self._is_parallel = True 84 | self._parallel_id = copy_id 85 | 86 | # parallel_id == 0 means master device. 87 | if self._parallel_id == 0: 88 | ctx.sync_master = self._sync_master 89 | else: 90 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 91 | 92 | def _data_parallel_master(self, intermediates): 93 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 94 | 95 | # Always using same "device order" makes the ReduceAdd operation faster. 96 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 97 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 98 | 99 | to_reduce = [i[1][:2] for i in intermediates] 100 | to_reduce = [j for i in to_reduce for j in i] # flatten 101 | target_gpus = [i[1].sum.get_device() for i in intermediates] 102 | 103 | sum_size = sum([i[1].sum_size for i in intermediates]) 104 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 105 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 106 | 107 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 108 | 109 | outputs = [] 110 | for i, rec in enumerate(intermediates): 111 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 112 | 113 | return outputs 114 | 115 | def _compute_mean_std(self, sum_, ssum, size): 116 | """Compute the mean and standard-deviation with sum and square-sum. This method 117 | also maintains the moving average on the master device.""" 118 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 119 | mean = sum_ / size 120 | sumvar = ssum - sum_ * mean 121 | unbias_var = sumvar / (size - 1) 122 | bias_var = sumvar / size 123 | 124 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 125 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 126 | 127 | return mean, bias_var.clamp(self.eps) ** -0.5 128 | 129 | 130 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 131 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 132 | mini-batch. 133 | 134 | .. math:: 135 | 136 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 137 | 138 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 139 | standard-deviation are reduced across all devices during training. 140 | 141 | For example, when one uses `nn.DataParallel` to wrap the network during 142 | training, PyTorch's implementation normalize the tensor on each device using 143 | the statistics only on that device, which accelerated the computation and 144 | is also easy to implement, but the statistics might be inaccurate. 145 | Instead, in this synchronized version, the statistics will be computed 146 | over all training samples distributed on multiple devices. 147 | 148 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 149 | as the built-in PyTorch implementation. 150 | 151 | The mean and standard-deviation are calculated per-dimension over 152 | the mini-batches and gamma and beta are learnable parameter vectors 153 | of size C (where C is the input size). 154 | 155 | During training, this layer keeps a running estimate of its computed mean 156 | and variance. The running sum is kept with a default momentum of 0.1. 157 | 158 | During evaluation, this running mean/variance is used for normalization. 159 | 160 | Because the BatchNorm is done over the `C` dimension, computing statistics 161 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 162 | 163 | Args: 164 | num_features: num_features from an expected input of size 165 | `batch_size x num_features [x width]` 166 | eps: a value added to the denominator for numerical stability. 167 | Default: 1e-5 168 | momentum: the value used for the running_mean and running_var 169 | computation. Default: 0.1 170 | affine: a boolean value that when set to ``True``, gives the layer learnable 171 | affine parameters. Default: ``True`` 172 | 173 | Shape: 174 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 175 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 176 | 177 | Examples: 178 | >>> # With Learnable Parameters 179 | >>> m = SynchronizedBatchNorm1d(100) 180 | >>> # Without Learnable Parameters 181 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 182 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 183 | >>> output = m(input) 184 | """ 185 | 186 | def _check_input_dim(self, input): 187 | if input.dim() != 2 and input.dim() != 3: 188 | raise ValueError('expected 2D or 3D input (got {}D input)' 189 | .format(input.dim())) 190 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 191 | 192 | 193 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 194 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 195 | of 3d inputs 196 | 197 | .. math:: 198 | 199 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 200 | 201 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 202 | standard-deviation are reduced across all devices during training. 203 | 204 | For example, when one uses `nn.DataParallel` to wrap the network during 205 | training, PyTorch's implementation normalize the tensor on each device using 206 | the statistics only on that device, which accelerated the computation and 207 | is also easy to implement, but the statistics might be inaccurate. 208 | Instead, in this synchronized version, the statistics will be computed 209 | over all training samples distributed on multiple devices. 210 | 211 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 212 | as the built-in PyTorch implementation. 213 | 214 | The mean and standard-deviation are calculated per-dimension over 215 | the mini-batches and gamma and beta are learnable parameter vectors 216 | of size C (where C is the input size). 217 | 218 | During training, this layer keeps a running estimate of its computed mean 219 | and variance. The running sum is kept with a default momentum of 0.1. 220 | 221 | During evaluation, this running mean/variance is used for normalization. 222 | 223 | Because the BatchNorm is done over the `C` dimension, computing statistics 224 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 225 | 226 | Args: 227 | num_features: num_features from an expected input of 228 | size batch_size x num_features x height x width 229 | eps: a value added to the denominator for numerical stability. 230 | Default: 1e-5 231 | momentum: the value used for the running_mean and running_var 232 | computation. Default: 0.1 233 | affine: a boolean value that when set to ``True``, gives the layer learnable 234 | affine parameters. Default: ``True`` 235 | 236 | Shape: 237 | - Input: :math:`(N, C, H, W)` 238 | - Output: :math:`(N, C, H, W)` (same shape as input) 239 | 240 | Examples: 241 | >>> # With Learnable Parameters 242 | >>> m = SynchronizedBatchNorm2d(100) 243 | >>> # Without Learnable Parameters 244 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 245 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 246 | >>> output = m(input) 247 | """ 248 | 249 | def _check_input_dim(self, input): 250 | if input.dim() != 4: 251 | raise ValueError('expected 4D input (got {}D input)' 252 | .format(input.dim())) 253 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 254 | 255 | 256 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 257 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 258 | of 4d inputs 259 | 260 | .. math:: 261 | 262 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 263 | 264 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 265 | standard-deviation are reduced across all devices during training. 266 | 267 | For example, when one uses `nn.DataParallel` to wrap the network during 268 | training, PyTorch's implementation normalize the tensor on each device using 269 | the statistics only on that device, which accelerated the computation and 270 | is also easy to implement, but the statistics might be inaccurate. 271 | Instead, in this synchronized version, the statistics will be computed 272 | over all training samples distributed on multiple devices. 273 | 274 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 275 | as the built-in PyTorch implementation. 276 | 277 | The mean and standard-deviation are calculated per-dimension over 278 | the mini-batches and gamma and beta are learnable parameter vectors 279 | of size C (where C is the input size). 280 | 281 | During training, this layer keeps a running estimate of its computed mean 282 | and variance. The running sum is kept with a default momentum of 0.1. 283 | 284 | During evaluation, this running mean/variance is used for normalization. 285 | 286 | Because the BatchNorm is done over the `C` dimension, computing statistics 287 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 288 | or Spatio-temporal BatchNorm 289 | 290 | Args: 291 | num_features: num_features from an expected input of 292 | size batch_size x num_features x depth x height x width 293 | eps: a value added to the denominator for numerical stability. 294 | Default: 1e-5 295 | momentum: the value used for the running_mean and running_var 296 | computation. Default: 0.1 297 | affine: a boolean value that when set to ``True``, gives the layer learnable 298 | affine parameters. Default: ``True`` 299 | 300 | Shape: 301 | - Input: :math:`(N, C, D, H, W)` 302 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 303 | 304 | Examples: 305 | >>> # With Learnable Parameters 306 | >>> m = SynchronizedBatchNorm3d(100) 307 | >>> # Without Learnable Parameters 308 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 309 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 310 | >>> output = m(input) 311 | """ 312 | 313 | def _check_input_dim(self, input): 314 | if input.dim() != 5: 315 | raise ValueError('expected 5D input (got {}D input)' 316 | .format(input.dim())) 317 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) 318 | 319 | 320 | def convert_model(module): 321 | """Traverse the input module and its child recursively 322 | and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d 323 | to SynchronizedBatchNorm*N*d 324 | 325 | Args: 326 | module: the input module needs to be convert to SyncBN model 327 | 328 | Examples: 329 | >>> import torch.nn as nn 330 | >>> import torchvision 331 | >>> # m is a standard pytorch model 332 | >>> m = torchvision.models.resnet18(True) 333 | >>> m = nn.DataParallel(m) 334 | >>> # after convert, m is using SyncBN 335 | >>> m = convert_model(m) 336 | """ 337 | if isinstance(module, torch.nn.DataParallel): 338 | mod = module.module 339 | mod = convert_model(mod) 340 | mod = DataParallelWithCallback(mod) 341 | return mod 342 | 343 | mod = module 344 | for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, 345 | torch.nn.modules.batchnorm.BatchNorm2d, 346 | torch.nn.modules.batchnorm.BatchNorm3d], 347 | [SynchronizedBatchNorm1d, 348 | SynchronizedBatchNorm2d, 349 | SynchronizedBatchNorm3d]): 350 | if isinstance(module, pth_module): 351 | mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) 352 | mod.running_mean = module.running_mean 353 | mod.running_var = module.running_var 354 | if module.affine: 355 | mod.weight.data = module.weight.data.clone().detach() 356 | mod.bias.data = module.bias.data.clone().detach() 357 | 358 | for name, child in module.named_children(): 359 | mod.add_module(name, convert_model(child)) 360 | 361 | return mod 362 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNormReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /models/pix2pix_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | import models.networks as networks 8 | import util.util as util 9 | 10 | 11 | class Pix2PixModel(torch.nn.Module): 12 | @staticmethod 13 | def modify_commandline_options(parser, is_train): 14 | networks.modify_commandline_options(parser, is_train) 15 | return parser 16 | 17 | def __init__(self, opt): 18 | super().__init__() 19 | self.opt = opt 20 | self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \ 21 | else torch.FloatTensor 22 | self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \ 23 | else torch.ByteTensor 24 | #opt.isTrain = False 25 | self.netG, self.netD, self.netE = self.initialize_networks(opt) 26 | 27 | # set loss functions 28 | if opt.isTrain: 29 | self.criterionGAN = networks.GANLoss( 30 | opt.gan_mode, tensor=self.FloatTensor, opt=self.opt) 31 | self.criterionFeat = torch.nn.L1Loss() 32 | if not opt.no_vgg_loss: 33 | self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids) 34 | 35 | 36 | # Entry point for all calls involving forward pass 37 | # of deep networks. We used this approach since DataParallel module 38 | # can't parallelize custom functions, we branch to different 39 | # routines based on |mode|. 40 | def forward(self, data, mode): 41 | input_semantics, real_image, input_instance = self.preprocess_input(data) 42 | input_instance = input_instance.squeeze(1) 43 | #print(input_instance.shape) 44 | 45 | if mode == 'generator': 46 | g_loss, generated = self.compute_generator_loss( 47 | input_semantics, real_image, input_instance) 48 | return g_loss, generated 49 | elif mode == 'discriminator': 50 | d_loss = self.compute_discriminator_loss( 51 | input_semantics, real_image, input_instance) 52 | return d_loss 53 | elif mode == 'encode_only': 54 | z, mu, logvar = self.encode_z(real_image) 55 | return mu, logvar 56 | elif mode == 'inference': 57 | with torch.no_grad(): 58 | # fake_image, _ = self.generate_fake(input_semantics, real_image) 59 | #obj_dic = data['path'] 60 | #fake_image = self.save_style_codes(input_semantics, real_image, input_instance) 61 | fake_image = self.generate_fake(input_semantics, real_image, input_instance) 62 | return fake_image 63 | elif mode == 'UI_mode': 64 | with torch.no_grad(): 65 | # fake_image, _ = self.generate_fake(input_semantics, real_image) 66 | 67 | ################### some problems here 68 | obj_dic = data['obj_dic'] 69 | # if isinstance(obj_dic, str): 70 | # obj_dic = [obj_dic] 71 | fake_image = self.use_style_codes(input_semantics, real_image, obj_dic) 72 | return fake_image 73 | else: 74 | raise ValueError("|mode| is invalid") 75 | 76 | def create_optimizers(self, opt): 77 | G_params = list(self.netG.parameters()) 78 | if opt.use_vae: 79 | G_params += list(self.netE.parameters()) 80 | if opt.isTrain: 81 | D_params = list(self.netD.parameters()) 82 | 83 | if opt.no_TTUR: 84 | beta1, beta2 = opt.beta1, opt.beta2 85 | G_lr, D_lr = opt.lr, opt.lr 86 | else: 87 | beta1, beta2 = 0, 0.9 88 | G_lr, D_lr = opt.lr / 2, opt.lr * 2 89 | 90 | optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2)) 91 | optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2)) 92 | 93 | return optimizer_G, optimizer_D 94 | 95 | def save(self, epoch): 96 | util.save_network(self.netG, 'G', epoch, self.opt) 97 | util.save_network(self.netD, 'D', epoch, self.opt) 98 | 99 | ############################################################################ 100 | # Private helper methods 101 | ############################################################################ 102 | 103 | def initialize_networks(self, opt): 104 | netG = networks.define_G(opt) 105 | 106 | netD = networks.define_D(opt) if opt.isTrain else None 107 | netE = networks.define_E(opt) if opt.use_vae else None 108 | 109 | if not opt.isTrain or opt.continue_train: 110 | netG = util.load_network(netG, 'G', opt.which_epoch, opt) 111 | if opt.isTrain: 112 | netD = util.load_network(netD, 'D', opt.which_epoch, opt) 113 | 114 | return netG, netD, netE 115 | 116 | # preprocess the input, such as moving the tensors to GPUs and 117 | # transforming the label map to one-hot encoding 118 | # |data|: dictionary of the input data 119 | 120 | def preprocess_input(self, data): 121 | # move to GPU and change data types 122 | data['label'] = data['label'].long() 123 | if self.use_gpu(): 124 | data['label'] = data['label'].cuda(non_blocking=True) 125 | data['instance'] = data['instance'].cuda(non_blocking=True) 126 | data['image'] = data['image'].cuda(non_blocking=True) 127 | 128 | # create one-hot label map 129 | label_map = data['label'] 130 | bs, _, h, w = label_map.size() 131 | nc = self.opt.label_nc + 1 if self.opt.contain_dontcare_label \ 132 | else self.opt.label_nc 133 | input_label = self.FloatTensor(bs, nc, h, w).zero_() 134 | input_semantics = input_label.scatter_(1, label_map, 1.0) 135 | 136 | ''' 137 | # concatenate instance map if it exists 138 | if not self.opt.no_instance: 139 | inst_map = data['instance'] 140 | instance_edge_map = self.get_edges(inst_map) 141 | input_semantics = torch.cat((input_semantics, instance_edge_map), dim=1) 142 | ''' 143 | 144 | return input_semantics, data['image'], data['instance'] 145 | 146 | def compute_generator_loss(self, input_semantics, real_image, input_instance): 147 | G_losses = {} 148 | 149 | fake_image = self.generate_fake( 150 | input_semantics, real_image, input_instance, compute_kld_loss=self.opt.use_vae) 151 | 152 | 153 | pred_fake, pred_real = self.discriminate( 154 | input_semantics, fake_image, real_image) 155 | 156 | G_losses['GAN'] = self.criterionGAN(pred_fake, True, 157 | for_discriminator=False) 158 | 159 | if not self.opt.no_ganFeat_loss: 160 | num_D = len(pred_fake) 161 | GAN_Feat_loss = self.FloatTensor(1).fill_(0) 162 | for i in range(num_D): # for each discriminator 163 | # last output is the final prediction, so we exclude it 164 | num_intermediate_outputs = len(pred_fake[i]) - 1 165 | for j in range(num_intermediate_outputs): # for each layer output 166 | unweighted_loss = self.criterionFeat( 167 | pred_fake[i][j], pred_real[i][j].detach()) 168 | GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D 169 | G_losses['GAN_Feat'] = GAN_Feat_loss 170 | 171 | if not self.opt.no_vgg_loss: 172 | G_losses['VGG'] = self.criterionVGG(fake_image, real_image) \ 173 | * self.opt.lambda_vgg 174 | 175 | return G_losses, fake_image 176 | 177 | def compute_discriminator_loss(self, input_semantics, real_image, input_instance): 178 | D_losses = {} 179 | with torch.no_grad(): 180 | fake_image = self.generate_fake(input_semantics, real_image, input_instance) 181 | fake_image = fake_image.detach() 182 | fake_image.requires_grad_() 183 | 184 | pred_fake, pred_real = self.discriminate( 185 | input_semantics, fake_image, real_image) 186 | 187 | D_losses['D_Fake'] = self.criterionGAN(pred_fake, False, 188 | for_discriminator=True) 189 | D_losses['D_real'] = self.criterionGAN(pred_real, True, 190 | for_discriminator=True) 191 | 192 | return D_losses 193 | 194 | def encode_z(self, real_image): 195 | mu, logvar = self.netE(real_image) 196 | z = self.reparameterize(mu, logvar) 197 | return z, mu, logvar 198 | 199 | def generate_fake(self, input_semantics, real_image, input_instance, compute_kld_loss=False): 200 | 201 | 202 | fake_image = self.netG(input_semantics, real_image, input_instance) 203 | 204 | 205 | return fake_image 206 | 207 | ############################################################### 208 | 209 | def save_style_codes(self, input_semantics, real_image, obj_dic): 210 | 211 | fake_image = self.netG(input_semantics, real_image, obj_dic=obj_dic) 212 | 213 | return fake_image 214 | 215 | 216 | def use_style_codes(self, input_semantics, real_image, obj_dic): 217 | 218 | fake_image = self.netG(input_semantics, real_image, obj_dic=obj_dic) 219 | 220 | return fake_image 221 | 222 | 223 | 224 | # Given fake and real image, return the prediction of discriminator 225 | # for each fake and real image. 226 | 227 | def discriminate(self, input_semantics, fake_image, real_image): 228 | fake_concat = torch.cat([input_semantics, fake_image], dim=1) 229 | real_concat = torch.cat([input_semantics, real_image], dim=1) 230 | 231 | # In Batch Normalization, the fake and real images are 232 | # recommended to be in the same batch to avoid disparate 233 | # statistics in fake and real images. 234 | # So both fake and real images are fed to D all at once. 235 | fake_and_real = torch.cat([fake_concat, real_concat], dim=0) 236 | 237 | discriminator_out = self.netD(fake_and_real) 238 | 239 | pred_fake, pred_real = self.divide_pred(discriminator_out) 240 | 241 | return pred_fake, pred_real 242 | 243 | # Take the prediction of fake and real images from the combined batch 244 | def divide_pred(self, pred): 245 | # the prediction contains the intermediate outputs of multiscale GAN, 246 | # so it's usually a list 247 | if type(pred) == list: 248 | fake = [] 249 | real = [] 250 | for p in pred: 251 | fake.append([tensor[:tensor.size(0) // 2] for tensor in p]) 252 | real.append([tensor[tensor.size(0) // 2:] for tensor in p]) 253 | else: 254 | fake = pred[:pred.size(0) // 2] 255 | real = pred[pred.size(0) // 2:] 256 | 257 | return fake, real 258 | 259 | def get_edges(self, t): 260 | edge = self.ByteTensor(t.size()).zero_() 261 | edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) 262 | edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) 263 | edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) 264 | edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) 265 | return edge.float() 266 | 267 | def reparameterize(self, mu, logvar): 268 | std = torch.exp(0.5 * logvar) 269 | eps = torch.randn_like(std) 270 | return eps.mul(std) + mu 271 | 272 | def use_gpu(self): 273 | return len(self.opt.gpu_ids) > 0 274 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ -------------------------------------------------------------------------------- /options/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/options/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/options/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /options/__pycache__/base_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/options/__pycache__/base_options.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/base_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/options/__pycache__/base_options.cpython-37.pyc -------------------------------------------------------------------------------- /options/__pycache__/test_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/options/__pycache__/test_options.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/test_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/options/__pycache__/test_options.cpython-37.pyc -------------------------------------------------------------------------------- /options/__pycache__/train_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/options/__pycache__/train_options.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/train_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/options/__pycache__/train_options.cpython-37.pyc -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import sys 7 | import argparse 8 | import os 9 | from util import util 10 | import torch 11 | import models 12 | import data 13 | import pickle 14 | 15 | 16 | class BaseOptions(): 17 | def __init__(self): 18 | self.initialized = False 19 | 20 | def initialize(self, parser): 21 | # experiment specifics 22 | parser.add_argument('--name', type=str, default='slic_test_first_back', help='name of the experiment. It decides where to store samples and models') 23 | 24 | parser.add_argument('--gpu_ids', type=str, default='3', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 25 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 26 | parser.add_argument('--model', type=str, default='pix2pix', help='which model to use') 27 | parser.add_argument('--norm_G', type=str, default='spectralinstance', help='instance normalization or batch normalization') 28 | parser.add_argument('--norm_D', type=str, default='spectralinstance', help='instance normalization or batch normalization') 29 | parser.add_argument('--norm_E', type=str, default='spectralinstance', help='instance normalization or batch normalization') 30 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 31 | 32 | # input/output sizes 33 | parser.add_argument('--batchSize', type=int, default=1, help='input batch size') 34 | parser.add_argument('--preprocess_mode', type=str, default='scale_width_and_crop', help='scaling and cropping of images at load time.', choices=("resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "scale_shortside", "scale_shortside_and_crop", "fixed", "none")) 35 | parser.add_argument('--load_size', type=int, default=256, help='Scale images to this size. The final image will be cropped to --crop_size.') 36 | parser.add_argument('--crop_size', type=int, default=256, help='Crop to the width of crop_size (after initially scaling the images to load_size.)') 37 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='The ratio width/height. The final height of the load image will be crop_size/aspect_ratio') 38 | parser.add_argument('--label_nc', type=int, default=19, help='# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.') 39 | parser.add_argument('--contain_dontcare_label', action='store_true', help='if the label map contains dontcare label (dontcare=255)') 40 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 41 | 42 | # for setting inputs 43 | parser.add_argument('--dataroot', type=str, default='./datasets/cityscapes/') 44 | parser.add_argument('--dataset_mode', type=str, default='coco') 45 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 46 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') 47 | parser.add_argument('--nThreads', default=28, type=int, help='# threads for loading data') 48 | parser.add_argument('--max_dataset_size', type=int, default=sys.maxsize, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 49 | parser.add_argument('--load_from_opt_file', action='store_true', help='load the options from checkpoints and use that as default') 50 | parser.add_argument('--cache_filelist_write', action='store_true', help='saves the current filelist into a text file, so that it loads faster') 51 | parser.add_argument('--cache_filelist_read', action='store_true', help='reads from the file list cache') 52 | 53 | # for displays 54 | parser.add_argument('--display_winsize', type=int, default=400, help='display window size') 55 | 56 | # for generator 57 | parser.add_argument('--netG', type=str, default='spade', help='selects model to use for netG (pix2pixhd | spade)') 58 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 59 | parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]') 60 | parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution') 61 | parser.add_argument('--z_dim', type=int, default=512, 62 | help="dimension of the latent z vector") 63 | 64 | # for instance-wise features 65 | parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input') 66 | parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer') 67 | parser.add_argument('--use_vae', action='store_true', help='enable training with an image encoder.') 68 | 69 | self.initialized = True 70 | return parser 71 | 72 | def gather_options(self): 73 | # initialize parser with basic options 74 | if not self.initialized: 75 | parser = argparse.ArgumentParser( 76 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 77 | parser = self.initialize(parser) 78 | 79 | # get the basic options 80 | opt, unknown = parser.parse_known_args() 81 | 82 | # modify model-related parser options 83 | model_name = opt.model 84 | model_option_setter = models.get_option_setter(model_name) 85 | parser = model_option_setter(parser, self.isTrain) 86 | 87 | # modify dataset-related parser options 88 | dataset_mode = opt.dataset_mode 89 | dataset_option_setter = data.get_option_setter(dataset_mode) 90 | parser = dataset_option_setter(parser, self.isTrain) 91 | 92 | opt, unknown = parser.parse_known_args() 93 | 94 | # if there is opt_file, load it. 95 | # The previous default options will be overwritten 96 | if opt.load_from_opt_file: 97 | parser = self.update_options_from_file(parser, opt) 98 | 99 | opt = parser.parse_args() 100 | self.parser = parser 101 | return opt 102 | 103 | def print_options(self, opt): 104 | message = '' 105 | message += '----------------- Options ---------------\n' 106 | for k, v in sorted(vars(opt).items()): 107 | comment = '' 108 | default = self.parser.get_default(k) 109 | if v != default: 110 | comment = '\t[default: %s]' % str(default) 111 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 112 | message += '----------------- End -------------------' 113 | print(message) 114 | 115 | def option_file_path(self, opt, makedir=False): 116 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 117 | if makedir: 118 | util.mkdirs(expr_dir) 119 | file_name = os.path.join(expr_dir, 'opt') 120 | return file_name 121 | 122 | def save_options(self, opt): 123 | file_name = self.option_file_path(opt, makedir=True) 124 | with open(file_name + '.txt', 'wt') as opt_file: 125 | for k, v in sorted(vars(opt).items()): 126 | comment = '' 127 | default = self.parser.get_default(k) 128 | if v != default: 129 | comment = '\t[default: %s]' % str(default) 130 | opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)) 131 | 132 | with open(file_name + '.pkl', 'wb') as opt_file: 133 | pickle.dump(opt, opt_file) 134 | 135 | def update_options_from_file(self, parser, opt): 136 | new_opt = self.load_options(opt) 137 | for k, v in sorted(vars(opt).items()): 138 | if hasattr(new_opt, k) and v != getattr(new_opt, k): 139 | new_val = getattr(new_opt, k) 140 | parser.set_defaults(**{k: new_val}) 141 | return parser 142 | 143 | def load_options(self, opt): 144 | file_name = self.option_file_path(opt, makedir=False) 145 | new_opt = pickle.load(open(file_name + '.pkl', 'rb')) 146 | return new_opt 147 | 148 | def parse(self, save=False): 149 | 150 | opt = self.gather_options() 151 | opt.isTrain = self.isTrain # train or test 152 | 153 | self.print_options(opt) 154 | if opt.isTrain: 155 | self.save_options(opt) 156 | 157 | # Set semantic_nc based on the option. 158 | # This will be convenient in many places 159 | opt.semantic_nc = opt.label_nc + \ 160 | (1 if opt.contain_dontcare_label else 0) 161 | #+ (0 if opt.no_instance else 1) 162 | 163 | # set gpu ids 164 | str_ids = opt.gpu_ids.split(',') 165 | opt.gpu_ids = [] 166 | for str_id in str_ids: 167 | id = int(str_id) 168 | if id >= 0: 169 | opt.gpu_ids.append(id) 170 | if len(opt.gpu_ids) > 0: 171 | torch.cuda.set_device(opt.gpu_ids[0]) 172 | 173 | assert len(opt.gpu_ids) == 0 or opt.batchSize % len(opt.gpu_ids) == 0, \ 174 | "Batch size %d is wrong. It must be a multiple of # GPUs %d." \ 175 | % (opt.batchSize, len(opt.gpu_ids)) 176 | 177 | self.opt = opt 178 | return self.opt 179 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from .base_options import BaseOptions 7 | 8 | 9 | class TestOptions(BaseOptions): 10 | def initialize(self, parser): 11 | BaseOptions.initialize(self, parser) 12 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 13 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 14 | parser.add_argument('--how_many', type=int, default=float("inf"), help='how many test images to run') 15 | 16 | parser.set_defaults(preprocess_mode='scale_width_and_crop', crop_size=256, load_size=256, display_winsize=256) 17 | parser.set_defaults(serial_batches=True) 18 | parser.set_defaults(no_flip=True) 19 | parser.set_defaults(phase='test') 20 | 21 | parser.add_argument('--status', type=str, default='test') 22 | 23 | self.isTrain = False 24 | return parser 25 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from .base_options import BaseOptions 7 | 8 | 9 | class TrainOptions(BaseOptions): 10 | def initialize(self, parser): 11 | BaseOptions.initialize(self, parser) 12 | # for displays 13 | parser.add_argument('--display_freq', type=int, default=2970, help='frequency of showing training results on screen') 14 | parser.add_argument('--print_freq', type=int, default=2970, help='frequency of showing training results on console') 15 | parser.add_argument('--save_latest_freq', type=int, default=2970, help='frequency of saving the latest results') 16 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') 17 | parser.add_argument('--no_html', default=False, help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 18 | parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration') 19 | parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed') 20 | 21 | # for training 22 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 23 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 24 | parser.add_argument('--niter', type=int, default=50, help='# of iter at starting learning rate. This is NOT the total #epochs. Totla #epochs is niter + niter_decay') 25 | parser.add_argument('--niter_decay', type=int, default=150, help='# of iter to linearly decay learning rate to zero') 26 | parser.add_argument('--optimizer', type=str, default='adam') 27 | parser.add_argument('--beta1', type=float, default=0.0, help='momentum term of adam') 28 | parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam') 29 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 30 | parser.add_argument('--D_steps_per_G', type=int, default=1, help='number of discriminator iterations per generator iterations.') 31 | 32 | # for discriminators 33 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 34 | parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') 35 | parser.add_argument('--lambda_vgg', type=float, default=10.0, help='weight for vgg loss') 36 | parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss') 37 | parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss') 38 | parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)') 39 | parser.add_argument('--netD', type=str, default='multiscale', help='(n_layers|multiscale|image)') 40 | parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme') 41 | parser.add_argument('--lambda_kld', type=float, default=0.005) 42 | 43 | parser.add_argument('--status', type=str, default='train') 44 | 45 | self.isTrain = True 46 | return parser 47 | -------------------------------------------------------------------------------- /save_mslic.py: -------------------------------------------------------------------------------- 1 | 2 | import maskslic as seg 3 | import scipy.interpolate as interp 4 | import numpy as np 5 | 6 | 7 | def _adaptive_interp(input, num_of_style_feature): 8 | if len(input) < 3: 9 | return input 10 | if len(input) < num_of_style_feature: 11 | 12 | x = np.linspace(0, num_of_style_feature - 1, num=len(input)) 13 | x_new = np.linspace(0, num_of_style_feature - 1, num=num_of_style_feature) 14 | 15 | interp_out = interp.InterpolatedUnivariateSpline(x, input) 16 | output = interp_out(x_new) 17 | else: 18 | output = input 19 | 20 | return output 21 | 22 | 23 | def _encoding(label_field, image, num_of_vectors, bg_label=-1): 24 | ''' 25 | Generating Style codes 26 | :param label_field: super-pixel output 27 | :param image: style image 28 | :param bg_label: background label in super-pixel 29 | :return: style codes 30 | ''' 31 | #lab = color.rgb2yuv(image) 32 | lab = image 33 | #lab = color.rgb2lab(image) 34 | l = [] 35 | a = [] 36 | b = [] 37 | labels = np.unique(label_field) 38 | bg = (labels == bg_label) 39 | if bg.any(): 40 | labels = labels[labels != bg_label] 41 | for label in labels: 42 | mask = (label_field == label).nonzero() 43 | feature = lab[mask].mean(axis=0) 44 | l.append(feature[0]) 45 | a.append(feature[1]) 46 | b.append(feature[2]) 47 | l = np.reshape(l, (-1)) 48 | l = _adaptive_interp(l, num_of_vectors) 49 | a = np.reshape(a, (-1)) 50 | a = _adaptive_interp(a, num_of_vectors) 51 | b = np.reshape(b, (-1)) 52 | b = _adaptive_interp(b, num_of_vectors) 53 | out = np.reshape([l, a, b], (-1)) 54 | 55 | out = _adaptive_interp(out, 512) 56 | out = np.reshape(out, (-1,)) 57 | 58 | 59 | return out 60 | 61 | 62 | def _style_encoder(images, masks, num_of_style=512, num_of_vectors=128): 63 | style_vectors = [] 64 | null = np.zeros(num_of_style) 65 | for i in range(len(images)): 66 | styles = [] 67 | for j in range(np.shape(masks)[-1]): 68 | num_of_component_pixel = np.count_nonzero(masks[i, :, :, j]) 69 | if num_of_component_pixel > 0: 70 | try: 71 | m_slic = seg.slic(images[i], compactness=10, seed_type='nplace', mask=masks[i, :, :, j], 72 | n_segments=num_of_vectors, recompute_seeds=True, enforce_connectivity=False) 73 | style_vector = _encoding(m_slic, images[i], num_of_vectors) 74 | styles.append(style_vector) 75 | except: 76 | styles.append(null) 77 | else: 78 | styles.append(null) 79 | style_vectors.append(styles) 80 | style_vectors = np.reshape(style_vectors, (np.shape(masks)[-1], len(images), 1, 1, num_of_style)) 81 | 82 | return style_vectors 83 | 84 | -------------------------------------------------------------------------------- /save_style_vector.py: -------------------------------------------------------------------------------- 1 | from save_mslic import _style_encoder 2 | import numpy as np 3 | import cv2 4 | import os 5 | from tqdm import tqdm 6 | import glob 7 | 8 | save_path = './cityscapes/test/codes' 9 | data_path = './cityscapes/test/image' 10 | anno_path = './cityscapes/test/label' 11 | 12 | data = glob.glob(os.path.join(data_path, '*')) 13 | 14 | class data_augmentation(object): 15 | def __init__(self, img_list, anno_path, save_path): 16 | 17 | self.img_list = img_list 18 | self.anno_path = anno_path 19 | self.save_path = save_path 20 | if not os.path.exists(save_path): 21 | os.mkdir(save_path) 22 | 23 | def open_mask(self, path, width, height, nlabel, isResize=True): 24 | 25 | mask = [] 26 | image_name = path.split('/')[-1] 27 | anno = cv2.imread(self.anno_path + '/' + image_name) 28 | if isResize: 29 | anno = cv2.resize(anno, (width, height), interpolation=cv2.INTER_NEAREST) 30 | anno = anno[:, :, 0] 31 | for idx in range(nlabel): 32 | null = np.zeros_like(anno) 33 | if idx not in anno: 34 | mask.append(null) 35 | else: 36 | null[anno == idx] = 1 37 | mask.append(null) 38 | mask = np.array(mask) 39 | 40 | return mask 41 | 42 | def open_image(self, path, width, height, isResize=True): 43 | img = cv2.imread(os.path.join(path)) 44 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 45 | if isResize: 46 | img = cv2.resize(img, (width, height), interpolation=cv2.INTER_CUBIC) 47 | img = img.astype(np.float32)/255.0 48 | return img 49 | 50 | def next_batch(self, width, height, nlabel): 51 | for i in tqdm(self.img_list): 52 | 53 | input_img = i.split('/')[-1] 54 | number = input_img.split('.')[0] 55 | 56 | mask_img = self.open_mask(i, width, height, nlabel) 57 | 58 | mask_img = np.transpose(np.array(mask_img), (1, 2, 0)) 59 | 60 | img = self.open_image(i, width, height) 61 | 62 | styles = _style_encoder(np.array([img]), np.array([mask_img])) 63 | 64 | np.save(save_path + '/%d' % int(number), np.array(styles)) 65 | 66 | 67 | data_generator = data_augmentation(data, anno_path, save_path) 68 | 69 | data_generator.next_batch(512, 512, nlabel=19) # You should change nlabel depending on dataset 70 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import os 7 | from collections import OrderedDict 8 | 9 | import data 10 | from options.test_options import TestOptions 11 | from models.pix2pix_model import Pix2PixModel 12 | from util.visualizer import Visualizer 13 | from util import html 14 | from tqdm import tqdm 15 | 16 | opt = TestOptions().parse() 17 | opt.status = 'test' 18 | 19 | dataloader = data.create_dataloader(opt) 20 | 21 | model = Pix2PixModel(opt) 22 | model.eval() 23 | 24 | visualizer = Visualizer(opt) 25 | 26 | # create a webpage that summarizes the all results 27 | web_dir = os.path.join(opt.results_dir, opt.name, 28 | '%s_%s' % (opt.phase, opt.which_epoch)) 29 | webpage = html.HTML(web_dir, 30 | 'Experiment = %s, Phase = %s, Epoch = %s' % 31 | (opt.name, opt.phase, opt.which_epoch)) 32 | 33 | # test 34 | 35 | for i, data_i in enumerate(tqdm(dataloader)): 36 | if i * opt.batchSize >= opt.how_many: 37 | break 38 | 39 | generated = model(data_i, mode='inference') 40 | 41 | img_path = data_i['path'] 42 | for b in range(generated.shape[0]): 43 | #print('process image... %s' % img_path[b]) 44 | visuals = OrderedDict([('input_label', data_i['label'][b]), 45 | ('synthesized_image', generated[b]), 46 | ('real_image', data_i['image'][b])]) 47 | visualizer.save_images(webpage, visuals, img_path[b:b + 1]) 48 | webpage.save() 49 | 50 | 51 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import sys 7 | from collections import OrderedDict 8 | from options.train_options import TrainOptions 9 | import data 10 | from util.iter_counter import IterationCounter 11 | from util.visualizer import Visualizer 12 | from trainers.pix2pix_trainer import Pix2PixTrainer 13 | from tqdm import tqdm 14 | 15 | 16 | # ######## added stuff 17 | # import os 18 | # os.environ["CUDA_VISIBLE_DEVICES"]="1" 19 | 20 | 21 | # parse options 22 | opt = TrainOptions().parse() 23 | 24 | # print options to help debugging 25 | print(' '.join(sys.argv)) 26 | 27 | # load the dataset 28 | dataloader = data.create_dataloader(opt) 29 | 30 | # create trainer for our model 31 | trainer = Pix2PixTrainer(opt) 32 | 33 | # create tool for counting iterations 34 | iter_counter = IterationCounter(opt, len(dataloader)) 35 | 36 | # create tool for visualization 37 | visualizer = Visualizer(opt) 38 | 39 | 40 | for epoch in iter_counter.training_epochs(): 41 | iter_counter.record_epoch_start(epoch) 42 | for i, data_i in enumerate(tqdm(dataloader), start=iter_counter.epoch_iter): 43 | iter_counter.record_one_iteration() 44 | # Training 45 | # train generator 46 | if i % opt.D_steps_per_G == 0: 47 | trainer.run_generator_one_step(data_i) 48 | 49 | # train discriminator 50 | trainer.run_discriminator_one_step(data_i) 51 | 52 | # Visualizations 53 | if iter_counter.needs_printing(): 54 | visuals = OrderedDict([('synthesized_image', trainer.get_latest_generated()), 55 | ('real_image', data_i['image'])]) 56 | 57 | losses = trainer.get_latest_losses() 58 | visualizer.print_current_errors(epoch, iter_counter.epoch_iter, 59 | losses, iter_counter.time_per_iter, visuals) 60 | visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) 61 | 62 | #visualizer.print_current_eval(visuals) 63 | 64 | if iter_counter.needs_displaying(): 65 | visuals = OrderedDict([('input_label', data_i['label']), 66 | ('synthesized_image', trainer.get_latest_generated()), 67 | ('real_image', data_i['image'])]) 68 | visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far) 69 | 70 | if iter_counter.needs_saving(): 71 | print('saving the latest model (epoch %d, total_steps %d)' % 72 | (epoch, iter_counter.total_steps_so_far)) 73 | trainer.save('latest') 74 | iter_counter.record_current_iter() 75 | 76 | trainer.update_learning_rate(epoch) 77 | iter_counter.record_epoch_end() 78 | 79 | if epoch % opt.save_epoch_freq == 0 or \ 80 | epoch == iter_counter.total_epochs: 81 | print('saving the model at the end of epoch %d, iters %d' % 82 | (epoch, iter_counter.total_steps_so_far)) 83 | trainer.save('latest') 84 | trainer.save(epoch) 85 | 86 | print('Training was successfully finished.') 87 | 88 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | -------------------------------------------------------------------------------- /trainers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/trainers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/trainers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/pix2pix_trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/trainers/__pycache__/pix2pix_trainer.cpython-36.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/pix2pix_trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BenjaminJonghyun/SuperStyleNet/2400c01b35f50b387b5f768fdece37688a077049/trainers/__pycache__/pix2pix_trainer.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/pix2pix_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from models.networks.sync_batchnorm import DataParallelWithCallback 7 | from models.pix2pix_model import Pix2PixModel 8 | 9 | 10 | class Pix2PixTrainer(): 11 | """ 12 | Trainer creates the model and optimizers, and uses them to 13 | updates the weights of the network while reporting losses 14 | and the latest visuals to visualize the progress in training. 15 | """ 16 | 17 | def __init__(self, opt): 18 | self.opt = opt 19 | self.pix2pix_model = Pix2PixModel(opt) 20 | if len(opt.gpu_ids) > 0: 21 | self.pix2pix_model = DataParallelWithCallback(self.pix2pix_model, 22 | device_ids=opt.gpu_ids) 23 | self.pix2pix_model_on_one_gpu = self.pix2pix_model.module 24 | else: 25 | self.pix2pix_model_on_one_gpu = self.pix2pix_model 26 | 27 | self.generated = None 28 | if opt.isTrain: 29 | self.optimizer_G, self.optimizer_D = \ 30 | self.pix2pix_model_on_one_gpu.create_optimizers(opt) 31 | self.old_lr = opt.lr 32 | 33 | def run_generator_one_step(self, data): 34 | self.optimizer_G.zero_grad() 35 | g_losses, generated = self.pix2pix_model(data, mode='generator') 36 | g_loss = sum(g_losses.values()).mean() 37 | g_loss.backward() 38 | self.optimizer_G.step() 39 | self.g_losses = g_losses 40 | self.generated = generated 41 | 42 | def run_discriminator_one_step(self, data): 43 | self.optimizer_D.zero_grad() 44 | d_losses = self.pix2pix_model(data, mode='discriminator') 45 | d_loss = sum(d_losses.values()).mean() 46 | d_loss.backward() 47 | self.optimizer_D.step() 48 | self.d_losses = d_losses 49 | 50 | def run_eval(self, data): 51 | generated = self.pix2pix_model(data, mode='inference') 52 | 53 | 54 | def get_latest_losses(self): 55 | return {**self.g_losses, **self.d_losses} 56 | 57 | def get_latest_generated(self): 58 | return self.generated 59 | 60 | def update_learning_rate(self, epoch): 61 | self.update_learning_rate(epoch) 62 | 63 | def save(self, epoch): 64 | self.pix2pix_model_on_one_gpu.save(epoch) 65 | 66 | ################################################################## 67 | # Helper functions 68 | ################################################################## 69 | 70 | def update_learning_rate(self, epoch): 71 | if epoch > self.opt.niter: 72 | lrd = self.opt.lr / self.opt.niter_decay 73 | new_lr = self.old_lr - lrd 74 | else: 75 | new_lr = self.old_lr 76 | 77 | if new_lr != self.old_lr: 78 | if self.opt.no_TTUR: 79 | new_lr_G = new_lr 80 | new_lr_D = new_lr 81 | else: 82 | new_lr_G = new_lr / 2 83 | new_lr_D = new_lr * 2 84 | 85 | for param_group in self.optimizer_D.param_groups: 86 | param_group['lr'] = new_lr_D 87 | for param_group in self.optimizer_G.param_groups: 88 | param_group['lr'] = new_lr_G 89 | print('update learning rate: %f -> %f' % (self.old_lr, new_lr)) 90 | self.old_lr = new_lr 91 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | -------------------------------------------------------------------------------- /util/coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | 7 | def id2label(id): 8 | if id == 182: 9 | id = 0 10 | else: 11 | id = id + 1 12 | labelmap = \ 13 | {0: 'unlabeled', 14 | 1: 'person', 15 | 2: 'bicycle', 16 | 3: 'car', 17 | 4: 'motorcycle', 18 | 5: 'airplane', 19 | 6: 'bus', 20 | 7: 'train', 21 | 8: 'truck', 22 | 9: 'boat', 23 | 10: 'traffic light', 24 | 11: 'fire hydrant', 25 | 12: 'street sign', 26 | 13: 'stop sign', 27 | 14: 'parking meter', 28 | 15: 'bench', 29 | 16: 'bird', 30 | 17: 'cat', 31 | 18: 'dog', 32 | 19: 'horse', 33 | 20: 'sheep', 34 | 21: 'cow', 35 | 22: 'elephant', 36 | 23: 'bear', 37 | 24: 'zebra', 38 | 25: 'giraffe', 39 | 26: 'hat', 40 | 27: 'backpack', 41 | 28: 'umbrella', 42 | 29: 'shoe', 43 | 30: 'eye glasses', 44 | 31: 'handbag', 45 | 32: 'tie', 46 | 33: 'suitcase', 47 | 34: 'frisbee', 48 | 35: 'skis', 49 | 36: 'snowboard', 50 | 37: 'sports ball', 51 | 38: 'kite', 52 | 39: 'baseball bat', 53 | 40: 'baseball glove', 54 | 41: 'skateboard', 55 | 42: 'surfboard', 56 | 43: 'tennis racket', 57 | 44: 'bottle', 58 | 45: 'plate', 59 | 46: 'wine glass', 60 | 47: 'cup', 61 | 48: 'fork', 62 | 49: 'knife', 63 | 50: 'spoon', 64 | 51: 'bowl', 65 | 52: 'banana', 66 | 53: 'apple', 67 | 54: 'sandwich', 68 | 55: 'orange', 69 | 56: 'broccoli', 70 | 57: 'carrot', 71 | 58: 'hot dog', 72 | 59: 'pizza', 73 | 60: 'donut', 74 | 61: 'cake', 75 | 62: 'chair', 76 | 63: 'couch', 77 | 64: 'potted plant', 78 | 65: 'bed', 79 | 66: 'mirror', 80 | 67: 'dining table', 81 | 68: 'window', 82 | 69: 'desk', 83 | 70: 'toilet', 84 | 71: 'door', 85 | 72: 'tv', 86 | 73: 'laptop', 87 | 74: 'mouse', 88 | 75: 'remote', 89 | 76: 'keyboard', 90 | 77: 'cell phone', 91 | 78: 'microwave', 92 | 79: 'oven', 93 | 80: 'toaster', 94 | 81: 'sink', 95 | 82: 'refrigerator', 96 | 83: 'blender', 97 | 84: 'book', 98 | 85: 'clock', 99 | 86: 'vase', 100 | 87: 'scissors', 101 | 88: 'teddy bear', 102 | 89: 'hair drier', 103 | 90: 'toothbrush', 104 | 91: 'hair brush', # Last class of Thing 105 | 92: 'banner', # Beginning of Stuff 106 | 93: 'blanket', 107 | 94: 'branch', 108 | 95: 'bridge', 109 | 96: 'building-other', 110 | 97: 'bush', 111 | 98: 'cabinet', 112 | 99: 'cage', 113 | 100: 'cardboard', 114 | 101: 'carpet', 115 | 102: 'ceiling-other', 116 | 103: 'ceiling-tile', 117 | 104: 'cloth', 118 | 105: 'clothes', 119 | 106: 'clouds', 120 | 107: 'counter', 121 | 108: 'cupboard', 122 | 109: 'curtain', 123 | 110: 'desk-stuff', 124 | 111: 'dirt', 125 | 112: 'door-stuff', 126 | 113: 'fence', 127 | 114: 'floor-marble', 128 | 115: 'floor-other', 129 | 116: 'floor-stone', 130 | 117: 'floor-tile', 131 | 118: 'floor-wood', 132 | 119: 'flower', 133 | 120: 'fog', 134 | 121: 'food-other', 135 | 122: 'fruit', 136 | 123: 'furniture-other', 137 | 124: 'grass', 138 | 125: 'gravel', 139 | 126: 'ground-other', 140 | 127: 'hill', 141 | 128: 'house', 142 | 129: 'leaves', 143 | 130: 'light', 144 | 131: 'mat', 145 | 132: 'metal', 146 | 133: 'mirror-stuff', 147 | 134: 'moss', 148 | 135: 'mountain', 149 | 136: 'mud', 150 | 137: 'napkin', 151 | 138: 'net', 152 | 139: 'paper', 153 | 140: 'pavement', 154 | 141: 'pillow', 155 | 142: 'plant-other', 156 | 143: 'plastic', 157 | 144: 'platform', 158 | 145: 'playingfield', 159 | 146: 'railing', 160 | 147: 'railroad', 161 | 148: 'river', 162 | 149: 'road', 163 | 150: 'rock', 164 | 151: 'roof', 165 | 152: 'rug', 166 | 153: 'salad', 167 | 154: 'sand', 168 | 155: 'sea', 169 | 156: 'shelf', 170 | 157: 'sky-other', 171 | 158: 'skyscraper', 172 | 159: 'snow', 173 | 160: 'solid-other', 174 | 161: 'stairs', 175 | 162: 'stone', 176 | 163: 'straw', 177 | 164: 'structural-other', 178 | 165: 'table', 179 | 166: 'tent', 180 | 167: 'textile-other', 181 | 168: 'towel', 182 | 169: 'tree', 183 | 170: 'vegetable', 184 | 171: 'wall-brick', 185 | 172: 'wall-concrete', 186 | 173: 'wall-other', 187 | 174: 'wall-panel', 188 | 175: 'wall-stone', 189 | 176: 'wall-tile', 190 | 177: 'wall-wood', 191 | 178: 'water-other', 192 | 179: 'waterdrops', 193 | 180: 'window-blind', 194 | 181: 'window-other', 195 | 182: 'wood'} 196 | if id in labelmap: 197 | return labelmap[id] 198 | else: 199 | return 'unknown' 200 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import datetime 7 | import dominate 8 | from dominate.tags import * 9 | import os 10 | 11 | 12 | class HTML: 13 | def __init__(self, web_dir, title, refresh=0): 14 | if web_dir.endswith('.html'): 15 | web_dir, html_name = os.path.split(web_dir) 16 | else: 17 | web_dir, html_name = web_dir, 'index.html' 18 | self.title = title 19 | self.web_dir = web_dir 20 | self.html_name = html_name 21 | self.img_dir = os.path.join(self.web_dir, 'images') 22 | if len(self.web_dir) > 0 and not os.path.exists(self.web_dir): 23 | os.makedirs(self.web_dir) 24 | if len(self.web_dir) > 0 and not os.path.exists(self.img_dir): 25 | os.makedirs(self.img_dir) 26 | 27 | self.doc = dominate.document(title=title) 28 | with self.doc: 29 | h1(datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")) 30 | if refresh > 0: 31 | with self.doc.head: 32 | meta(http_equiv="refresh", content=str(refresh)) 33 | 34 | def get_image_dir(self): 35 | return self.img_dir 36 | 37 | def add_header(self, str): 38 | with self.doc: 39 | h3(str) 40 | 41 | def add_table(self, border=1): 42 | self.t = table(border=border, style="table-layout: fixed;") 43 | self.doc.add(self.t) 44 | 45 | def add_images(self, ims, txts, links, width=512): 46 | self.add_table() 47 | with self.t: 48 | with tr(): 49 | for im, txt, link in zip(ims, txts, links): 50 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 51 | with p(): 52 | with a(href=os.path.join('images', link)): 53 | img(style="width:%dpx" % (width), src=os.path.join('images', im)) 54 | br() 55 | p(txt.encode('utf-8')) 56 | 57 | def save(self): 58 | html_file = os.path.join(self.web_dir, self.html_name) 59 | f = open(html_file, 'wt') 60 | f.write(self.doc.render()) 61 | f.close() 62 | 63 | 64 | if __name__ == '__main__': 65 | html = HTML('web/', 'test_html') 66 | html.add_header('hello world') 67 | 68 | ims = [] 69 | txts = [] 70 | links = [] 71 | for n in range(4): 72 | ims.append('image_%d.jpg' % n) 73 | txts.append('text_%d' % n) 74 | links.append('image_%d.jpg' % n) 75 | html.add_images(ims, txts, links) 76 | html.save() 77 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import re 7 | import importlib 8 | import torch 9 | from argparse import Namespace 10 | import numpy as np 11 | from PIL import Image 12 | import os 13 | import argparse 14 | import dill as pickle 15 | import util.coco 16 | 17 | 18 | def save_obj(obj, name): 19 | with open(name, 'wb') as f: 20 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 21 | 22 | 23 | def load_obj(name): 24 | with open(name, 'rb') as f: 25 | return pickle.load(f) 26 | 27 | # returns a configuration for creating a generator 28 | # |default_opt| should be the opt of the current experiment 29 | # |**kwargs|: if any configuration should be overriden, it can be specified here 30 | 31 | 32 | def copyconf(default_opt, **kwargs): 33 | conf = argparse.Namespace(**vars(default_opt)) 34 | for key in kwargs: 35 | print(key, kwargs[key]) 36 | setattr(conf, key, kwargs[key]) 37 | return conf 38 | 39 | 40 | def tile_images(imgs, picturesPerRow=4): 41 | """ Code borrowed from 42 | https://stackoverflow.com/questions/26521365/cleanly-tile-numpy-array-of-images-stored-in-a-flattened-1d-format/26521997 43 | """ 44 | 45 | # Padding 46 | if imgs.shape[0] % picturesPerRow == 0: 47 | rowPadding = 0 48 | else: 49 | rowPadding = picturesPerRow - imgs.shape[0] % picturesPerRow 50 | if rowPadding > 0: 51 | imgs = np.concatenate([imgs, np.zeros((rowPadding, *imgs.shape[1:]), dtype=imgs.dtype)], axis=0) 52 | 53 | # Tiling Loop (The conditionals are not necessary anymore) 54 | tiled = [] 55 | for i in range(0, imgs.shape[0], picturesPerRow): 56 | tiled.append(np.concatenate([imgs[j] for j in range(i, i + picturesPerRow)], axis=1)) 57 | 58 | tiled = np.concatenate(tiled, axis=0) 59 | return tiled 60 | 61 | 62 | # Converts a Tensor into a Numpy array 63 | # |imtype|: the desired type of the converted numpy array 64 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False): 65 | if isinstance(image_tensor, list): 66 | image_numpy = [] 67 | for i in range(len(image_tensor)): 68 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) 69 | return image_numpy 70 | 71 | if image_tensor.dim() == 4: 72 | # transform each image in the batch 73 | images_np = [] 74 | for b in range(image_tensor.size(0)): 75 | one_image = image_tensor[b] 76 | one_image_np = tensor2im(one_image) 77 | images_np.append(one_image_np.reshape(1, *one_image_np.shape)) 78 | images_np = np.concatenate(images_np, axis=0) 79 | if tile: 80 | images_tiled = tile_images(images_np) 81 | return images_tiled 82 | else: 83 | return images_np 84 | 85 | if image_tensor.dim() == 2: 86 | image_tensor = image_tensor.unsqueeze(0) 87 | image_numpy = image_tensor.detach().cpu().float().numpy() 88 | if normalize: 89 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 90 | else: 91 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 92 | image_numpy = np.clip(image_numpy, 0, 255) 93 | if image_numpy.shape[2] == 1: 94 | image_numpy = image_numpy[:, :, 0] 95 | return image_numpy.astype(imtype) 96 | 97 | 98 | # Converts a one-hot tensor into a colorful label map 99 | def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False): 100 | if label_tensor.dim() == 4: 101 | # transform each image in the batch 102 | images_np = [] 103 | for b in range(label_tensor.size(0)): 104 | one_image = label_tensor[b] 105 | one_image_np = tensor2label(one_image, n_label, imtype) 106 | images_np.append(one_image_np.reshape(1, *one_image_np.shape)) 107 | images_np = np.concatenate(images_np, axis=0) 108 | if tile: 109 | images_tiled = tile_images(images_np) 110 | return images_tiled 111 | else: 112 | images_np = images_np[0] 113 | return images_np 114 | 115 | if label_tensor.dim() == 1: 116 | return np.zeros((64, 64, 3), dtype=np.uint8) 117 | if n_label == 0: 118 | return tensor2im(label_tensor, imtype) 119 | label_tensor = label_tensor.cpu().float() 120 | if label_tensor.size()[0] > 1: 121 | label_tensor = label_tensor.max(0, keepdim=True)[1] 122 | label_tensor = Colorize(n_label)(label_tensor) 123 | label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) 124 | result = label_numpy.astype(imtype) 125 | return result 126 | 127 | 128 | def save_image(image_numpy, image_path, create_dir=False): 129 | if create_dir: 130 | os.makedirs(os.path.dirname(image_path), exist_ok=True) 131 | if len(image_numpy.shape) == 2: 132 | image_numpy = np.expand_dims(image_numpy, axis=2) 133 | if image_numpy.shape[2] == 1: 134 | image_numpy = np.repeat(image_numpy, 3, 2) 135 | image_pil = Image.fromarray(image_numpy) 136 | 137 | # save to png 138 | image_pil.save(image_path.replace('.jpg', '.png')) 139 | 140 | 141 | def mkdirs(paths): 142 | if isinstance(paths, list) and not isinstance(paths, str): 143 | for path in paths: 144 | mkdir(path) 145 | else: 146 | mkdir(paths) 147 | 148 | 149 | def mkdir(path): 150 | if not os.path.exists(path): 151 | os.makedirs(path) 152 | 153 | 154 | def atoi(text): 155 | return int(text) if text.isdigit() else text 156 | 157 | 158 | def natural_keys(text): 159 | ''' 160 | alist.sort(key=natural_keys) sorts in human order 161 | http://nedbatchelder.com/blog/200712/human_sorting.html 162 | (See Toothy's implementation in the comments) 163 | ''' 164 | return [atoi(c) for c in re.split('(\d+)', text)] 165 | 166 | 167 | def natural_sort(items): 168 | items.sort(key=natural_keys) 169 | 170 | 171 | def str2bool(v): 172 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 173 | return True 174 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 175 | return False 176 | else: 177 | raise argparse.ArgumentTypeError('Boolean value expected.') 178 | 179 | 180 | def find_class_in_module(target_cls_name, module): 181 | target_cls_name = target_cls_name.replace('_', '').lower() 182 | clslib = importlib.import_module(module) 183 | cls = None 184 | for name, clsobj in clslib.__dict__.items(): 185 | if name.lower() == target_cls_name: 186 | cls = clsobj 187 | 188 | if cls is None: 189 | print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)) 190 | exit(0) 191 | 192 | return cls 193 | 194 | 195 | def save_network(net, label, epoch, opt): 196 | save_filename = '%s_net_%s.pth' % (epoch, label) 197 | save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename) 198 | torch.save(net.cpu().state_dict(), save_path) 199 | if len(opt.gpu_ids) and torch.cuda.is_available(): 200 | net.cuda() 201 | 202 | 203 | def load_network(net, label, epoch, opt): 204 | save_filename = '%s_net_%s.pth' % (epoch, label) 205 | save_dir = os.path.join(opt.checkpoints_dir, opt.name) 206 | save_path = os.path.join(save_dir, save_filename) 207 | weights = torch.load(save_path) 208 | net.load_state_dict(weights) 209 | return net 210 | 211 | 212 | ############################################################################### 213 | # Code from 214 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py 215 | # Modified so it complies with the Citscape label map colors 216 | ############################################################################### 217 | def uint82bin(n, count=8): 218 | """returns the binary of integer n, count refers to amount of bits""" 219 | return ''.join([str((n >> y) & 1) for y in range(count - 1, -1, -1)]) 220 | 221 | 222 | def labelcolormap(N): 223 | if N == 35: # cityscape 224 | cmap = np.array([(0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (111, 74, 0), (81, 0, 81), 225 | (128, 64, 128), (244, 35, 232), (250, 170, 160), (230, 150, 140), (70, 70, 70), (102, 102, 156), (190, 153, 153), 226 | (180, 165, 180), (150, 100, 100), (150, 120, 90), (153, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0), 227 | (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70), 228 | (0, 60, 100), (0, 0, 90), (0, 0, 110), (0, 80, 100), (0, 0, 230), (119, 11, 32), (0, 0, 142)], 229 | dtype=np.uint8) 230 | else: 231 | cmap = np.zeros((N, 3), dtype=np.uint8) 232 | for i in range(N): 233 | r, g, b = 0, 0, 0 234 | id = i + 1 # let's give 0 a color 235 | for j in range(7): 236 | str_id = uint82bin(id) 237 | r = r ^ (np.uint8(str_id[-1]) << (7 - j)) 238 | g = g ^ (np.uint8(str_id[-2]) << (7 - j)) 239 | b = b ^ (np.uint8(str_id[-3]) << (7 - j)) 240 | id = id >> 3 241 | cmap[i, 0] = r 242 | cmap[i, 1] = g 243 | cmap[i, 2] = b 244 | 245 | if N == 182: # COCO 246 | important_colors = { 247 | 'sea': (54, 62, 167), 248 | 'sky-other': (95, 219, 255), 249 | 'tree': (140, 104, 47), 250 | 'clouds': (170, 170, 170), 251 | 'grass': (29, 195, 49) 252 | } 253 | for i in range(N): 254 | name = util.coco.id2label(i) 255 | if name in important_colors: 256 | color = important_colors[name] 257 | cmap[i] = np.array(list(color)) 258 | 259 | return cmap 260 | 261 | 262 | class Colorize(object): 263 | def __init__(self, n=35): 264 | self.cmap = labelcolormap(n) 265 | self.cmap = torch.from_numpy(self.cmap[:n]) 266 | 267 | def __call__(self, gray_image): 268 | size = gray_image.size() 269 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 270 | 271 | for label in range(0, len(self.cmap)): 272 | mask = (label == gray_image[0]).cpu() 273 | color_image[0][mask] = self.cmap[label][0] 274 | color_image[1][mask] = self.cmap[label][1] 275 | color_image[2][mask] = self.cmap[label][2] 276 | 277 | return color_image 278 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import os 7 | import ntpath 8 | import time 9 | import numpy as np 10 | from skimage.measure import compare_ssim, compare_psnr, compare_nrmse 11 | from . import util 12 | from . import html 13 | import scipy.misc 14 | try: 15 | from StringIO import StringIO # Python 2.7 16 | except ImportError: 17 | from io import BytesIO # Python 3.x 18 | 19 | class Visualizer(): 20 | def __init__(self, opt): 21 | self.opt = opt 22 | self.tf_log = opt.isTrain and opt.tf_log 23 | self.use_html = opt.isTrain and not opt.no_html 24 | self.win_size = opt.display_winsize 25 | self.name = opt.name 26 | if self.tf_log: 27 | import tensorflow as tf 28 | self.tf = tf 29 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') 30 | self.writer = tf.summary.FileWriter(self.log_dir) 31 | 32 | if self.use_html: 33 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 34 | self.img_dir = os.path.join(self.web_dir, 'images') 35 | print('create web directory %s...' % self.web_dir) 36 | util.mkdirs([self.web_dir, self.img_dir]) 37 | if opt.isTrain: 38 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 39 | with open(self.log_name, "a") as log_file: 40 | now = time.strftime("%c") 41 | log_file.write('================ Training Loss (%s) ================\n' % now) 42 | 43 | # |visuals|: dictionary of images to display or save 44 | def display_current_results(self, visuals, epoch, step): 45 | 46 | ## convert tensors to numpy arrays 47 | visuals = self.convert_visuals_to_numpy(visuals) 48 | 49 | if self.tf_log: # show images in tensorboard output 50 | img_summaries = [] 51 | for label, image_numpy in visuals.items(): 52 | # Write the image to a string 53 | try: 54 | s = StringIO() 55 | except: 56 | s = BytesIO() 57 | if len(image_numpy.shape) >= 4: 58 | image_numpy = image_numpy[0] 59 | scipy.misc.toimage(image_numpy).save(s, format="jpeg") 60 | # Create an Image object 61 | img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) 62 | # Create a Summary value 63 | img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) 64 | 65 | # Create and write Summary 66 | summary = self.tf.Summary(value=img_summaries) 67 | self.writer.add_summary(summary, step) 68 | 69 | if self.use_html: # save images to a html file 70 | for label, image_numpy in visuals.items(): 71 | if isinstance(image_numpy, list): 72 | for i in range(len(image_numpy)): 73 | img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.3d_%s_%d.png' % (epoch, step, label, i)) 74 | util.save_image(image_numpy[i], img_path) 75 | else: 76 | img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.3d_%s.png' % (epoch, step, label)) 77 | if len(image_numpy.shape) >= 4: 78 | image_numpy = image_numpy[0] 79 | util.save_image(image_numpy, img_path) 80 | 81 | # update website 82 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=5) 83 | for n in range(epoch, 0, -1): 84 | webpage.add_header('epoch [%d]' % n) 85 | ims = [] 86 | txts = [] 87 | links = [] 88 | 89 | for label, image_numpy in visuals.items(): 90 | if isinstance(image_numpy, list): 91 | for i in range(len(image_numpy)): 92 | img_path = 'epoch%.3d_iter%.3d_%s_%d.png' % (n, step, label, i) 93 | ims.append(img_path) 94 | txts.append(label+str(i)) 95 | links.append(img_path) 96 | else: 97 | img_path = 'epoch%.3d_iter%.3d_%s.png' % (n, step, label) 98 | ims.append(img_path) 99 | txts.append(label) 100 | links.append(img_path) 101 | if len(ims) < 10: 102 | webpage.add_images(ims, txts, links, width=self.win_size) 103 | else: 104 | num = int(round(len(ims)/2.0)) 105 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) 106 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) 107 | webpage.save() 108 | 109 | # errors: dictionary of error labels and values 110 | def plot_current_errors(self, errors, step): 111 | if self.tf_log: 112 | for tag, value in errors.items(): 113 | value = value.mean().float() 114 | summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) 115 | self.writer.add_summary(summary, step) 116 | 117 | # errors: same format as |errors| of plotCurrentErrors 118 | def print_current_errors(self, epoch, i, errors, t, visuals): 119 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) 120 | for k, v in errors.items(): 121 | # print(v) 122 | # if v != 0: 123 | v = v.mean().float() 124 | message += '%s: %.3f ' % (k, v) 125 | 126 | visuals = self.convert_visuals_to_numpy(visuals) 127 | real = visuals['real_image'] 128 | fake = visuals['synthesized_image'] 129 | ssim = 0 130 | psnr = 0 131 | nrmse = 0 132 | for i in range(len(real)): 133 | ssim += compare_ssim(real[i], fake[i], multichannel=True, gaussian_weights=True, 134 | use_sample_covariance=False) / len(real) 135 | psnr += compare_psnr(real[i], fake[i]) / len(real) 136 | nrmse += compare_nrmse(real[i], fake[i]) / len(real) 137 | 138 | message += 'SSIM: %.3f PSNR: %.3f NRMSE: %.3f' % (ssim, psnr, nrmse) 139 | 140 | print(message) 141 | with open(self.log_name, "a") as log_file: 142 | log_file.write('%s\n' % message) 143 | 144 | def cal_eval(self, visuals): 145 | visuals = self.convert_visuals_to_numpy(visuals) 146 | real = visuals['real_image'] 147 | fake = visuals['synthesized_image'] 148 | ssim = 0 149 | psnr = 0 150 | nrmse = 0 151 | for i in range(len(real)): 152 | ssim += compare_ssim(real[i], fake[i], multichannel=True, gaussian_weights=True, use_sample_covariance=False) / len(real) 153 | psnr += compare_psnr(real[i], fake[i]) / len(real) 154 | nrmse += compare_nrmse(real[i], fake[i]) / len(real) 155 | 156 | return ssim, psnr, nrmse 157 | 158 | def print_current_eval(self, visuals): 159 | visuals = self.convert_visuals_to_numpy(visuals) 160 | real = visuals['real_image'] 161 | fake = visuals['synthesized_image'] 162 | ssim = 0 163 | psnr = 0 164 | for i in range(len(real)): 165 | ssim += compare_ssim(real[i], fake[i], multichannel=True, gaussian_weights=True, use_sample_covariance=False)/len(real) 166 | psnr += compare_psnr(real[i], fake[i])/len(real) 167 | print('SSIM: %.3f PSNR: %.3f' %(ssim, psnr)) 168 | 169 | 170 | def convert_visuals_to_numpy(self, visuals): 171 | for key, t in visuals.items(): 172 | tile = self.opt.batchSize > 8 173 | if 'input_label' == key: 174 | t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) 175 | else: 176 | t = util.tensor2im(t, tile=tile) 177 | visuals[key] = t 178 | return visuals 179 | 180 | # save image to the disk 181 | def save_images(self, webpage, visuals, image_path): 182 | visuals = self.convert_visuals_to_numpy(visuals) 183 | 184 | image_dir = webpage.get_image_dir() 185 | short_path = ntpath.basename(image_path[0]) 186 | name = os.path.splitext(short_path)[0] 187 | 188 | webpage.add_header(name) 189 | ims = [] 190 | txts = [] 191 | links = [] 192 | 193 | for label, image_numpy in visuals.items(): 194 | image_name = os.path.join(label, '%s.png' % (name)) 195 | save_path = os.path.join(image_dir, image_name) 196 | util.save_image(image_numpy, save_path, create_dir=True) 197 | 198 | ims.append(image_name) 199 | txts.append(label) 200 | links.append(image_name) 201 | webpage.add_images(ims, txts, links, width=self.win_size) 202 | 203 | # Our codes My single image convert 204 | 205 | def convert_image(self, generated): 206 | tile = self.opt.batchSize > 8 207 | t = util.tensor2im(generated, tile=tile)[0] 208 | 209 | #image_pil = Image.fromarray(t) 210 | 211 | # save to png 212 | #image_pil.save('test.png') 213 | 214 | return (t) --------------------------------------------------------------------------------