├── .gitignore ├── README.md ├── dataset ├── __init__.py ├── base_dataset.py ├── cityscapes_dataset.py ├── image_folder.py └── pix2pix_dataset.py ├── imgs └── teaser.png ├── main.py ├── models ├── __init__.py ├── networks │ ├── __init__.py │ ├── architecture.py │ ├── base_network.py │ ├── discriminator.py │ ├── fcn.py │ ├── loss.py │ ├── normalization.py │ ├── sgnet.py │ ├── spade_discriminator.py │ └── spnet.py └── seg_inpaint_model.py ├── trainers ├── __init__.py └── seg_inpaint_trainer.py └── util ├── __init__.py ├── coco.py ├── html.py ├── util.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch - SegInpaint 2 | 3 | ![teaser](imgs/teaser.png) 4 | 5 | This repo is an adapted version of the BMVC 2018 paper: [SPG-Net](http://bmvc2018.org/contents/papers/0317.pdf) by Song et al. 6 | 7 | Usage 8 | --- 9 | 10 | **Prepare the data** 11 | 12 | Download the [Cityscapes](https://www.cityscapes-dataset.com/) datasets and the masks from [PartialConv](https://nv-adlr.github.io/publication/partialconv-inpainting). Create a directory `data` and put the downloaded data under `data`. 13 | 14 | **Training** 15 | 16 | If you want to use the synchronized batchnorm, you can set it up following the steps (credits: [SPADE](https://github.com/NVlabs/SPADE)): 17 | ``` 18 | cd models/networks/ 19 | git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 20 | cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . 21 | cd ../../ 22 | ``` 23 | 24 | Train with the following command: 25 | ``` 26 | python main.py --gpu_ids 0 --batch_size 2 27 | ``` 28 | and check the results at `logs`. 29 | 30 | 31 | Acknowledgments 32 | --- 33 | This code borrows heavily from [SPADE](https://github.com/NVlabs/SPADE) and [pix2pixHD](https://github.com/NVIDIA/pix2pixHD). I also adapt the code from [EdgeConnect](https://github.com/knazeri/edge-connect) for processing the masks for training. Thanks for their amazing works! 34 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch.utils.data 3 | from dataset.base_dataset import BaseDataset 4 | 5 | 6 | def find_dataset_using_name(dataset_name): 7 | dataset_filename = "dataset." + dataset_name + "_dataset" 8 | datasetlib = importlib.import_module(dataset_filename) 9 | 10 | # In the file, the class called DatasetNameDataset() will 11 | # be instantiated. It has to be a subclass of BaseDataset, 12 | # and it is case-insensitive. 13 | dataset = None 14 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 15 | for name, cls in datasetlib.__dict__.items(): 16 | if name.lower() == target_dataset_name.lower() \ 17 | and issubclass(cls, BaseDataset): 18 | dataset = cls 19 | 20 | if dataset is None: 21 | raise ValueError("In %s.py, there should be a subclass of BaseDataset " 22 | "with class name that matches %s in lowercase." % 23 | (dataset_filename, target_dataset_name)) 24 | 25 | return dataset 26 | 27 | 28 | def get_option_setter(dataset_name): 29 | dataset_class = find_dataset_using_name(dataset_name) 30 | return dataset_class.modify_commandline_options 31 | 32 | 33 | def create_dataloader(opt): 34 | dataset = find_dataset_using_name(opt.dataset_mode) 35 | instance = dataset() 36 | instance.initialize(opt) 37 | print("dataset [%s] of size %d was created" % 38 | (type(instance).__name__, len(instance))) 39 | 40 | dataloader = torch.utils.data.DataLoader( 41 | instance, 42 | batch_size=opt.batchSize, 43 | shuffle=not opt.serial_batches, 44 | num_workers=int(opt.nThreads), 45 | drop_last=opt.isTrain 46 | ) 47 | 48 | return dataloader 49 | -------------------------------------------------------------------------------- /dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import random 6 | 7 | 8 | class BaseDataset(data.Dataset): 9 | def __init__(self): 10 | super(BaseDataset, self).__init__() 11 | 12 | @staticmethod 13 | def modify_commandline_options(parser, is_train): 14 | return parser 15 | 16 | def initialize(self, opt): 17 | pass 18 | 19 | 20 | def get_params(opt, size): 21 | w, h = size 22 | new_h = h 23 | new_w = w 24 | if opt.preprocess_mode == 'resize_and_crop': 25 | new_h = new_w = opt.load_size 26 | elif opt.preprocess_mode == 'scale_width_and_crop': 27 | new_w = opt.load_size 28 | new_h = opt.load_size * h // w 29 | elif opt.preprocess_mode == 'scale_shortside_and_crop': 30 | ss, ls = min(w, h), max(w, h) # shortside and longside 31 | width_is_shorter = w == ss 32 | ls = int(opt.load_size * ls / ss) 33 | new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss) 34 | 35 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 36 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 37 | 38 | flip = random.random() > 0.5 39 | return {'crop_pos': (x, y), 'flip': flip} 40 | 41 | 42 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True): 43 | transform_list = [] 44 | if 'resize' in opt.preprocess_mode: 45 | osize = [opt.load_size, opt.load_size] 46 | transform_list.append(transforms.Resize(osize, interpolation=method)) 47 | elif 'scale_width' in opt.preprocess_mode: 48 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 49 | elif 'scale_shortside' in opt.preprocess_mode: 50 | transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method))) 51 | 52 | if 'crop' in opt.preprocess_mode: 53 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 54 | 55 | if opt.preprocess_mode == 'none': 56 | base = 32 57 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) 58 | 59 | if opt.preprocess_mode == 'fixed': 60 | w = opt.crop_size 61 | h = round(opt.crop_size / opt.aspect_ratio) 62 | transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method))) 63 | 64 | if opt.isTrain and not opt.no_flip: 65 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 66 | 67 | if toTensor: 68 | transform_list += [transforms.ToTensor()] 69 | 70 | if normalize: 71 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 72 | (0.5, 0.5, 0.5))] 73 | return transforms.Compose(transform_list) 74 | 75 | 76 | def normalize(): 77 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 78 | 79 | 80 | def __resize(img, w, h, method=Image.BICUBIC): 81 | return img.resize((w, h), method) 82 | 83 | 84 | def __make_power_2(img, base, method=Image.BICUBIC): 85 | ow, oh = img.size 86 | h = int(round(oh / base) * base) 87 | w = int(round(ow / base) * base) 88 | if (h == oh) and (w == ow): 89 | return img 90 | return img.resize((w, h), method) 91 | 92 | 93 | def __scale_width(img, target_width, method=Image.BICUBIC): 94 | ow, oh = img.size 95 | if (ow == target_width): 96 | return img 97 | w = target_width 98 | h = int(target_width * oh / ow) 99 | return img.resize((w, h), method) 100 | 101 | 102 | def __scale_shortside(img, target_width, method=Image.BICUBIC): 103 | ow, oh = img.size 104 | ss, ls = min(ow, oh), max(ow, oh) # shortside and longside 105 | width_is_shorter = ow == ss 106 | if (ss == target_width): 107 | return img 108 | ls = int(target_width * ls / ss) 109 | nw, nh = (ss, ls) if width_is_shorter else (ls, ss) 110 | return img.resize((nw, nh), method) 111 | 112 | 113 | def __crop(img, pos, size): 114 | ow, oh = img.size 115 | x1, y1 = pos 116 | tw = th = size 117 | return img.crop((x1, y1, x1 + tw, y1 + th)) 118 | 119 | 120 | def __flip(img, flip): 121 | if flip: 122 | return img.transpose(Image.FLIP_LEFT_RIGHT) 123 | return img 124 | -------------------------------------------------------------------------------- /dataset/cityscapes_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from dataset.pix2pix_dataset import Pix2pixDataset 3 | from dataset.image_folder import make_dataset 4 | 5 | class CityscapesDataset(Pix2pixDataset): 6 | 7 | @staticmethod 8 | def modify_commandline_options(parser, is_train): 9 | parser = Pix2pixDataset.modify_commandline_options(parser, is_train) 10 | parser.set_defaults(preprocess_mode='fixed') 11 | parser.set_defaults(load_size=256) 12 | parser.set_defaults(crop_size=256) 13 | parser.set_defaults(display_winsize=256) 14 | parser.set_defaults(label_nc=35) 15 | parser.set_defaults(aspect_ratio=1.0) 16 | opt, _ = parser.parse_known_args() 17 | if hasattr(opt, 'num_upsampling_layers'): 18 | parser.set_defaults(num_upsampling_layers='more') 19 | return parser 20 | 21 | def get_paths(self, opt): 22 | root = opt.dataroot 23 | phase = 'val' if opt.phase == 'test' else 'train' 24 | 25 | label_dir = os.path.join(root, 'gtFine', phase) 26 | label_paths_all = make_dataset(label_dir, recursive=True) 27 | label_paths = [p for p in label_paths_all if p.endswith('_labelIds.png')] 28 | 29 | image_dir = os.path.join(root, 'leftImg8bit', phase) 30 | image_paths = make_dataset(image_dir, recursive=True) 31 | 32 | if not opt.no_instance: 33 | instance_paths = [p for p in label_paths_all if p.endswith('_instanceIds.png')] 34 | else: 35 | instance_paths = [] 36 | 37 | # load mask 38 | mask_dir = os.path.join(root, 'irregular_mask') 39 | mask_paths = make_dataset(mask_dir, recursive=True) 40 | 41 | return label_paths, image_paths, instance_paths, mask_paths 42 | 43 | def paths_match(self, path1, path2): 44 | name1 = os.path.basename(path1) 45 | name2 = os.path.basename(path2) 46 | # compare the first 3 components, [city]_[id1]_[id2] 47 | return '_'.join(name1.split('_')[:3]) == \ 48 | '_'.join(name2.split('_')[:3]) 49 | -------------------------------------------------------------------------------- /dataset/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import os 10 | 11 | IMG_EXTENSIONS = [ 12 | '.jpg', '.JPG', '.jpeg', '.JPEG', 13 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff', '.webp' 14 | ] 15 | 16 | 17 | def is_image_file(filename): 18 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 19 | 20 | 21 | def make_dataset_rec(dir, images): 22 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 23 | 24 | for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)): 25 | for fname in fnames: 26 | if is_image_file(fname): 27 | path = os.path.join(root, fname) 28 | images.append(path) 29 | 30 | 31 | def make_dataset(dir, recursive=False, read_cache=False, write_cache=False): 32 | images = [] 33 | 34 | if read_cache: 35 | possible_filelist = os.path.join(dir, 'files.list') 36 | if os.path.isfile(possible_filelist): 37 | with open(possible_filelist, 'r') as f: 38 | images = f.read().splitlines() 39 | return images 40 | 41 | if recursive: 42 | make_dataset_rec(dir, images) 43 | else: 44 | assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir 45 | 46 | for root, dnames, fnames in sorted(os.walk(dir)): 47 | for fname in fnames: 48 | if is_image_file(fname): 49 | path = os.path.join(root, fname) 50 | images.append(path) 51 | 52 | if write_cache: 53 | filelist_cache = os.path.join(dir, 'files.list') 54 | with open(filelist_cache, 'w') as f: 55 | for path in images: 56 | f.write("%s\n" % path) 57 | print('wrote filelist cache at %s' % filelist_cache) 58 | 59 | return images 60 | 61 | 62 | def default_loader(path): 63 | return Image.open(path).convert('RGB') 64 | 65 | 66 | class ImageFolder(data.Dataset): 67 | 68 | def __init__(self, root, transform=None, return_paths=False, 69 | loader=default_loader): 70 | imgs = make_dataset(root) 71 | if len(imgs) == 0: 72 | raise(RuntimeError("Found 0 images in: " + root + "\n" 73 | "Supported image extensions are: " + 74 | ",".join(IMG_EXTENSIONS))) 75 | 76 | self.root = root 77 | self.imgs = imgs 78 | self.transform = transform 79 | self.return_paths = return_paths 80 | self.loader = loader 81 | 82 | def __getitem__(self, index): 83 | path = self.imgs[index] 84 | img = self.loader(path) 85 | if self.transform is not None: 86 | img = self.transform(img) 87 | if self.return_paths: 88 | return img, path 89 | else: 90 | return img 91 | 92 | def __len__(self): 93 | return len(self.imgs) 94 | -------------------------------------------------------------------------------- /dataset/pix2pix_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import imageio 4 | import numpy as np 5 | from PIL import Image 6 | 7 | import util.util as util 8 | from dataset.base_dataset import BaseDataset, get_params, get_transform, normalize 9 | 10 | import torchvision.utils as vutils 11 | import torchvision.transforms.functional as F 12 | import torchvision.transforms as transforms 13 | 14 | norm = normalize() 15 | 16 | class Pix2pixDataset(BaseDataset): 17 | """ adapted from SPADE repo. """ 18 | @staticmethod 19 | def modify_commandline_options(parser, is_train): 20 | parser.add_argument('--no_pairing_check', action='store_true', 21 | help='If specified, skip sanity check of correct label-image file pairing') 22 | return parser 23 | 24 | def initialize(self, opt): 25 | self.opt = opt 26 | self.mask = opt.mask 27 | 28 | label_paths, image_paths, instance_paths, mask_paths = self.get_paths(opt) 29 | 30 | util.natural_sort(label_paths) 31 | util.natural_sort(image_paths) 32 | util.natural_sort(mask_paths) # MASK 33 | if not opt.no_instance: 34 | util.natural_sort(instance_paths) 35 | 36 | label_paths = label_paths[:opt.max_dataset_size] 37 | image_paths = image_paths[:opt.max_dataset_size] 38 | instance_paths = instance_paths[:opt.max_dataset_size] 39 | #mask_paths = mask_paths[:] will get random mask from there. 40 | 41 | 42 | if not opt.no_pairing_check: 43 | for path1, path2 in zip(label_paths, image_paths): 44 | assert self.paths_match(path1, path2), \ 45 | "The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this." % (path1, path2) 46 | 47 | self.label_paths = label_paths 48 | self.image_paths = image_paths 49 | self.instance_paths = instance_paths 50 | self.mask_paths = mask_paths # MASK 51 | 52 | size = len(self.label_paths) 53 | self.dataset_size = size 54 | 55 | def get_paths(self, opt): 56 | label_paths = [] 57 | image_paths = [] 58 | instance_paths = [] 59 | assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)" 60 | return label_paths, image_paths, instance_paths 61 | 62 | def paths_match(self, path1, path2): 63 | filename1_without_ext = os.path.splitext(os.path.basename(path1))[0] 64 | filename2_without_ext = os.path.splitext(os.path.basename(path2))[0] 65 | return filename1_without_ext == filename2_without_ext 66 | 67 | def __getitem__(self, index): 68 | # Label Image 69 | label_path = self.label_paths[index] 70 | label = Image.open(label_path) 71 | params = get_params(self.opt, label.size) 72 | transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) 73 | label_tensor = transform_label(label) * 255.0 74 | label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc 75 | 76 | # input image (real images) 77 | image_path = self.image_paths[index] 78 | assert self.paths_match(label_path, image_path), \ 79 | "The label_path %s and image_path %s don't match." % \ 80 | (label_path, image_path) 81 | image = Image.open(image_path) 82 | image = image.convert('RGB') 83 | 84 | transform_image = get_transform(self.opt, params, normalize=False) 85 | image_tensor = transform_image(image) 86 | 87 | # reference: https://github.com/knazeri/edge-connect/blob/master/src/dataset.py#L116-L151 88 | mask, mask_ix = self.load_mask(image_tensor, index) 89 | mask_tensor = self.edge_connect_to_tensor(mask) 90 | 91 | # if using instance maps 92 | if self.opt.no_instance: 93 | instance_tensor = 0 94 | else: 95 | instance_path = self.instance_paths[index] 96 | instance = Image.open(instance_path) 97 | if instance.mode == 'L': 98 | instance_tensor = transform_label(instance) * 255 99 | instance_tensor = instance_tensor.long() 100 | else: 101 | instance_tensor = transform_label(instance) 102 | 103 | input_dict = {'label': label_tensor, 104 | 'instance': instance_tensor, 105 | 'image': image_tensor, 106 | 'mask': mask_tensor, 107 | 'path': image_path, 108 | } 109 | 110 | # Give subclasses a chance to modify the final output 111 | self.postprocess(input_dict) 112 | 113 | return input_dict 114 | 115 | def load_mask(self, img_tensor, index): 116 | imgh, imgw = img_tensor.shape[1:] 117 | 118 | mask_type = self.mask 119 | 120 | # external + random block 121 | if mask_type == 4: 122 | mask_type = 1 if np.random.binomial(1, 0.5) == 1 else 3 123 | 124 | # external + random block + half 125 | elif mask_type == 5: 126 | mask_type = np.random.randint(1, 4) 127 | 128 | # # random block 129 | # if mask_type == 1: 130 | # return create_mask(imgw, imgh, imgw // 2, imgh // 2) 131 | 132 | # # half 133 | # if mask_type == 2: 134 | # # randomly choose right or left 135 | # return create_mask(imgw, imgh, imgw // 2, imgh, 0 if random.random() < 0.5 else imgw // 2, 0) 136 | 137 | # external 138 | if mask_type == 3: 139 | mask_ix = np.random.randint(0, len(self.mask_paths)) 140 | mask = imageio.imread(self.mask_paths[mask_ix]) 141 | mask = self.edge_connect_resize(mask, imgh, imgw) 142 | mask = (mask > 0).astype(np.uint8) * 255 143 | return mask, mask_ix 144 | 145 | def edge_connect_resize(self, img, height, width, centerCrop=True): 146 | if len(img.shape) == 2: # gray image 147 | imgh, imgw= img.shape 148 | else: 149 | imgh, imgw = img.shape[0:2] 150 | 151 | if centerCrop and imgh != imgw: 152 | # center crop 153 | side = np.minimum(imgh, imgw) 154 | j = (imgh - side) // 2 155 | i = (imgw - side) // 2 156 | if len(img.shape) == 2: # gray image 157 | img = img[j:j + side, i:i + side] 158 | else: 159 | img = img[j:j + side, i:i + side, ...] 160 | 161 | img = cv2.resize(img, (height, width)) 162 | 163 | return img 164 | 165 | def edge_connect_to_tensor(self, img): 166 | img = Image.fromarray(img) 167 | img_t = F.to_tensor(img).float() 168 | return img_t 169 | 170 | def postprocess(self, input_dict): 171 | return input_dict 172 | 173 | def __len__(self): 174 | return self.dataset_size 175 | -------------------------------------------------------------------------------- /imgs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yccyenchicheng/pytorch-SegInpaint/13de7deb7ad11508294d8ed2e7a2c8f67ebc04d9/imgs/teaser.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | usage: python main.py --gpu_ids 0 --batch_size 2 3 | """ 4 | import argparse 5 | import os 6 | import sys 7 | from collections import OrderedDict 8 | from datetime import datetime 9 | from tqdm import tqdm 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.utils.data import DataLoader 14 | 15 | import dataset 16 | import util.util as util 17 | from dataset.cityscapes_dataset import CityscapesDataset 18 | from util.visualizer import Visualizer 19 | from trainers.seg_inpaint_trainer import SegInpaintTrainer 20 | 21 | current_time = datetime.now().strftime("%m%d-%H%M") 22 | 23 | def get_opt(): 24 | parser = argparse.ArgumentParser() 25 | ### base options ### 26 | parser.add_argument('--name', type=str, default='exp1-%s' % current_time, help="name of this experiment") 27 | parser.add_argument('--phase', type=str, default='train', help="'train' or 'test'") 28 | parser.add_argument('--gpu_ids', type=str, default='0,2', help="0,1,2 corresponds to GPU 2,0,1 (weird)") 29 | parser.add_argument('--norm_D', type=str, default='spectralinstance', help='instance normalization or batch normalization') 30 | 31 | # input/output sizes 32 | parser.add_argument('--batch_size', type=int, default=4, help='input batch size') 33 | parser.add_argument('--contain_dontcare_label', action='store_true', help='if the label map contains dontcare label (dontcare=255)') 34 | 35 | # for setting input 36 | parser.add_argument('--dataset', type=str, default='cityscapes') # dataroot will be at: 'server'_data/cityscapes 37 | parser.add_argument('--max_dataset_size', type=int, default=sys.maxsize, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 38 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') 39 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 40 | parser.add_argument('--nThreads', default=0, type=int, help='# threads for loading data') 41 | parser.add_argument('--mask', type=int, default=3, choices=[1, 2, 3, 4, 5], help='1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half)') 42 | 43 | # for instance-wise features 44 | parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input') 45 | ##################### 46 | 47 | # for deeplab 48 | parser.add_argument('--deeplab_backbone', type=str, default='resnet', choices=['resnet', 'xception', 'drn', 'mobilenet']) 49 | parser.add_argument('--deeplab_output_stride', type=int, default=16, help='network output stride (default: 8)') 50 | 51 | ### train options ### 52 | # for displays 53 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 54 | parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration') 55 | parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed') 56 | 57 | # for training 58 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate. This is NOT the total #epochs. Totla #epochs is niter + niter_decay') 59 | parser.add_argument('--niter_decay', type=int, default=0, help='# of iter to linearly decay learning rate to zero') 60 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 61 | parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam') 62 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 63 | parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)') 64 | 65 | # for discriminator 66 | parser.add_argument('--ndf', type=int, default=64, help='how many D to use in multiscale discriminator') 67 | parser.add_argument('--num_D', type=int, default=3, help='how many D to use in multiscale discriminator') 68 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') 69 | parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') 70 | parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss') 71 | parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme') 72 | parser.add_argument('--netD_subarch', type=str, default='n_layer', help='architecture of each discriminator') 73 | ##################### 74 | 75 | parser.add_argument('--test', action='store_true') 76 | 77 | # copy from original code. add more command_line_option 78 | dataset_mode = 'cityscapes' 79 | isTrain = True 80 | dataset_option_setter = dataset.get_option_setter(dataset_mode) 81 | parser = dataset_option_setter(parser, isTrain) 82 | 83 | opt = parser.parse_args() 84 | 85 | opt.isTrain = isTrain 86 | 87 | return opt 88 | 89 | with open('latest_cmd.txt', 'w') as f: 90 | cmd = ' '.join(sys.argv) + '\n' 91 | f.write(cmd) 92 | 93 | opt = get_opt() 94 | opt.dataroot = os.path.join('data', opt.dataset) 95 | opt.total_epochs = opt.niter + opt.niter_decay 96 | 97 | # logs and checkpoing 98 | if not os.path.exists('logs'): 99 | os.mkdir('logs') 100 | 101 | log_root = os.path.join('logs', 'seg_inpaint_logs') 102 | if not os.path.exists(log_root): 103 | os.mkdir(log_root) 104 | 105 | opt.log_root = log_root 106 | exp_dir = os.path.join(log_root, opt.name) 107 | util.mkdir(exp_dir) 108 | ckpt_dir = os.path.join(log_root, opt.name, 'checkpoint') 109 | util.mkdir(ckpt_dir) 110 | 111 | dataset = CityscapesDataset() 112 | dataset.initialize(opt) 113 | dataloader = DataLoader(dataset, 114 | batch_size=opt.batch_size, 115 | shuffle=not opt.serial_batches, 116 | num_workers=int(opt.nThreads), 117 | drop_last=opt.isTrain) 118 | 119 | # setup GPU, optimizer. 120 | # borrow from SPADE 121 | str_ids = opt.gpu_ids.split(',') 122 | opt.gpu_ids = [] 123 | for str_id in str_ids: 124 | id = int(str_id) 125 | if id >= 0: 126 | opt.gpu_ids.append(id) 127 | if len(opt.gpu_ids) > 0: 128 | torch.cuda.set_device(opt.gpu_ids[0]) 129 | 130 | assert len(opt.gpu_ids) == 0 or opt.batch_size % len(opt.gpu_ids) == 0, \ 131 | "Batch size %d is wrong. It must be a multiple of # GPUs %d." \ 132 | % (opt.batch_size, len(opt.gpu_ids)) 133 | 134 | trainer = SegInpaintTrainer(opt) 135 | 136 | # create tool for visualization 137 | visualizer = Visualizer(opt) 138 | 139 | total_steps_so_far = 0 140 | 141 | for epoch in range(opt.total_epochs): 142 | 143 | for i, data_i in tqdm(enumerate(dataloader), total=len(dataloader)): 144 | current_step = epoch*len(dataloader) + i 145 | 146 | # Training 147 | if not opt.test or i == 0: 148 | trainer.run_generator_one_step(data_i) 149 | trainer.run_discriminator_one_step(data_i) 150 | else: 151 | pass 152 | 153 | if current_step % 10 == 0: 154 | if i != 0: 155 | sys.stdout.write("\033[F") # back to previous line 156 | sys.stdout.write("\033[K") # clear line 157 | 158 | loss_str = trainer.get_loss_str() 159 | print("[%d/%d] %d %s" % (epoch, opt.total_epochs, current_step, loss_str)) 160 | with open(visualizer.log_name, "a") as log_file: 161 | log_file.write('%s\n' % loss_str) 162 | 163 | 164 | if current_step % 50 == 0: 165 | real_img, corruped_img, generated_seg, generated_img = \ 166 | trainer.get_latest_results() 167 | visuals = OrderedDict([('input_label', data_i['label']), 168 | ('synthesized_image', generated_img), 169 | ('synthesized_segmentation', generated_seg), 170 | ('real_image', real_img), 171 | ('corruped_image', corruped_img)], 172 | ) 173 | visualizer.display_current_results(visuals, epoch, current_step) 174 | 175 | if opt.test: 176 | break 177 | 178 | trainer.update_learning_rate(epoch) 179 | 180 | model_path = os.path.join(ckpt_dir, 'model_%d.pth' % epoch) 181 | trainer.save(model_path, epoch) 182 | 183 | print('Training complete.') 184 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yccyenchicheng/pytorch-SegInpaint/13de7deb7ad11508294d8ed2e7a2c8f67ebc04d9/models/__init__.py -------------------------------------------------------------------------------- /models/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yccyenchicheng/pytorch-SegInpaint/13de7deb7ad11508294d8ed2e7a2c8f67ebc04d9/models/networks/__init__.py -------------------------------------------------------------------------------- /models/networks/architecture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | import torch.nn.utils.spectral_norm as spectral_norm 6 | 7 | # ResNet block used in pix2pixHD 8 | # We keep the same architecture as pix2pixHD. 9 | class ResnetBlock(nn.Module): 10 | def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3): 11 | super().__init__() 12 | 13 | pw = (kernel_size - 1) // 2 14 | self.conv_block = nn.Sequential( 15 | nn.ReflectionPad2d(pw), 16 | nn.Conv2d(dim, dim, kernel_size=kernel_size), 17 | norm_layer(dim), 18 | activation, 19 | nn.ReflectionPad2d(pw), 20 | nn.Conv2d(dim, dim, kernel_size=kernel_size), 21 | norm_layer(dim), 22 | ) 23 | 24 | def forward(self, x): 25 | y = self.conv_block(x) 26 | out = x + y 27 | return out 28 | 29 | class Alex(torch.nn.Module): 30 | def __init__(self, requires_grad=False): 31 | super().__init__() 32 | alex_pretrained_features = torchvision.models.alexnet(pretrained=True).features 33 | self.slice1 = torch.nn.Sequential() 34 | self.slice2 = torch.nn.Sequential() 35 | self.slice3 = torch.nn.Sequential() 36 | self.slice4 = torch.nn.Sequential() 37 | self.slice5 = torch.nn.Sequential() 38 | for x in range(2): 39 | self.slice1.add_module(str(x), alex_pretrained_features[x]) 40 | for x in range(2, 7): 41 | self.slice2.add_module(str(x), alex_pretrained_features[x]) 42 | for x in range(7, 12): 43 | self.slice3.add_module(str(x), alex_pretrained_features[x]) 44 | for x in range(12, 21): 45 | self.slice4.add_module(str(x), alex_pretrained_features[x]) 46 | for x in range(21, 30): 47 | self.slice5.add_module(str(x), alex_pretrained_features[x]) 48 | if not requires_grad: 49 | for param in self.parameters(): 50 | param.requires_grad = False 51 | 52 | def forward(self, X): 53 | h_relu1 = self.slice1(X) 54 | h_relu2 = self.slice2(h_relu1) 55 | h_relu3 = self.slice3(h_relu2) 56 | h_relu4 = self.slice4(h_relu3) 57 | h_relu5 = self.slice5(h_relu4) 58 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 59 | return out 60 | 61 | # VGG architecter, used for the perceptual loss using a pretrained VGG network 62 | class VGG19(torch.nn.Module): 63 | def __init__(self, requires_grad=False): 64 | super().__init__() 65 | vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features 66 | self.slice1 = torch.nn.Sequential() 67 | self.slice2 = torch.nn.Sequential() 68 | self.slice3 = torch.nn.Sequential() 69 | self.slice4 = torch.nn.Sequential() 70 | self.slice5 = torch.nn.Sequential() 71 | for x in range(2): 72 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 73 | for x in range(2, 7): 74 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 75 | for x in range(7, 12): 76 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 77 | for x in range(12, 21): 78 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 79 | for x in range(21, 30): 80 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 81 | if not requires_grad: 82 | for param in self.parameters(): 83 | param.requires_grad = False 84 | 85 | def forward(self, X): 86 | h_relu1 = self.slice1(X) 87 | h_relu2 = self.slice2(h_relu1) 88 | h_relu3 = self.slice3(h_relu2) 89 | h_relu4 = self.slice4(h_relu3) 90 | h_relu5 = self.slice5(h_relu4) 91 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 92 | return out 93 | -------------------------------------------------------------------------------- /models/networks/base_network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import init 3 | 4 | 5 | class BaseNetwork(nn.Module): 6 | def __init__(self): 7 | super(BaseNetwork, self).__init__() 8 | 9 | @staticmethod 10 | def modify_commandline_options(parser, is_train): 11 | return parser 12 | 13 | def print_network(self): 14 | if isinstance(self, list): 15 | self = self[0] 16 | num_params = 0 17 | for param in self.parameters(): 18 | num_params += param.numel() 19 | print('Network [%s] was created. Total number of parameters: %.1f million. ' 20 | 'To see the architecture, do print(network).' 21 | % (type(self).__name__, num_params / 1000000)) 22 | 23 | def init_weights(self, init_type='normal', gain=0.02): 24 | def init_func(m): 25 | classname = m.__class__.__name__ 26 | if classname.find('BatchNorm2d') != -1: 27 | if hasattr(m, 'weight') and m.weight is not None: 28 | init.normal_(m.weight.data, 1.0, gain) 29 | if hasattr(m, 'bias') and m.bias is not None: 30 | init.constant_(m.bias.data, 0.0) 31 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 32 | if init_type == 'normal': 33 | init.normal_(m.weight.data, 0.0, gain) 34 | elif init_type == 'xavier': 35 | init.xavier_normal_(m.weight.data, gain=gain) 36 | elif init_type == 'xavier_uniform': 37 | init.xavier_uniform_(m.weight.data, gain=1.0) 38 | elif init_type == 'kaiming': 39 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 40 | elif init_type == 'orthogonal': 41 | init.orthogonal_(m.weight.data, gain=gain) 42 | elif init_type == 'none': # uses pytorch's default init method 43 | m.reset_parameters() 44 | else: 45 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 46 | if hasattr(m, 'bias') and m.bias is not None: 47 | init.constant_(m.bias.data, 0.0) 48 | 49 | self.apply(init_func) 50 | 51 | # propagate to children 52 | for m in self.children(): 53 | if hasattr(m, 'init_weights'): 54 | m.init_weights(init_type, gain) 55 | -------------------------------------------------------------------------------- /models/networks/discriminator.py: -------------------------------------------------------------------------------- 1 | """ 2 | from pix2pixHD 3 | """ 4 | import torch.nn as nn 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import util.util as util 8 | 9 | class MultiscaleDiscriminator(nn.Module): 10 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 11 | use_sigmoid=False, num_D=3, getIntermFeat=False): 12 | super(MultiscaleDiscriminator, self).__init__() 13 | self.num_D = num_D 14 | self.n_layers = n_layers 15 | self.getIntermFeat = getIntermFeat 16 | 17 | for i in range(num_D): 18 | netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat) 19 | if getIntermFeat: 20 | for j in range(n_layers+2): 21 | setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j))) 22 | else: 23 | setattr(self, 'layer'+str(i), netD.model) 24 | 25 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) 26 | 27 | def singleD_forward(self, model, input): 28 | if self.getIntermFeat: 29 | result = [input] 30 | for i in range(len(model)): 31 | result.append(model[i](result[-1])) 32 | return result[1:] 33 | else: 34 | return [model(input)] 35 | 36 | def forward(self, input): 37 | num_D = self.num_D 38 | result = [] 39 | input_downsampled = input 40 | for i in range(num_D): 41 | if self.getIntermFeat: 42 | model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)] 43 | else: 44 | model = getattr(self, 'layer'+str(num_D-1-i)) 45 | result.append(self.singleD_forward(model, input_downsampled)) 46 | if i != (num_D-1): 47 | input_downsampled = self.downsample(input_downsampled) 48 | 49 | return result 50 | 51 | # Defines the PatchGAN discriminator with the specified arguments. 52 | class NLayerDiscriminator(nn.Module): 53 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False): 54 | super(NLayerDiscriminator, self).__init__() 55 | self.getIntermFeat = getIntermFeat 56 | self.n_layers = n_layers 57 | 58 | kw = 4 59 | padw = int(np.ceil((kw-1.0)/2)) 60 | sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] 61 | 62 | nf = ndf 63 | for n in range(1, n_layers): 64 | nf_prev = nf 65 | nf = min(nf * 2, 512) 66 | sequence += [[ 67 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), 68 | norm_layer(nf), nn.LeakyReLU(0.2, True) 69 | ]] 70 | 71 | nf_prev = nf 72 | nf = min(nf * 2, 512) 73 | sequence += [[ 74 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), 75 | norm_layer(nf), 76 | nn.LeakyReLU(0.2, True) 77 | ]] 78 | 79 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 80 | 81 | if use_sigmoid: 82 | sequence += [[nn.Sigmoid()]] 83 | 84 | if getIntermFeat: 85 | for n in range(len(sequence)): 86 | setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) 87 | else: 88 | sequence_stream = [] 89 | for n in range(len(sequence)): 90 | sequence_stream += sequence[n] 91 | self.model = nn.Sequential(*sequence_stream) 92 | 93 | def forward(self, input): 94 | if self.getIntermFeat: 95 | res = [input] 96 | for n in range(self.n_layers+2): 97 | model = getattr(self, 'model'+str(n)) 98 | res.append(model(res[-1])) 99 | 100 | return res[1:] 101 | else: 102 | return self.model(input) -------------------------------------------------------------------------------- /models/networks/fcn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torchvision import models 9 | from torchvision.models.vgg import VGG 10 | 11 | 12 | class FCN32s(nn.Module): 13 | 14 | def __init__(self, pretrained_net, n_class): 15 | super().__init__() 16 | self.n_class = n_class 17 | self.pretrained_net = pretrained_net 18 | self.relu = nn.ReLU(inplace=True) 19 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 20 | self.bn1 = nn.BatchNorm2d(512) 21 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 22 | self.bn2 = nn.BatchNorm2d(256) 23 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 24 | self.bn3 = nn.BatchNorm2d(128) 25 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 26 | self.bn4 = nn.BatchNorm2d(64) 27 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 28 | self.bn5 = nn.BatchNorm2d(32) 29 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1) 30 | 31 | def forward(self, x): 32 | output = self.pretrained_net(x) 33 | x5 = output['x5'] # size=(N, 512, x.H/32, x.W/32) 34 | 35 | score = self.bn1(self.relu(self.deconv1(x5))) # size=(N, 512, x.H/16, x.W/16) 36 | score = self.bn2(self.relu(self.deconv2(score))) # size=(N, 256, x.H/8, x.W/8) 37 | score = self.bn3(self.relu(self.deconv3(score))) # size=(N, 128, x.H/4, x.W/4) 38 | score = self.bn4(self.relu(self.deconv4(score))) # size=(N, 64, x.H/2, x.W/2) 39 | score = self.bn5(self.relu(self.deconv5(score))) # size=(N, 32, x.H, x.W) 40 | score = self.classifier(score) # size=(N, n_class, x.H/1, x.W/1) 41 | 42 | return score # size=(N, n_class, x.H/1, x.W/1) 43 | 44 | 45 | class FCN16s(nn.Module): 46 | 47 | def __init__(self, pretrained_net, n_class): 48 | super().__init__() 49 | self.n_class = n_class 50 | self.pretrained_net = pretrained_net 51 | self.relu = nn.ReLU(inplace=True) 52 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 53 | self.bn1 = nn.BatchNorm2d(512) 54 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 55 | self.bn2 = nn.BatchNorm2d(256) 56 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 57 | self.bn3 = nn.BatchNorm2d(128) 58 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 59 | self.bn4 = nn.BatchNorm2d(64) 60 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 61 | self.bn5 = nn.BatchNorm2d(32) 62 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1) 63 | 64 | def forward(self, x): 65 | output = self.pretrained_net(x) 66 | x5 = output['x5'] # size=(N, 512, x.H/32, x.W/32) 67 | x4 = output['x4'] # size=(N, 512, x.H/16, x.W/16) 68 | 69 | score = self.relu(self.deconv1(x5)) # size=(N, 512, x.H/16, x.W/16) 70 | score = self.bn1(score + x4) # element-wise add, size=(N, 512, x.H/16, x.W/16) 71 | score = self.bn2(self.relu(self.deconv2(score))) # size=(N, 256, x.H/8, x.W/8) 72 | score = self.bn3(self.relu(self.deconv3(score))) # size=(N, 128, x.H/4, x.W/4) 73 | score = self.bn4(self.relu(self.deconv4(score))) # size=(N, 64, x.H/2, x.W/2) 74 | score = self.bn5(self.relu(self.deconv5(score))) # size=(N, 32, x.H, x.W) 75 | score = self.classifier(score) # size=(N, n_class, x.H/1, x.W/1) 76 | 77 | return score # size=(N, n_class, x.H/1, x.W/1) 78 | 79 | 80 | class FCN8s(nn.Module): 81 | 82 | def __init__(self, pretrained_net, n_class): 83 | super().__init__() 84 | self.n_class = n_class 85 | self.pretrained_net = pretrained_net 86 | self.relu = nn.ReLU(inplace=True) 87 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 88 | self.bn1 = nn.BatchNorm2d(512) 89 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 90 | self.bn2 = nn.BatchNorm2d(256) 91 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 92 | self.bn3 = nn.BatchNorm2d(128) 93 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 94 | self.bn4 = nn.BatchNorm2d(64) 95 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 96 | self.bn5 = nn.BatchNorm2d(32) 97 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1) 98 | 99 | def forward(self, x): 100 | output = self.pretrained_net(x) 101 | x5 = output['x5'] # size=(N, 512, x.H/32, x.W/32) 102 | x4 = output['x4'] # size=(N, 512, x.H/16, x.W/16) 103 | x3 = output['x3'] # size=(N, 256, x.H/8, x.W/8) 104 | 105 | score = self.relu(self.deconv1(x5)) # size=(N, 512, x.H/16, x.W/16) 106 | score = self.bn1(score + x4) # element-wise add, size=(N, 512, x.H/16, x.W/16) 107 | score = self.relu(self.deconv2(score)) # size=(N, 256, x.H/8, x.W/8) 108 | score = self.bn2(score + x3) # element-wise add, size=(N, 256, x.H/8, x.W/8) 109 | score = self.bn3(self.relu(self.deconv3(score))) # size=(N, 128, x.H/4, x.W/4) 110 | score = self.bn4(self.relu(self.deconv4(score))) # size=(N, 64, x.H/2, x.W/2) 111 | score = self.bn5(self.relu(self.deconv5(score))) # size=(N, 32, x.H, x.W) 112 | score = self.classifier(score) # size=(N, n_class, x.H/1, x.W/1) 113 | 114 | return score # size=(N, n_class, x.H/1, x.W/1) 115 | 116 | 117 | class FCNs(nn.Module): 118 | 119 | def __init__(self, pretrained_net, n_class): 120 | super().__init__() 121 | self.n_class = n_class 122 | self.pretrained_net = pretrained_net 123 | self.relu = nn.ReLU(inplace=True) 124 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 125 | self.bn1 = nn.BatchNorm2d(512) 126 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 127 | self.bn2 = nn.BatchNorm2d(256) 128 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 129 | self.bn3 = nn.BatchNorm2d(128) 130 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 131 | self.bn4 = nn.BatchNorm2d(64) 132 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 133 | self.bn5 = nn.BatchNorm2d(32) 134 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1) 135 | 136 | def forward(self, x): 137 | output = self.pretrained_net(x) 138 | x5 = output['x5'] # size=(N, 512, x.H/32, x.W/32) 139 | x4 = output['x4'] # size=(N, 512, x.H/16, x.W/16) 140 | x3 = output['x3'] # size=(N, 256, x.H/8, x.W/8) 141 | x2 = output['x2'] # size=(N, 128, x.H/4, x.W/4) 142 | x1 = output['x1'] # size=(N, 64, x.H/2, x.W/2) 143 | 144 | score = self.bn1(self.relu(self.deconv1(x5))) # size=(N, 512, x.H/16, x.W/16) 145 | score = score + x4 # element-wise add, size=(N, 512, x.H/16, x.W/16) 146 | score = self.bn2(self.relu(self.deconv2(score))) # size=(N, 256, x.H/8, x.W/8) 147 | score = score + x3 # element-wise add, size=(N, 256, x.H/8, x.W/8) 148 | score = self.bn3(self.relu(self.deconv3(score))) # size=(N, 128, x.H/4, x.W/4) 149 | score = score + x2 # element-wise add, size=(N, 128, x.H/4, x.W/4) 150 | score = self.bn4(self.relu(self.deconv4(score))) # size=(N, 64, x.H/2, x.W/2) 151 | score = score + x1 # element-wise add, size=(N, 64, x.H/2, x.W/2) 152 | score = self.bn5(self.relu(self.deconv5(score))) # size=(N, 32, x.H, x.W) 153 | score = self.classifier(score) # size=(N, n_class, x.H/1, x.W/1) 154 | 155 | return score # size=(N, n_class, x.H/1, x.W/1) 156 | 157 | 158 | class VGGNet(VGG): 159 | def __init__(self, pretrained=True, model='vgg16', requires_grad=True, remove_fc=True, show_params=False): 160 | super().__init__(make_layers(cfg[model])) 161 | self.ranges = ranges[model] 162 | 163 | if pretrained: 164 | exec("self.load_state_dict(models.%s(pretrained=True).state_dict())" % model) 165 | 166 | if not requires_grad: 167 | for param in super().parameters(): 168 | param.requires_grad = False 169 | 170 | if remove_fc: # delete redundant fully-connected layer params, can save memory 171 | del self.classifier 172 | 173 | if show_params: 174 | for name, param in self.named_parameters(): 175 | print(name, param.size()) 176 | 177 | def forward(self, x): 178 | output = {} 179 | 180 | # get the output of each maxpooling layer (5 maxpool in VGG net) 181 | for idx in range(len(self.ranges)): 182 | for layer in range(self.ranges[idx][0], self.ranges[idx][1]): 183 | x = self.features[layer](x) 184 | output["x%d"%(idx+1)] = x 185 | 186 | return output 187 | 188 | 189 | ranges = { 190 | 'vgg11': ((0, 3), (3, 6), (6, 11), (11, 16), (16, 21)), 191 | 'vgg13': ((0, 5), (5, 10), (10, 15), (15, 20), (20, 25)), 192 | 'vgg16': ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)), 193 | 'vgg19': ((0, 5), (5, 10), (10, 19), (19, 28), (28, 37)) 194 | } 195 | 196 | # cropped version from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 197 | cfg = { 198 | 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 199 | 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 200 | 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 201 | 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 202 | } 203 | 204 | def make_layers(cfg, batch_norm=False): 205 | layers = [] 206 | in_channels = 3 207 | for v in cfg: 208 | if v == 'M': 209 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 210 | else: 211 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 212 | if batch_norm: 213 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 214 | else: 215 | layers += [conv2d, nn.ReLU(inplace=True)] 216 | in_channels = v 217 | return nn.Sequential(*layers) 218 | 219 | 220 | if __name__ == "__main__": 221 | batch_size, n_class, h, w = 10, 20, 160, 160 222 | 223 | # test output size 224 | vgg_model = VGGNet(requires_grad=True) 225 | input = torch.autograd.Variable(torch.randn(batch_size, 3, 224, 224)) 226 | output = vgg_model(input) 227 | assert output['x5'].size() == torch.Size([batch_size, 512, 7, 7]) 228 | 229 | fcn_model = FCN32s(pretrained_net=vgg_model, n_class=n_class) 230 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w)) 231 | output = fcn_model(input) 232 | assert output.size() == torch.Size([batch_size, n_class, h, w]) 233 | 234 | fcn_model = FCN16s(pretrained_net=vgg_model, n_class=n_class) 235 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w)) 236 | output = fcn_model(input) 237 | assert output.size() == torch.Size([batch_size, n_class, h, w]) 238 | 239 | fcn_model = FCN8s(pretrained_net=vgg_model, n_class=n_class) 240 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w)) 241 | output = fcn_model(input) 242 | assert output.size() == torch.Size([batch_size, n_class, h, w]) 243 | 244 | fcn_model = FCNs(pretrained_net=vgg_model, n_class=n_class) 245 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w)) 246 | output = fcn_model(input) 247 | assert output.size() == torch.Size([batch_size, n_class, h, w]) 248 | 249 | print("Pass size check") 250 | 251 | # test a random batch, loss should decrease 252 | fcn_model = FCNs(pretrained_net=vgg_model, n_class=n_class) 253 | criterion = nn.BCELoss() 254 | optimizer = optim.SGD(fcn_model.parameters(), lr=1e-3, momentum=0.9) 255 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w)) 256 | y = torch.autograd.Variable(torch.randn(batch_size, n_class, h, w), requires_grad=False) 257 | for iter in range(10): 258 | optimizer.zero_grad() 259 | output = fcn_model(input) 260 | output = nn.functional.sigmoid(output) 261 | loss = criterion(output, y) 262 | loss.backward() 263 | print("iter{}, loss {}".format(iter, loss.data[0])) 264 | optimizer.step() 265 | -------------------------------------------------------------------------------- /models/networks/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.networks.architecture import VGG19 5 | 6 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 7 | # When LSGAN is used, it is basically same as MSELoss, 8 | # but it abstracts away the need to create the target label tensor 9 | # that has the same size as the input 10 | class GANLoss(nn.Module): 11 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, 12 | tensor=torch.FloatTensor, opt=None): 13 | super(GANLoss, self).__init__() 14 | self.real_label = target_real_label 15 | self.fake_label = target_fake_label 16 | self.real_label_tensor = None 17 | self.fake_label_tensor = None 18 | self.zero_tensor = None 19 | self.Tensor = tensor 20 | self.gan_mode = gan_mode 21 | self.opt = opt 22 | if gan_mode == 'ls': 23 | pass 24 | elif gan_mode == 'original': 25 | pass 26 | elif gan_mode == 'w': 27 | pass 28 | elif gan_mode == 'hinge': 29 | pass 30 | else: 31 | raise ValueError('Unexpected gan_mode {}'.format(gan_mode)) 32 | 33 | def get_target_tensor(self, input, target_is_real): 34 | if target_is_real: 35 | if self.real_label_tensor is None: 36 | self.real_label_tensor = self.Tensor(1).fill_(self.real_label) 37 | self.real_label_tensor.requires_grad_(False) 38 | return self.real_label_tensor.expand_as(input) 39 | else: 40 | if self.fake_label_tensor is None: 41 | self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label) 42 | self.fake_label_tensor.requires_grad_(False) 43 | return self.fake_label_tensor.expand_as(input) 44 | 45 | def get_zero_tensor(self, input): 46 | if self.zero_tensor is None: 47 | self.zero_tensor = self.Tensor(1).fill_(0) 48 | self.zero_tensor.requires_grad_(False) 49 | return self.zero_tensor.expand_as(input) 50 | 51 | def loss(self, input, target_is_real, for_discriminator=True): 52 | if self.gan_mode == 'original': # cross entropy loss 53 | target_tensor = self.get_target_tensor(input, target_is_real) 54 | loss = F.binary_cross_entropy_with_logits(input, target_tensor) 55 | return loss 56 | elif self.gan_mode == 'ls': 57 | target_tensor = self.get_target_tensor(input, target_is_real) 58 | return F.mse_loss(input, target_tensor) 59 | elif self.gan_mode == 'hinge': 60 | if for_discriminator: 61 | if target_is_real: 62 | minval = torch.min(input - 1, self.get_zero_tensor(input)) 63 | loss = -torch.mean(minval) 64 | else: 65 | minval = torch.min(-input - 1, self.get_zero_tensor(input)) 66 | loss = -torch.mean(minval) 67 | else: 68 | assert target_is_real, "The generator's hinge loss must be aiming for real" 69 | loss = -torch.mean(input) 70 | return loss 71 | else: 72 | # wgan 73 | if target_is_real: 74 | return -input.mean() 75 | else: 76 | return input.mean() 77 | 78 | def __call__(self, input, target_is_real, for_discriminator=True): 79 | # computing loss is a bit complicated because |input| may not be 80 | # a tensor, but list of tensors in case of multiscale discriminator 81 | if isinstance(input, list): 82 | loss = 0 83 | for pred_i in input: 84 | if isinstance(pred_i, list): 85 | pred_i = pred_i[-1] 86 | loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) 87 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) 88 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) 89 | loss += new_loss 90 | return loss / len(input) 91 | else: 92 | return self.loss(input, target_is_real, for_discriminator) 93 | 94 | 95 | # Perceptual loss that uses a pretrained VGG network 96 | class VGGLoss(nn.Module): 97 | #def __init__(self, gpu_ids): 98 | def __init__(self): 99 | super(VGGLoss, self).__init__() 100 | self.vgg = VGG19().cuda() 101 | self.criterion = nn.L1Loss() 102 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] 103 | 104 | def forward(self, x, y): 105 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 106 | loss = 0 107 | for i in range(len(x_vgg)): 108 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 109 | return loss 110 | 111 | 112 | # KL Divergence loss used in VAE with an image encoder 113 | class KLDLoss(nn.Module): 114 | def forward(self, mu, logvar): 115 | return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 116 | 117 | # GAN feature matching loss 118 | class GANFeatMatchingLoss(nn.Module): 119 | def __init__(self, opt, criterionFeat=nn.L1Loss): 120 | super().__init__() 121 | self.opt = opt 122 | self.criterionFeat = criterionFeat() 123 | 124 | def forward(self, pred_fake, pred_real): 125 | loss_G_GAN_Feat = 0 126 | #if not self.opt.no_ganFeat_loss: 127 | feat_weights = 4.0 / (self.opt.n_layers_D + 1) 128 | D_weights = 1.0 / self.opt.num_D 129 | for i in range(self.opt.num_D): 130 | for j in range(len(pred_fake[i])-1): 131 | loss_G_GAN_Feat += D_weights * feat_weights * \ 132 | self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat 133 | 134 | return loss_G_GAN_Feat 135 | -------------------------------------------------------------------------------- /models/networks/normalization.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from models.networks.sync_batchnorm import SynchronizedBatchNorm2d 6 | import torch.nn.utils.spectral_norm as spectral_norm 7 | 8 | 9 | # Returns a function that creates a normalization function 10 | # that does not condition on semantic map 11 | def get_nonspade_norm_layer(opt, norm_type='instance'): 12 | # helper function to get # output channels of the previous layer 13 | def get_out_channel(layer): 14 | if hasattr(layer, 'out_channels'): 15 | return getattr(layer, 'out_channels') 16 | return layer.weight.size(0) 17 | 18 | # this function will be returned 19 | def add_norm_layer(layer): 20 | nonlocal norm_type 21 | if norm_type.startswith('spectral'): 22 | layer = spectral_norm(layer) 23 | subnorm_type = norm_type[len('spectral'):] 24 | 25 | if subnorm_type == 'none' or len(subnorm_type) == 0: 26 | return layer 27 | 28 | # remove bias in the previous layer, which is meaningless 29 | # since it has no effect after normalization 30 | if getattr(layer, 'bias', None) is not None: 31 | delattr(layer, 'bias') 32 | layer.register_parameter('bias', None) 33 | 34 | if subnorm_type == 'batch': 35 | norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) 36 | elif subnorm_type == 'sync_batch': 37 | norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) 38 | elif subnorm_type == 'instance': 39 | norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) 40 | else: 41 | raise ValueError('normalization layer %s is not recognized' % subnorm_type) 42 | 43 | return nn.Sequential(layer, norm_layer) 44 | 45 | return add_norm_layer 46 | 47 | 48 | # Creates SPADE normalization layer based on the given configuration 49 | # SPADE consists of two steps. First, it normalizes the activations using 50 | # your favorite normalization method, such as Batch Norm or Instance Norm. 51 | # Second, it applies scale and bias to the normalized output, conditioned on 52 | # the segmentation map. 53 | # The format of |config_text| is spade(norm)(ks), where 54 | # (norm) specifies the type of parameter-free normalization. 55 | # (e.g. syncbatch, batch, instance) 56 | # (ks) specifies the size of kernel in the SPADE module (e.g. 3x3) 57 | # Example |config_text| will be spadesyncbatch3x3, or spadeinstance5x5. 58 | # Also, the other arguments are 59 | # |norm_nc|: the #channels of the normalized activations, hence the output dim of SPADE 60 | # |label_nc|: the #channels of the input semantic map, hence the input dim of SPADE 61 | class SPADE(nn.Module): 62 | def __init__(self, config_text, norm_nc, label_nc): 63 | super().__init__() 64 | 65 | assert config_text.startswith('spade') 66 | parsed = re.search('spade(\D+)(\d)x\d', config_text) 67 | param_free_norm_type = str(parsed.group(1)) 68 | ks = int(parsed.group(2)) 69 | 70 | if param_free_norm_type == 'instance': 71 | self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) 72 | elif param_free_norm_type == 'syncbatch': 73 | self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) 74 | elif param_free_norm_type == 'batch': 75 | self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) 76 | else: 77 | raise ValueError('%s is not a recognized param-free norm type in SPADE' 78 | % param_free_norm_type) 79 | 80 | # The dimension of the intermediate embedding space. Yes, hardcoded. 81 | nhidden = 128 82 | 83 | pw = ks // 2 84 | self.mlp_shared = nn.Sequential( 85 | nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), 86 | nn.ReLU() 87 | ) 88 | self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 89 | self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 90 | 91 | def forward(self, x, segmap): 92 | 93 | # Part 1. generate parameter-free normalized activations 94 | normalized = self.param_free_norm(x) 95 | 96 | # Part 2. produce scaling and bias conditioned on semantic map 97 | segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') 98 | actv = self.mlp_shared(segmap) 99 | gamma = self.mlp_gamma(actv) 100 | beta = self.mlp_beta(actv) 101 | 102 | # apply scale and bias 103 | out = normalized * (1 + gamma) + beta 104 | 105 | return out 106 | -------------------------------------------------------------------------------- /models/networks/sgnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | try: 5 | from models.networks.architecture import ResnetBlock 6 | except: 7 | from architecture import ResnetBlock 8 | 9 | 10 | class SGNet(nn.Module): 11 | def __init__(self, opt, block=ResnetBlock): 12 | super().__init__() 13 | self.n_class = opt.label_nc 14 | 15 | self.resnet_initial_kernel_size = 7 16 | self.resnet_n_blocks = 9 17 | ngf = 64 18 | activation = nn.ReLU(False) 19 | 20 | self.down = nn.Sequential( 21 | nn.ReflectionPad2d(self.resnet_initial_kernel_size // 2), 22 | nn.Conv2d(self.n_class+3, ngf, kernel_size=self.resnet_initial_kernel_size, stride=2, padding=0), 23 | nn.BatchNorm2d(ngf), 24 | activation, 25 | 26 | nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=1), 27 | nn.BatchNorm2d(ngf*2), 28 | activation, 29 | 30 | nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=2, padding=1), 31 | nn.BatchNorm2d(ngf*4), 32 | activation, 33 | 34 | nn.Conv2d(ngf*4, ngf*8, kernel_size=3, stride=2, padding=1), 35 | nn.BatchNorm2d(ngf*8), 36 | activation, 37 | ) 38 | 39 | # resnet blocks 40 | resnet_blocks = [] 41 | for i in range(self.resnet_n_blocks): 42 | resnet_blocks += [block(ngf*8, norm_layer=nn.BatchNorm2d, kernel_size=3)] 43 | self.bottle_neck = nn.Sequential(*resnet_blocks) 44 | 45 | self.up = nn.Sequential( 46 | nn.ConvTranspose2d(ngf*8, ngf*8, kernel_size=3, stride=2, padding=1, output_padding=1), 47 | nn.BatchNorm2d(ngf*8), 48 | activation, 49 | 50 | nn.ConvTranspose2d(ngf*8, ngf*4, kernel_size=3, stride=2, padding=1, output_padding=1), 51 | nn.BatchNorm2d(ngf*4), 52 | activation, 53 | 54 | nn.ConvTranspose2d(ngf*4, ngf*2, kernel_size=3, stride=2, padding=1, output_padding=1), 55 | nn.BatchNorm2d(ngf*2), 56 | activation, 57 | 58 | nn.ConvTranspose2d(ngf*2, ngf, kernel_size=3, stride=2, padding=1, output_padding=1), 59 | nn.BatchNorm2d(ngf), 60 | activation, 61 | ) 62 | 63 | self.out = nn.Sequential( 64 | nn.ReflectionPad2d(self.resnet_initial_kernel_size // 2), 65 | nn.Conv2d(ngf, 3, kernel_size=7, padding=0), 66 | nn.Tanh() 67 | ) 68 | 69 | def forward(self, x): 70 | x = self.down(x) 71 | x = self.bottle_neck(x) 72 | x = self.up(x) 73 | out = self.out(x) 74 | 75 | return out # shape: 76 | 77 | def generate_fake(self, x): 78 | return self(x) 79 | 80 | if __name__ == '__main__': 81 | class Opt(): 82 | def __init__(self, label_nc=35): 83 | self.label_nc = label_nc 84 | 85 | label_nc = 35 86 | nc = 3 87 | opt = Opt(label_nc=label_nc) 88 | x = torch.randn(2, label_nc+nc, 256, 256).cuda() 89 | model = SGNet(opt) 90 | model.cuda() 91 | 92 | out = model(x) 93 | import pdb; pdb.set_trace() -------------------------------------------------------------------------------- /models/networks/spade_discriminator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from models.networks.base_network import BaseNetwork 5 | from models.networks.normalization import get_nonspade_norm_layer 6 | import util.util as util 7 | 8 | 9 | class MultiscaleDiscriminator(BaseNetwork): 10 | @staticmethod 11 | def modify_commandline_options(parser, is_train): 12 | parser.add_argument('--netD_subarch', type=str, default='n_layer', 13 | help='architecture of each discriminator') 14 | #parser.add_argument('--num_D', type=int, default=2, 15 | parser.add_argument('--num_D', type=int, default=3, 16 | help='number of discriminators to be used in multiscale') 17 | opt, _ = parser.parse_known_args() 18 | 19 | # define properties of each discriminator of the multiscale discriminator 20 | subnetD = util.find_class_in_module(opt.netD_subarch + 'discriminator', 21 | 'models.networks.discriminator') 22 | subnetD.modify_commandline_options(parser, is_train) 23 | 24 | return parser 25 | 26 | def __init__(self, opt, input_nc=3): 27 | super().__init__() 28 | self.opt = opt 29 | 30 | for i in range(opt.num_D): 31 | subnetD = self.create_single_discriminator(opt, input_nc) 32 | self.add_module('discriminator_%d' % i, subnetD) 33 | 34 | def create_single_discriminator(self, opt, input_nc=3): 35 | subarch = opt.netD_subarch 36 | if subarch == 'n_layer': 37 | netD = NLayerDiscriminator(opt, input_nc) 38 | else: 39 | raise ValueError('unrecognized discriminator subarchitecture %s' % subarch) 40 | return netD 41 | 42 | def downsample(self, input): 43 | return F.avg_pool2d(input, kernel_size=3, 44 | stride=2, padding=[1, 1], 45 | count_include_pad=False) 46 | 47 | # Returns list of lists of discriminator outputs. 48 | # The final result is of size opt.num_D x opt.n_layers_D 49 | def forward(self, input): 50 | result = [] 51 | get_intermediate_features = not self.opt.no_ganFeat_loss 52 | for name, D in self.named_children(): 53 | out = D(input) 54 | if not get_intermediate_features: 55 | out = [out] 56 | result.append(out) 57 | input = self.downsample(input) 58 | 59 | return result 60 | 61 | 62 | # Defines the PatchGAN discriminator with the specified arguments. 63 | class NLayerDiscriminator(BaseNetwork): 64 | @staticmethod 65 | def modify_commandline_options(parser, is_train): 66 | #parser.add_argument('--n_layers_D', type=int, default=4, 67 | parser.add_argument('--n_layers_D', type=int, default=3, 68 | help='# layers in each discriminator') 69 | return parser 70 | 71 | def __init__(self, opt, input_nc): 72 | super().__init__() 73 | self.opt = opt 74 | 75 | kw = 4 76 | padw = int(np.ceil((kw - 1.0) / 2)) 77 | nf = opt.ndf 78 | #input_nc = self.compute_D_input_nc(opt) 79 | #input_nc = opt.input_nc 80 | 81 | #norm_layer = get_nonspade_norm_layer(opt, opt.norm_D) 82 | norm_layer = nn.BatchNorm2d 83 | # sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), 84 | # nn.LeakyReLU(0.2, False)]] 85 | sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), 86 | #norm_layer(input_nc), 87 | nn.LeakyReLU(0.2, False)]] 88 | 89 | for n in range(1, opt.n_layers_D): 90 | nf_prev = nf 91 | nf = min(nf * 2, 512) 92 | stride = 1 if n == opt.n_layers_D - 1 else 2 93 | # sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, 94 | # stride=stride, padding=padw)), 95 | # nn.LeakyReLU(0.2, False) 96 | # ]] 97 | sequence += [[nn.Conv2d(nf_prev, nf, kernel_size=kw, 98 | stride=stride, padding=padw), 99 | norm_layer(nf), 100 | nn.LeakyReLU(0.2, False) 101 | ]] 102 | 103 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 104 | 105 | # We divide the layers into groups to extract intermediate layer outputs 106 | for n in range(len(sequence)): 107 | self.add_module('model' + str(n), nn.Sequential(*sequence[n])) 108 | """ 109 | def compute_D_input_nc(self, opt): 110 | #input_nc = opt.label_nc + opt.output_nc 111 | input_nc = opt.label_nc 112 | if opt.contain_dontcare_label: 113 | input_nc += 1 114 | if not opt.no_instance: 115 | input_nc += 1 116 | return input_nc 117 | """ 118 | 119 | def forward(self, input): 120 | results = [input] 121 | for submodel in self.children(): 122 | intermediate_output = submodel(results[-1]) 123 | results.append(intermediate_output) 124 | 125 | get_intermediate_features = not self.opt.no_ganFeat_loss 126 | if get_intermediate_features: 127 | return results[1:] 128 | else: 129 | return results[-1] 130 | -------------------------------------------------------------------------------- /models/networks/spnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | try: 4 | from models.networks.architecture import ResnetBlock 5 | except: 6 | from architecture import ResnetBlock 7 | class SPNet(nn.Module): 8 | def __init__(self, opt, block=ResnetBlock): 9 | super().__init__() 10 | self.n_class = opt.label_nc 11 | self.resnet_initial_kernel_size = 7 12 | self.resnet_n_blocks = 9 13 | ngf = 64 14 | activation = nn.ReLU(False) 15 | 16 | self.down = nn.Sequential( 17 | nn.ReflectionPad2d(self.resnet_initial_kernel_size // 2), 18 | nn.Conv2d(self.n_class+3, ngf, kernel_size=self.resnet_initial_kernel_size, stride=2, padding=0), 19 | nn.BatchNorm2d(ngf), 20 | activation, 21 | 22 | nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=1), 23 | nn.BatchNorm2d(ngf*2), 24 | activation, 25 | 26 | nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=2, padding=1), 27 | nn.BatchNorm2d(ngf*4), 28 | activation, 29 | 30 | nn.Conv2d(ngf*4, ngf*8, kernel_size=3, stride=2, padding=1), 31 | nn.BatchNorm2d(ngf*8), 32 | activation, 33 | ) 34 | 35 | # resnet blocks 36 | resnet_blocks = [] 37 | for i in range(self.resnet_n_blocks): 38 | resnet_blocks += [block(ngf*8, norm_layer=nn.BatchNorm2d, kernel_size=3)] 39 | self.bottle_neck = nn.Sequential(*resnet_blocks) 40 | 41 | self.up = nn.Sequential( 42 | nn.ConvTranspose2d(ngf*8, ngf*8, kernel_size=3, stride=2, padding=1, output_padding=1), 43 | nn.BatchNorm2d(ngf*8), 44 | activation, 45 | 46 | nn.ConvTranspose2d(ngf*8, ngf*4, kernel_size=3, stride=2, padding=1, output_padding=1), 47 | nn.BatchNorm2d(ngf*4), 48 | activation, 49 | 50 | nn.ConvTranspose2d(ngf*4, ngf*2, kernel_size=3, stride=2, padding=1, output_padding=1), 51 | nn.BatchNorm2d(ngf*2), 52 | activation, 53 | 54 | nn.ConvTranspose2d(ngf*2, ngf, kernel_size=3, stride=2, padding=1, output_padding=1), 55 | nn.BatchNorm2d(ngf), 56 | activation, 57 | ) 58 | 59 | self.out = nn.Sequential( 60 | nn.ReflectionPad2d(self.resnet_initial_kernel_size // 2), 61 | nn.Conv2d(ngf, self.n_class, kernel_size=7, padding=0), 62 | nn.Softmax2d() 63 | ) 64 | 65 | def forward(self, x): 66 | x = self.down(x) 67 | x = self.bottle_neck(x) 68 | x = self.up(x) 69 | out = self.out(x) 70 | 71 | return out # shape: 72 | 73 | def generate_fake(self, x): 74 | return self(x) 75 | 76 | 77 | if __name__ == '__main__': 78 | class Opt(): 79 | def __init__(self, label_nc=35): 80 | self.label_nc = label_nc 81 | 82 | label_nc = 35 83 | nc = 3 84 | opt = Opt(label_nc=label_nc) 85 | x = torch.zeros(2, label_nc+nc, 256, 256).cuda() 86 | model = SPNet(opt) 87 | model.cuda() 88 | 89 | out = model(x) 90 | 91 | -------------------------------------------------------------------------------- /models/seg_inpaint_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import models.networks as networks 5 | 6 | from models.networks.spnet import SPNet 7 | from models.networks.sgnet import SGNet 8 | from models.networks.spade_discriminator import MultiscaleDiscriminator 9 | from models.networks.loss import GANLoss, GANFeatMatchingLoss, VGGLoss 10 | 11 | import torchvision.utils as vutils 12 | import torchvision.transforms as transforms 13 | 14 | 15 | def normalize(tensor): 16 | # assum a batch of img tensor 17 | return (tensor - 0.5)/0.5 18 | 19 | def weights_init(m): 20 | classname = m.__class__.__name__ 21 | if classname.find('Conv') != -1: 22 | m.weight.data.normal_(0.0, 0.02) 23 | elif classname.find('BatchNorm2d') != -1: 24 | m.weight.data.normal_(1.0, 0.02) 25 | m.bias.data.fill_(0) 26 | 27 | class SegInpaintModel(torch.nn.Module): 28 | 29 | def __init__(self, opt): 30 | super().__init__() 31 | self.opt = opt 32 | self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \ 33 | else torch.ByteTensor 34 | self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \ 35 | else torch.FloatTensor 36 | 37 | self.SPNet = SPNet(opt=opt) 38 | self.SGNet = SGNet(opt=opt) 39 | self.D_seg = MultiscaleDiscriminator(opt, input_nc=opt.label_nc) 40 | self.D_img = MultiscaleDiscriminator(opt, input_nc=3) 41 | 42 | if len(opt.gpu_ids) > 0: 43 | assert(torch.cuda.is_available()) 44 | self.SPNet.cuda() 45 | self.SGNet.cuda() 46 | self.D_seg.cuda() 47 | self.D_img.cuda() 48 | 49 | self.SPNet.apply(weights_init) 50 | self.SGNet.apply(weights_init) 51 | self.D_seg.apply(weights_init) 52 | self.D_img.apply(weights_init) 53 | 54 | print("=> finish initializing model") 55 | 56 | # loss 57 | self.criterion_GAN = GANLoss(gan_mode=opt.gan_mode, tensor=self.FloatTensor, opt=self.opt) 58 | self.criterion_Feat = GANFeatMatchingLoss(opt=opt) 59 | self.criterion_VGG = VGGLoss() 60 | 61 | self.normalize = normalize 62 | 63 | def forward(self, data, mode, fake_seg=None, fake_img=None): 64 | """ 65 | - image: in SPG-Net, img_size = cx256x256 66 | - label, instance, image, path = data 67 | """ 68 | input_semantics, real_image, corrupted_seg, corrupted_img, occ_mask = self.preprocess_input(data) 69 | 70 | if mode == 'spn': 71 | spn_loss, generated_seg = self.compute_spn_loss(input_semantics, corrupted_seg, corrupted_img) 72 | return spn_loss, generated_seg 73 | elif mode == 'sgn': 74 | sgn_loss, generated_img = self.compute_sgn_loss(fake_seg, real_image, corrupted_img) 75 | return sgn_loss, generated_img 76 | elif mode == 'd_seg': 77 | d_seg_loss = self.compute_d_seg_loss(input_semantics, corrupted_img, corrupted_seg) 78 | return d_seg_loss 79 | elif mode == 'd_img': 80 | d_img_loss = self.compute_d_img_loss(real_image, fake_seg, corrupted_img) 81 | #return d_img_loss 82 | return d_img_loss, real_image, corrupted_img 83 | else: 84 | raise ValueError("|mode| is invalid") 85 | 86 | def preprocess_input(self, data): 87 | # copy from SPADE's Pix2PixModel 88 | # label: 1~33 89 | 90 | # move to GPU and change data types 91 | data['label'] = data['label'].long() 92 | if self.use_gpu(): 93 | data['label'] = data['label'].cuda() 94 | data['instance'] = data['instance'].cuda() 95 | data['image'] = data['image'].cuda() 96 | data['mask'] = data['mask'].cuda() 97 | 98 | image = data['image'] 99 | occ_mask = data['mask'] 100 | 101 | # create one-hot label map 102 | label_map = data['label'] 103 | bs, _, h, w = label_map.size() 104 | nc = self.opt.label_nc + 1 if self.opt.contain_dontcare_label \ 105 | else self.opt.label_nc 106 | 107 | # NOTE: let label=0 be masked region 108 | input_label = self.FloatTensor(bs, nc, h, w).zero_() 109 | input_semantics = input_label.scatter_(1, label_map, 1.0) 110 | 111 | corrupted_seg = occ_mask * input_semantics # NOTE haven't checked yet 112 | corrupted_img = occ_mask * image 113 | 114 | # normalize 115 | image = self.normalize(image) 116 | corrupted_img = self.normalize(corrupted_img) 117 | 118 | return input_semantics, image, corrupted_seg, corrupted_img, occ_mask 119 | 120 | def compute_spn_loss(self, input_semantics, corrupted_seg, corrupted_img): 121 | # Generator - SPNet 122 | G_SPNet_losses = {} 123 | 124 | real_seg = input_semantics 125 | 126 | # for GAN feature matching loss 127 | pred_real_seg = self.discriminate_seg(real_seg) 128 | 129 | # SP-Net 130 | input_spn = torch.cat([corrupted_img, corrupted_seg], dim=1) # concate corrupted + output_seg 131 | fake_seg = self.generate_fake_seg(input_spn) # generated fake_seg 132 | 133 | pred_fake_seg = self.discriminate_seg(fake_seg) 134 | G_SPNet_losses['GAN'] = self.criterion_GAN(pred_fake_seg, target_is_real=True, for_discriminator=False) 135 | G_SPNet_losses['GAN_Feat'] = self.criterion_Feat(pred_fake_seg, pred_real_seg) # or perceptual loss 136 | return G_SPNet_losses, fake_seg 137 | 138 | def compute_sgn_loss(self, fake_seg, real_image, corrupted_img): 139 | # Generator - SGNet 140 | G_SGNet_losses = {} 141 | 142 | # for GAN feature matching loss 143 | with torch.no_grad(): 144 | pred_real_img = self.D_img(real_image) 145 | 146 | fake_seg = fake_seg.detach() 147 | fake_seg.requires_grad_() 148 | input_sgn = torch.cat([corrupted_img, fake_seg], dim=1) # concate corrupted + output_seg 149 | fake_image = self.SGNet.generate_fake(input_sgn) # generated fake_image 150 | 151 | pred_fake_img = self.D_img(fake_image) 152 | G_SGNet_losses['GAN'] = self.criterion_GAN(pred_fake_img, target_is_real=True, for_discriminator=False) 153 | G_SGNet_losses['GAN_Feat'] = self.criterion_Feat(pred_fake_img, pred_real_img) 154 | G_SGNet_losses['VGG'] = self.criterion_VGG(real_image, fake_image)*self.opt.lambda_feat # or alex 155 | 156 | return G_SGNet_losses, fake_image 157 | 158 | def compute_d_seg_loss(self, real_seg, corrupted_img, corrupted_seg): 159 | # real_seg is input_semantics 160 | # reference: https://github.com/NVlabs/SPADE/blob/master/models/pix2pix_model.py#L166-L181 161 | D_seg_losses = {} 162 | with torch.no_grad(): 163 | input_spn = torch.cat([corrupted_img, corrupted_seg], dim=1) 164 | fake_seg = self.SPNet.generate_fake(input_spn) 165 | fake_seg = fake_seg.detach() 166 | fake_seg.requires_grad_() 167 | 168 | pred_fake_seg = self.D_seg(fake_seg) 169 | D_seg_losses['D_fake'] = self.criterion_GAN(pred_fake_seg, target_is_real=False, for_discriminator=True) 170 | 171 | pred_real_seg = self.D_seg(real_seg) 172 | D_seg_losses['D_real'] = self.criterion_GAN(pred_real_seg, target_is_real=True, for_discriminator=True) 173 | return D_seg_losses 174 | 175 | def compute_d_img_loss(self, real_image, fake_seg, corrupted_img): 176 | D_img_losses = {} 177 | with torch.no_grad(): 178 | input_sgn = torch.cat([corrupted_img, fake_seg], dim=1) 179 | fake_image = self.SGNet.generate_fake(input_sgn) 180 | fake_image = fake_image.detach() 181 | fake_image.requires_grad_() 182 | 183 | pred_fake_image = self.D_img(fake_image) 184 | D_img_losses['D_fake'] = self.criterion_GAN(pred_fake_image, target_is_real=False, for_discriminator=True) 185 | 186 | pred_real_image = self.D_img(real_image) 187 | D_img_losses['D_real'] = self.criterion_GAN(pred_real_image, target_is_real=True, for_discriminator=True) 188 | return D_img_losses 189 | 190 | def generate_fake_seg(self, input_spn): 191 | # NOTE: should produce one_hot? 192 | fake_prob = self.SPNet(input_spn) 193 | return fake_prob 194 | 195 | def generate_fake_img(self, inpug_sgn): 196 | return self.SGNet(inpug_sgn) 197 | 198 | def discriminate_seg(self, seg): 199 | pred_seg = self.D_seg(seg) 200 | return pred_seg 201 | 202 | def discriminate_img(self, img): 203 | pred_img = self.D_img(img) 204 | return pred_img 205 | 206 | def create_optimizers(self, opt): 207 | SPNet_param = list(self.SPNet.parameters()) 208 | SGNet_param = list(self.SGNet.parameters()) 209 | D_seg_param = list(self.D_seg.parameters()) 210 | D_img_param = list(self.D_img.parameters()) 211 | 212 | if opt.no_TTUR: 213 | beta1, beta2 = opt.beta1, opt.beta2 214 | G_lr, D_lr = opt.lr, opt.lr 215 | else: 216 | beta1, beta2 = 0, 0.9 217 | G_lr, D_lr = opt.lr / 2, opt.lr * 2 218 | 219 | optimizer_SPNet = optim.Adam(SPNet_param, lr=G_lr, betas=(beta1, beta2)) 220 | optimizer_SGNet = optim.Adam(SGNet_param, lr=G_lr, betas=(beta1, beta2)) 221 | optimizer_D_seg = optim.Adam(D_seg_param, lr=D_lr, betas=(beta1, beta2)) 222 | optimizer_D_img = optim.Adam(D_img_param, lr=D_lr, betas=(beta1, beta2)) 223 | 224 | return optimizer_SPNet, optimizer_SGNet, optimizer_D_seg, optimizer_D_img 225 | 226 | 227 | def get_edges(self, t): 228 | edge = self.ByteTensor(t.size()).zero_() 229 | edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) 230 | edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) 231 | edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) 232 | edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) 233 | return edge.float() 234 | 235 | def save(self, path, epoch): 236 | states = { 237 | 238 | 'SPNet': self.SPNet.cpu().state_dict(), 239 | 'SGNet': self.SGNet.cpu().state_dict(), 240 | 'D_seg': self.D_seg.cpu().state_dict(), 241 | 'D_img': self.D_img.cpu().state_dict(), 242 | 243 | # 'optimizer_SPNet': self.optimizer_SPNet.state_dict(), 244 | # 'optimizer_SGNet': self.optimizer_SGNet.state_dict(), 245 | # 'optimizer_D_seg': self.optimizer_D_seg.state_dict(), 246 | # 'optimizer_D_img': self.optimizer_D_img.state_dict(), 247 | 'epoch': epoch, 248 | } 249 | torch.save(states, path) 250 | if len(self.opt.gpu_ids) > 0: 251 | assert(torch.cuda.is_available()) 252 | self.SPNet.cuda() 253 | self.SGNet.cuda() 254 | self.D_seg.cuda() 255 | self.D_img.cuda() 256 | 257 | def load(self, path): 258 | pass 259 | 260 | def use_gpu(self): 261 | return len(self.opt.gpu_ids) > 0 -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yccyenchicheng/pytorch-SegInpaint/13de7deb7ad11508294d8ed2e7a2c8f67ebc04d9/trainers/__init__.py -------------------------------------------------------------------------------- /trainers/seg_inpaint_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | adapted from SPADE 3 | """ 4 | import torch 5 | from models.networks.sync_batchnorm import DataParallelWithCallback 6 | from models.seg_inpaint_model import SegInpaintModel 7 | 8 | class SegInpaintTrainer(): 9 | """ 10 | Trainer creates the model and optimizers, and uses them to 11 | updates the weights of the network while reporting losses 12 | and the latest visuals to visualize the progress in training. 13 | """ 14 | 15 | def __init__(self, opt): 16 | self.opt = opt 17 | self.seg_inpaint_model = SegInpaintModel(opt) 18 | if len(opt.gpu_ids) > 0: 19 | self.seg_inpaint_model = DataParallelWithCallback(self.seg_inpaint_model, 20 | device_ids=opt.gpu_ids) 21 | self.seg_inpaint_model_on_one_gpu = self.seg_inpaint_model.module 22 | else: 23 | self.seg_inpaint_model_on_one_gpu = self.seg_inpaint_model 24 | 25 | self.generated = None 26 | 27 | self.optimizer_SPNet, self.optimizer_SGNet, self.optimizer_D_seg, self.optimizer_D_img = \ 28 | self.seg_inpaint_model_on_one_gpu.create_optimizers(opt) 29 | 30 | self.old_lr = opt.lr 31 | 32 | def run_generator_one_step(self, data): 33 | # first update SPNet 34 | self.optimizer_SPNet.zero_grad() 35 | spn_losses, generated_seg = self.seg_inpaint_model(data, mode='spn') 36 | spn_loss = sum(spn_losses.values()).mean() 37 | spn_loss.backward() 38 | self.optimizer_SPNet.step() 39 | 40 | self.spn_losses = spn_losses 41 | self.generated_seg = generated_seg 42 | 43 | # then udpate SGNet 44 | self.optimizer_SGNet.zero_grad() 45 | sgn_losses, generated_img = self.seg_inpaint_model(data, mode='sgn', fake_seg=generated_seg) 46 | sgn_loss = sum(sgn_losses.values()).mean() 47 | sgn_loss.backward() 48 | self.optimizer_SGNet.step() 49 | 50 | self.sgn_losses = sgn_losses 51 | self.generated_img = generated_img 52 | 53 | def run_discriminator_one_step(self, data): 54 | # first D_seg 55 | self.optimizer_D_seg.zero_grad() 56 | d_seg_losses = self.seg_inpaint_model(data, mode='d_seg') 57 | d_seg_loss = sum(d_seg_losses.values()).mean() 58 | d_seg_loss.backward() 59 | self.optimizer_D_seg.step() 60 | self.d_seg_losses = d_seg_losses 61 | 62 | self.optimizer_D_img.zero_grad() 63 | #d_img_losses = self.seg_inpaint_model(data, mode='d_img', fake_seg=self.generated_seg) 64 | d_img_losses, real_img, corruped_img = self.seg_inpaint_model(data, mode='d_img', fake_seg=self.generated_seg) 65 | d_img_loss = sum(d_img_losses.values()).mean() 66 | d_img_loss.backward() 67 | self.optimizer_D_img.step() 68 | self.d_img_losses = d_img_losses 69 | 70 | # NOTE: for display current results 71 | self.real_img = real_img 72 | self.corruped_img = corruped_img 73 | 74 | def get_latest_results(self): 75 | return self.real_img, self.corruped_img, self.generated_seg, self.generated_img 76 | 77 | def get_loss_str(self): 78 | def gather_str(name, errors): 79 | msg = '%s: ' % name 80 | for k, v in errors.items(): 81 | v = v.mean().float() 82 | msg += '%s: %.3f ' % (k, v) 83 | msg += '| ' 84 | return msg 85 | spn_l, sgn_l, d_seg_l, d_img_l = self.spn_losses, self.sgn_losses, self.d_seg_losses, self.d_img_losses 86 | 87 | return gather_str('SPN', spn_l) + gather_str('SGN', sgn_l) + \ 88 | gather_str('DSeg', d_seg_l) + gather_str('DImg', d_img_l) 89 | 90 | def save(self, path, epoch): 91 | self.seg_inpaint_model_on_one_gpu.save(path, epoch) 92 | 93 | ################################################################## 94 | # Helper functions 95 | ################################################################## 96 | 97 | def update_learning_rate(self, epoch): 98 | if epoch > self.opt.niter: 99 | lrd = self.opt.lr / self.opt.niter_decay 100 | new_lr = self.old_lr - lrd 101 | else: 102 | new_lr = self.old_lr 103 | 104 | if new_lr != self.old_lr: 105 | if self.opt.no_TTUR: 106 | new_lr_G = new_lr 107 | new_lr_D = new_lr 108 | else: 109 | new_lr_G = new_lr / 2 110 | new_lr_D = new_lr * 2 111 | 112 | for param_group in self.seg_inpaint_model_on_one_gpu.optimizer_SPNet.param_groups: 113 | param_group['lr'] = new_lr_G 114 | for param_group in self.seg_inpaint_model_on_one_gpu.optimizer_SGNet.param_groups: 115 | param_group['lr'] = new_lr_G 116 | for param_group in self.seg_inpaint_model_on_one_gpu.optimizer_D_seg.param_groups: 117 | param_group['lr'] = new_lr_D 118 | for param_group in self.seg_inpaint_model_on_one_gpu.optimizer_D_img.param_groups: 119 | param_group['lr'] = new_lr_D 120 | print('update learning rate: %f -> %f' % (self.old_lr, new_lr)) 121 | self.old_lr = new_lr 122 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yccyenchicheng/pytorch-SegInpaint/13de7deb7ad11508294d8ed2e7a2c8f67ebc04d9/util/__init__.py -------------------------------------------------------------------------------- /util/coco.py: -------------------------------------------------------------------------------- 1 | def id2label(id): 2 | if id == 182: 3 | id = 0 4 | else: 5 | id = id + 1 6 | labelmap = \ 7 | {0: 'unlabeled', 8 | 1: 'person', 9 | 2: 'bicycle', 10 | 3: 'car', 11 | 4: 'motorcycle', 12 | 5: 'airplane', 13 | 6: 'bus', 14 | 7: 'train', 15 | 8: 'truck', 16 | 9: 'boat', 17 | 10: 'traffic light', 18 | 11: 'fire hydrant', 19 | 12: 'street sign', 20 | 13: 'stop sign', 21 | 14: 'parking meter', 22 | 15: 'bench', 23 | 16: 'bird', 24 | 17: 'cat', 25 | 18: 'dog', 26 | 19: 'horse', 27 | 20: 'sheep', 28 | 21: 'cow', 29 | 22: 'elephant', 30 | 23: 'bear', 31 | 24: 'zebra', 32 | 25: 'giraffe', 33 | 26: 'hat', 34 | 27: 'backpack', 35 | 28: 'umbrella', 36 | 29: 'shoe', 37 | 30: 'eye glasses', 38 | 31: 'handbag', 39 | 32: 'tie', 40 | 33: 'suitcase', 41 | 34: 'frisbee', 42 | 35: 'skis', 43 | 36: 'snowboard', 44 | 37: 'sports ball', 45 | 38: 'kite', 46 | 39: 'baseball bat', 47 | 40: 'baseball glove', 48 | 41: 'skateboard', 49 | 42: 'surfboard', 50 | 43: 'tennis racket', 51 | 44: 'bottle', 52 | 45: 'plate', 53 | 46: 'wine glass', 54 | 47: 'cup', 55 | 48: 'fork', 56 | 49: 'knife', 57 | 50: 'spoon', 58 | 51: 'bowl', 59 | 52: 'banana', 60 | 53: 'apple', 61 | 54: 'sandwich', 62 | 55: 'orange', 63 | 56: 'broccoli', 64 | 57: 'carrot', 65 | 58: 'hot dog', 66 | 59: 'pizza', 67 | 60: 'donut', 68 | 61: 'cake', 69 | 62: 'chair', 70 | 63: 'couch', 71 | 64: 'potted plant', 72 | 65: 'bed', 73 | 66: 'mirror', 74 | 67: 'dining table', 75 | 68: 'window', 76 | 69: 'desk', 77 | 70: 'toilet', 78 | 71: 'door', 79 | 72: 'tv', 80 | 73: 'laptop', 81 | 74: 'mouse', 82 | 75: 'remote', 83 | 76: 'keyboard', 84 | 77: 'cell phone', 85 | 78: 'microwave', 86 | 79: 'oven', 87 | 80: 'toaster', 88 | 81: 'sink', 89 | 82: 'refrigerator', 90 | 83: 'blender', 91 | 84: 'book', 92 | 85: 'clock', 93 | 86: 'vase', 94 | 87: 'scissors', 95 | 88: 'teddy bear', 96 | 89: 'hair drier', 97 | 90: 'toothbrush', 98 | 91: 'hair brush', # Last class of Thing 99 | 92: 'banner', # Beginning of Stuff 100 | 93: 'blanket', 101 | 94: 'branch', 102 | 95: 'bridge', 103 | 96: 'building-other', 104 | 97: 'bush', 105 | 98: 'cabinet', 106 | 99: 'cage', 107 | 100: 'cardboard', 108 | 101: 'carpet', 109 | 102: 'ceiling-other', 110 | 103: 'ceiling-tile', 111 | 104: 'cloth', 112 | 105: 'clothes', 113 | 106: 'clouds', 114 | 107: 'counter', 115 | 108: 'cupboard', 116 | 109: 'curtain', 117 | 110: 'desk-stuff', 118 | 111: 'dirt', 119 | 112: 'door-stuff', 120 | 113: 'fence', 121 | 114: 'floor-marble', 122 | 115: 'floor-other', 123 | 116: 'floor-stone', 124 | 117: 'floor-tile', 125 | 118: 'floor-wood', 126 | 119: 'flower', 127 | 120: 'fog', 128 | 121: 'food-other', 129 | 122: 'fruit', 130 | 123: 'furniture-other', 131 | 124: 'grass', 132 | 125: 'gravel', 133 | 126: 'ground-other', 134 | 127: 'hill', 135 | 128: 'house', 136 | 129: 'leaves', 137 | 130: 'light', 138 | 131: 'mat', 139 | 132: 'metal', 140 | 133: 'mirror-stuff', 141 | 134: 'moss', 142 | 135: 'mountain', 143 | 136: 'mud', 144 | 137: 'napkin', 145 | 138: 'net', 146 | 139: 'paper', 147 | 140: 'pavement', 148 | 141: 'pillow', 149 | 142: 'plant-other', 150 | 143: 'plastic', 151 | 144: 'platform', 152 | 145: 'playingfield', 153 | 146: 'railing', 154 | 147: 'railroad', 155 | 148: 'river', 156 | 149: 'road', 157 | 150: 'rock', 158 | 151: 'roof', 159 | 152: 'rug', 160 | 153: 'salad', 161 | 154: 'sand', 162 | 155: 'sea', 163 | 156: 'shelf', 164 | 157: 'sky-other', 165 | 158: 'skyscraper', 166 | 159: 'snow', 167 | 160: 'solid-other', 168 | 161: 'stairs', 169 | 162: 'stone', 170 | 163: 'straw', 171 | 164: 'structural-other', 172 | 165: 'table', 173 | 166: 'tent', 174 | 167: 'textile-other', 175 | 168: 'towel', 176 | 169: 'tree', 177 | 170: 'vegetable', 178 | 171: 'wall-brick', 179 | 172: 'wall-concrete', 180 | 173: 'wall-other', 181 | 174: 'wall-panel', 182 | 175: 'wall-stone', 183 | 176: 'wall-tile', 184 | 177: 'wall-wood', 185 | 178: 'water-other', 186 | 179: 'waterdrops', 187 | 180: 'window-blind', 188 | 181: 'window-other', 189 | 182: 'wood'} 190 | if id in labelmap: 191 | return labelmap[id] 192 | else: 193 | return 'unknown' 194 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | 2 | import datetime 3 | import dominate 4 | from dominate.tags import * 5 | import os 6 | 7 | 8 | class HTML: 9 | def __init__(self, web_dir, title, refresh=0): 10 | if web_dir.endswith('.html'): 11 | web_dir, html_name = os.path.split(web_dir) 12 | else: 13 | web_dir, html_name = web_dir, 'index.html' 14 | self.title = title 15 | self.web_dir = web_dir 16 | self.html_name = html_name 17 | self.img_dir = os.path.join(self.web_dir, 'images') 18 | if len(self.web_dir) > 0 and not os.path.exists(self.web_dir): 19 | os.makedirs(self.web_dir) 20 | if len(self.web_dir) > 0 and not os.path.exists(self.img_dir): 21 | os.makedirs(self.img_dir) 22 | 23 | self.doc = dominate.document(title=title) 24 | with self.doc: 25 | h1(datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")) 26 | if refresh > 0: 27 | with self.doc.head: 28 | meta(http_equiv="refresh", content=str(refresh)) 29 | 30 | def get_image_dir(self): 31 | return self.img_dir 32 | 33 | def add_header(self, str): 34 | with self.doc: 35 | h3(str) 36 | 37 | def add_table(self, border=1): 38 | self.t = table(border=border, style="table-layout: fixed;") 39 | self.doc.add(self.t) 40 | 41 | def add_images(self, ims, txts, links, width=512): 42 | self.add_table() 43 | with self.t: 44 | with tr(): 45 | for im, txt, link in zip(ims, txts, links): 46 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 47 | with p(): 48 | with a(href=os.path.join('images', link)): 49 | img(style="width:%dpx" % (width), src=os.path.join('images', im)) 50 | br() 51 | p(txt.encode('utf-8')) 52 | 53 | def save(self): 54 | html_file = os.path.join(self.web_dir, self.html_name) 55 | f = open(html_file, 'wt') 56 | f.write(self.doc.render()) 57 | f.close() 58 | 59 | 60 | if __name__ == '__main__': 61 | html = HTML('web/', 'test_html') 62 | html.add_header('hello world') 63 | 64 | ims = [] 65 | txts = [] 66 | links = [] 67 | for n in range(4): 68 | ims.append('image_%d.jpg' % n) 69 | txts.append('text_%d' % n) 70 | links.append('image_%d.jpg' % n) 71 | html.add_images(ims, txts, links) 72 | html.save() 73 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | import importlib 4 | import torch 5 | from argparse import Namespace 6 | import numpy as np 7 | from PIL import Image 8 | import os 9 | import argparse 10 | import dill as pickle 11 | import util.coco 12 | 13 | 14 | def save_obj(obj, name): 15 | with open(name, 'wb') as f: 16 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 17 | 18 | 19 | def load_obj(name): 20 | with open(name, 'rb') as f: 21 | return pickle.load(f) 22 | 23 | # returns a configuration for creating a generator 24 | # |default_opt| should be the opt of the current experiment 25 | # |**kwargs|: if any configuration should be overriden, it can be specified here 26 | 27 | 28 | def copyconf(default_opt, **kwargs): 29 | conf = argparse.Namespace(**vars(default_opt)) 30 | for key in kwargs: 31 | print(key, kwargs[key]) 32 | setattr(conf, key, kwargs[key]) 33 | return conf 34 | 35 | 36 | def tile_images(imgs, picturesPerRow=4): 37 | """ Code borrowed from 38 | https://stackoverflow.com/questions/26521365/cleanly-tile-numpy-array-of-images-stored-in-a-flattened-1d-format/26521997 39 | """ 40 | 41 | # Padding 42 | if imgs.shape[0] % picturesPerRow == 0: 43 | rowPadding = 0 44 | else: 45 | rowPadding = picturesPerRow - imgs.shape[0] % picturesPerRow 46 | if rowPadding > 0: 47 | imgs = np.concatenate([imgs, np.zeros((rowPadding, *imgs.shape[1:]), dtype=imgs.dtype)], axis=0) 48 | 49 | # Tiling Loop (The conditionals are not necessary anymore) 50 | tiled = [] 51 | for i in range(0, imgs.shape[0], picturesPerRow): 52 | tiled.append(np.concatenate([imgs[j] for j in range(i, i + picturesPerRow)], axis=1)) 53 | 54 | tiled = np.concatenate(tiled, axis=0) 55 | return tiled 56 | 57 | 58 | # Converts a Tensor into a Numpy array 59 | # |imtype|: the desired type of the converted numpy array 60 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False): 61 | if isinstance(image_tensor, list): 62 | image_numpy = [] 63 | for i in range(len(image_tensor)): 64 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) 65 | return image_numpy 66 | 67 | if image_tensor.dim() == 4: 68 | # transform each image in the batch 69 | images_np = [] 70 | for b in range(image_tensor.size(0)): 71 | one_image = image_tensor[b] 72 | one_image_np = tensor2im(one_image) 73 | images_np.append(one_image_np.reshape(1, *one_image_np.shape)) 74 | images_np = np.concatenate(images_np, axis=0) 75 | if tile: 76 | images_tiled = tile_images(images_np) 77 | return images_tiled 78 | else: 79 | return images_np 80 | 81 | if image_tensor.dim() == 2: 82 | image_tensor = image_tensor.unsqueeze(0) 83 | image_numpy = image_tensor.detach().cpu().float().numpy() 84 | if normalize: 85 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 86 | else: 87 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 88 | image_numpy = np.clip(image_numpy, 0, 255) 89 | if image_numpy.shape[2] == 1: 90 | image_numpy = image_numpy[:, :, 0] 91 | return image_numpy.astype(imtype) 92 | 93 | 94 | # Converts a one-hot tensor into a colorful label map 95 | def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False): 96 | if label_tensor.dim() == 4: 97 | # transform each image in the batch 98 | images_np = [] 99 | for b in range(label_tensor.size(0)): 100 | one_image = label_tensor[b] 101 | one_image_np = tensor2label(one_image, n_label, imtype) 102 | images_np.append(one_image_np.reshape(1, *one_image_np.shape)) 103 | images_np = np.concatenate(images_np, axis=0) 104 | if tile: 105 | images_tiled = tile_images(images_np) 106 | return images_tiled 107 | else: 108 | images_np = images_np[0] 109 | return images_np 110 | 111 | if label_tensor.dim() == 1: 112 | return np.zeros((64, 64, 3), dtype=np.uint8) 113 | if n_label == 0: 114 | return tensor2im(label_tensor, imtype) 115 | label_tensor = label_tensor.cpu().float() 116 | if label_tensor.size()[0] > 1: 117 | label_tensor = label_tensor.max(0, keepdim=True)[1] 118 | label_tensor = Colorize(n_label)(label_tensor) 119 | label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) 120 | result = label_numpy.astype(imtype) 121 | return result 122 | 123 | 124 | def save_image(image_numpy, image_path, create_dir=False): 125 | if create_dir: 126 | os.makedirs(os.path.dirname(image_path), exist_ok=True) 127 | if len(image_numpy.shape) == 2: 128 | image_numpy = np.expand_dims(image_numpy, axis=2) 129 | if image_numpy.shape[2] == 1: 130 | image_numpy = np.repeat(image_numpy, 3, 2) 131 | image_pil = Image.fromarray(image_numpy) 132 | 133 | # save to png 134 | image_pil.save(image_path.replace('.jpg', '.png')) 135 | 136 | 137 | def mkdirs(paths): 138 | if isinstance(paths, list) and not isinstance(paths, str): 139 | for path in paths: 140 | mkdir(path) 141 | else: 142 | mkdir(paths) 143 | 144 | 145 | def mkdir(path): 146 | if not os.path.exists(path): 147 | os.makedirs(path) 148 | 149 | 150 | def atoi(text): 151 | return int(text) if text.isdigit() else text 152 | 153 | 154 | def natural_keys(text): 155 | ''' 156 | alist.sort(key=natural_keys) sorts in human order 157 | http://nedbatchelder.com/blog/200712/human_sorting.html 158 | (See Toothy's implementation in the comments) 159 | ''' 160 | return [atoi(c) for c in re.split('(\d+)', text)] 161 | 162 | 163 | def natural_sort(items): 164 | items.sort(key=natural_keys) 165 | 166 | 167 | def str2bool(v): 168 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 169 | return True 170 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 171 | return False 172 | else: 173 | raise argparse.ArgumentTypeError('Boolean value expected.') 174 | 175 | 176 | def find_class_in_module(target_cls_name, module): 177 | target_cls_name = target_cls_name.replace('_', '').lower() 178 | clslib = importlib.import_module(module) 179 | cls = None 180 | for name, clsobj in clslib.__dict__.items(): 181 | if name.lower() == target_cls_name: 182 | cls = clsobj 183 | 184 | if cls is None: 185 | print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)) 186 | exit(0) 187 | 188 | return cls 189 | 190 | 191 | def save_network(net, label, epoch, opt): 192 | save_filename = '%s_net_%s.pth' % (epoch, label) 193 | save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename) 194 | torch.save(net.cpu().state_dict(), save_path) 195 | if len(opt.gpu_ids) and torch.cuda.is_available(): 196 | net.cuda() 197 | 198 | 199 | def load_network(net, label, epoch, opt): 200 | save_filename = '%s_net_%s.pth' % (epoch, label) 201 | save_dir = os.path.join(opt.checkpoints_dir, opt.name) 202 | save_path = os.path.join(save_dir, save_filename) 203 | weights = torch.load(save_path) 204 | net.load_state_dict(weights) 205 | return net 206 | 207 | 208 | ############################################################################### 209 | # Code from 210 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py 211 | # Modified so it complies with the Citscape label map colors 212 | ############################################################################### 213 | def uint82bin(n, count=8): 214 | """returns the binary of integer n, count refers to amount of bits""" 215 | return ''.join([str((n >> y) & 1) for y in range(count - 1, -1, -1)]) 216 | 217 | 218 | def labelcolormap(N): 219 | if N == 35: # cityscape 220 | cmap = np.array([(0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (111, 74, 0), (81, 0, 81), 221 | (128, 64, 128), (244, 35, 232), (250, 170, 160), (230, 150, 140), (70, 70, 70), (102, 102, 156), (190, 153, 153), 222 | (180, 165, 180), (150, 100, 100), (150, 120, 90), (153, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0), 223 | (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70), 224 | (0, 60, 100), (0, 0, 90), (0, 0, 110), (0, 80, 100), (0, 0, 230), (119, 11, 32), (0, 0, 142)], 225 | dtype=np.uint8) 226 | else: 227 | cmap = np.zeros((N, 3), dtype=np.uint8) 228 | for i in range(N): 229 | r, g, b = 0, 0, 0 230 | id = i + 1 # let's give 0 a color 231 | for j in range(7): 232 | str_id = uint82bin(id) 233 | r = r ^ (np.uint8(str_id[-1]) << (7 - j)) 234 | g = g ^ (np.uint8(str_id[-2]) << (7 - j)) 235 | b = b ^ (np.uint8(str_id[-3]) << (7 - j)) 236 | id = id >> 3 237 | cmap[i, 0] = r 238 | cmap[i, 1] = g 239 | cmap[i, 2] = b 240 | 241 | if N == 182: # COCO 242 | important_colors = { 243 | 'sea': (54, 62, 167), 244 | 'sky-other': (95, 219, 255), 245 | 'tree': (140, 104, 47), 246 | 'clouds': (170, 170, 170), 247 | 'grass': (29, 195, 49) 248 | } 249 | for i in range(N): 250 | name = util.coco.id2label(i) 251 | if name in important_colors: 252 | color = important_colors[name] 253 | cmap[i] = np.array(list(color)) 254 | 255 | return cmap 256 | 257 | 258 | class Colorize(object): 259 | def __init__(self, n=35): 260 | self.cmap = labelcolormap(n) 261 | self.cmap = torch.from_numpy(self.cmap[:n]) 262 | 263 | def __call__(self, gray_image): 264 | size = gray_image.size() 265 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 266 | 267 | for label in range(0, len(self.cmap)): 268 | mask = (label == gray_image[0]).cpu() 269 | color_image[0][mask] = self.cmap[label][0] 270 | color_image[1][mask] = self.cmap[label][1] 271 | color_image[2][mask] = self.cmap[label][2] 272 | 273 | return color_image 274 | 275 | 276 | # added by yenchi 277 | def print_loss_dict(loss): 278 | msg = "" 279 | for k, v in loss.items(): 280 | msg += "%s: %.4f | " % (k, v) 281 | 282 | return msg[:-3] -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import ntpath 5 | import time 6 | from . import util 7 | from . import html 8 | import scipy.misc 9 | try: 10 | from StringIO import StringIO # Python 2.7 11 | except ImportError: 12 | from io import BytesIO # Python 3.x 13 | 14 | class Visualizer(): 15 | def __init__(self, opt): 16 | self.opt = opt 17 | self.tf_log = opt.isTrain and opt.tf_log 18 | self.use_html = opt.isTrain and not opt.no_html 19 | self.win_size = opt.display_winsize 20 | self.name = opt.name 21 | if self.tf_log: 22 | import tensorflow as tf 23 | self.tf = tf 24 | self.log_dir = os.path.join(opt.log_root, opt.name, 'logs') 25 | self.writer = tf.summary.FileWriter(self.log_dir) 26 | 27 | if self.use_html: 28 | #self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 29 | self.web_dir = os.path.join(opt.log_root, opt.name, 'web') 30 | self.img_dir = os.path.join(self.web_dir, 'images') 31 | print('create web directory %s...' % self.web_dir) 32 | util.mkdirs([self.web_dir, self.img_dir]) 33 | if opt.isTrain: 34 | self.log_name = os.path.join(opt.log_root, opt.name, 'loss_log.txt') 35 | with open(self.log_name, "a") as log_file: 36 | now = time.strftime("%c") 37 | log_file.write('================ Training Loss (%s) ================\n' % now) 38 | 39 | # |visuals|: dictionary of images to display or save 40 | def display_current_results(self, visuals, epoch, step): 41 | 42 | ## convert tensors to numpy arrays 43 | visuals = self.convert_visuals_to_numpy(visuals) 44 | 45 | if self.tf_log: # show images in tensorboard output 46 | img_summaries = [] 47 | for label, image_numpy in visuals.items(): 48 | # Write the image to a string 49 | try: 50 | s = StringIO() 51 | except: 52 | s = BytesIO() 53 | if len(image_numpy.shape) >= 4: 54 | image_numpy = image_numpy[0] 55 | scipy.misc.toimage(image_numpy).save(s, format="jpeg") 56 | # Create an Image object 57 | img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) 58 | # Create a Summary value 59 | img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) 60 | 61 | # Create and write Summary 62 | summary = self.tf.Summary(value=img_summaries) 63 | self.writer.add_summary(summary, step) 64 | 65 | if self.use_html: # save images to a html file 66 | for label, image_numpy in visuals.items(): 67 | if isinstance(image_numpy, list): 68 | for i in range(len(image_numpy)): 69 | img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.3d_%s_%d.png' % (epoch, step, label, i)) 70 | util.save_image(image_numpy[i], img_path) 71 | else: 72 | img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.3d_%s.png' % (epoch, step, label)) 73 | if len(image_numpy.shape) >= 4: 74 | image_numpy = image_numpy[0] 75 | util.save_image(image_numpy, img_path) 76 | 77 | # update website 78 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=5) 79 | for n in range(epoch, 0, -1): 80 | webpage.add_header('epoch [%d]' % n) 81 | ims = [] 82 | txts = [] 83 | links = [] 84 | 85 | for label, image_numpy in visuals.items(): 86 | if isinstance(image_numpy, list): 87 | for i in range(len(image_numpy)): 88 | img_path = 'epoch%.3d_iter%.3d_%s_%d.png' % (n, step, label, i) 89 | ims.append(img_path) 90 | txts.append(label+str(i)) 91 | links.append(img_path) 92 | else: 93 | img_path = 'epoch%.3d_iter%.3d_%s.png' % (n, step, label) 94 | ims.append(img_path) 95 | txts.append(label) 96 | links.append(img_path) 97 | if len(ims) < 10: 98 | webpage.add_images(ims, txts, links, width=self.win_size) 99 | else: 100 | num = int(round(len(ims)/2.0)) 101 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) 102 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) 103 | webpage.save() 104 | 105 | # errors: dictionary of error labels and values 106 | def plot_current_errors(self, errors, step): 107 | if self.tf_log: 108 | for tag, value in errors.items(): 109 | value = value.mean().float() 110 | summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) 111 | self.writer.add_summary(summary, step) 112 | 113 | # errors: same format as |errors| of plotCurrentErrors 114 | def print_current_errors(self, epoch, i, errors, t): 115 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) 116 | for k, v in errors.items(): 117 | #print(v) 118 | #if v != 0: 119 | v = v.mean().float() 120 | message += '%s: %.3f ' % (k, v) 121 | 122 | print(message) 123 | with open(self.log_name, "a") as log_file: 124 | log_file.write('%s\n' % message) 125 | 126 | def convert_visuals_to_numpy(self, visuals): 127 | for key, t in visuals.items(): 128 | #tile = self.opt.batchSize > 8 129 | tile = self.opt.batch_size > 8 130 | if 'input_label' == key or 'synthesized_segmentation' == key: 131 | t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) 132 | else: 133 | t = util.tensor2im(t, tile=tile) 134 | visuals[key] = t 135 | return visuals 136 | 137 | # save image to the disk 138 | def save_images(self, webpage, visuals, image_path): 139 | visuals = self.convert_visuals_to_numpy(visuals) 140 | 141 | image_dir = webpage.get_image_dir() 142 | short_path = ntpath.basename(image_path[0]) 143 | name = os.path.splitext(short_path)[0] 144 | 145 | webpage.add_header(name) 146 | ims = [] 147 | txts = [] 148 | links = [] 149 | 150 | for label, image_numpy in visuals.items(): 151 | image_name = os.path.join(label, '%s.png' % (name)) 152 | save_path = os.path.join(image_dir, image_name) 153 | util.save_image(image_numpy, save_path, create_dir=True) 154 | 155 | ims.append(image_name) 156 | txts.append(label) 157 | links.append(image_name) 158 | webpage.add_images(ims, txts, links, width=self.win_size) 159 | --------------------------------------------------------------------------------