├── .gitmodules ├── README.md ├── checkpoints ├── .gitkeep ├── comod-ffhq-1024 │ └── .gitkeep ├── comod-ffhq-512 │ └── .gitkeep └── comod-places-512 │ └── .gitkeep ├── data ├── __init__.py ├── base_dataset.py ├── image_folder.py ├── testimage_dataset.py ├── trainimage_dataset.py └── valimage_dataset.py ├── datasets └── .gitkeep ├── download ├── data.sh ├── ffhq1024.sh ├── ffhq512.sh └── places512.sh ├── ffhq_debug ├── 1.png ├── example_image.jpg ├── images.txt ├── images │ └── 1.png ├── masks │ └── 1.png └── masks_inv │ └── 1.png ├── imgs ├── example_image.jpg ├── example_mask.jpg ├── example_output.jpg ├── ffhq_in.png └── ffhq_m.png ├── models ├── __init__.py ├── comod_model.py ├── create_mask.py └── networks │ ├── __init__.py │ ├── architecture.py │ ├── base_network.py │ ├── co_mod_gan.py │ ├── discriminator.py │ ├── generator.py │ ├── loss.py │ ├── op │ ├── __init__.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu │ └── stylegan2.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── output └── .gitkeep ├── save_remote_gs.py ├── test.py ├── test.sh ├── train.py ├── train.sh ├── trainers ├── __init__.py └── stylegan2_trainer.py └── util ├── __init__.py ├── coco.py ├── html.py ├── iter_counter.py ├── util.py └── visualizer.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "logger"] 2 | path = logger 3 | url = https://github.com/zengxianyu/logger 4 | [submodule "models/networks/sync_batchnorm"] 5 | path = models/networks/sync_batchnorm 6 | url = https://github.com/zengxianyu/sync_batchnorm 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # co-mod-gan-pytorch 2 | Implementation of the paper ``Large Scale Image Completion via Co-Modulated Generative Adversarial Networks" 3 | 4 | official tensorflow version: https://github.com/zsyzzsoft/co-mod-gan 5 | 6 | Input image Mask Result 7 | 8 | ## Usage 9 | 10 | ### requirments 11 | ``` 12 | conda install pytorch torchvision cudatoolkit=11 -c pytorch 13 | conda install matplotlib jinja2 ninja dill 14 | pip install git+https://github.com/zengxianyu/pytorch-fid 15 | ``` 16 | 17 | Download the code: 18 | 19 | ``` 20 | git clone https://github.com/zengxianyu/co-mod-gan-pytorch 21 | git checkout train 22 | git submodule init 23 | git submodule update 24 | ``` 25 | 26 | ### inference 27 | 28 | 1. download pretrained model using ``download/*.sh" (converted from the tensorflow pretrained model) 29 | 30 | e.g. ffhq512 31 | 32 | ``` 33 | ./download/ffhq512.sh 34 | ``` 35 | 36 | converted model: 37 | * FFHQ 512 checkpoints/comod-ffhq-512/co-mod-gan-ffhq-9-025000_net_G_ema.pth 38 | * FFHQ 1024 checkpoints/comod-ffhq-1024/co-mod-gan-ffhq-10-025000_net_G_ema.pth 39 | * Places 512 checkpoints/comod-places-512/co-mod-gan-places2-050000_net_G_ema.pth 40 | 41 | 2. use the following command as a minimal example of usage 42 | 43 | ``` 44 | ./test.sh 45 | ``` 46 | 47 | ### Training 48 | 1. download example datasets for training and validation 49 | 50 | ``` 51 | ./download/data.sh 52 | ``` 53 | 54 | 2. use the following command as a minimal example of usage 55 | 56 | ``` 57 | ./train.sh 58 | ``` 59 | 60 | ### Demo 61 | Coming soon 62 | 63 | ## Reference 64 | 65 | [1] official tensorflow version: https://github.com/zsyzzsoft/co-mod-gan 66 | 67 | [2] stylegan2-pytorch https://github.com/rosinality/stylegan2-pytorch 68 | 69 | [3] pix2pixHD https://github.com/NVIDIA/pix2pixHD 70 | 71 | [4] SPADE https://github.com/NVlabs/SPADE 72 | -------------------------------------------------------------------------------- /checkpoints/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/checkpoints/.gitkeep -------------------------------------------------------------------------------- /checkpoints/comod-ffhq-1024/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/checkpoints/comod-ffhq-1024/.gitkeep -------------------------------------------------------------------------------- /checkpoints/comod-ffhq-512/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/checkpoints/comod-ffhq-512/.gitkeep -------------------------------------------------------------------------------- /checkpoints/comod-places-512/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/checkpoints/comod-places-512/.gitkeep -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import importlib 7 | import torch.utils.data 8 | from data.base_dataset import BaseDataset 9 | 10 | 11 | def find_dataset_using_name(dataset_name): 12 | # Given the option --dataset [datasetname], 13 | # the file "datasets/datasetname_dataset.py" 14 | # will be imported. 15 | dataset_filename = "data." + dataset_name + "_dataset" 16 | datasetlib = importlib.import_module(dataset_filename) 17 | 18 | # In the file, the class called DatasetNameDataset() will 19 | # be instantiated. It has to be a subclass of BaseDataset, 20 | # and it is case-insensitive. 21 | dataset = None 22 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 23 | for name, cls in datasetlib.__dict__.items(): 24 | if name.lower() == target_dataset_name.lower() \ 25 | and issubclass(cls, BaseDataset): 26 | dataset = cls 27 | 28 | if dataset is None: 29 | raise ValueError("In %s.py, there should be a subclass of BaseDataset " 30 | "with class name that matches %s in lowercase." % 31 | (dataset_filename, target_dataset_name)) 32 | 33 | return dataset 34 | 35 | 36 | def get_option_setter(dataset_name): 37 | dataset_class = find_dataset_using_name(dataset_name) 38 | return dataset_class.modify_commandline_options 39 | 40 | 41 | def create_dataloader(opt): 42 | dataset = find_dataset_using_name(opt.dataset_mode) 43 | instance = dataset() 44 | instance.initialize(opt) 45 | print("dataset [%s] of size %d was created" % 46 | (type(instance).__name__, len(instance))) 47 | dataloader = torch.utils.data.DataLoader( 48 | instance, 49 | batch_size=opt.batchSize, 50 | shuffle=not opt.serial_batches, 51 | num_workers=int(opt.nThreads), 52 | drop_last=opt.isTrain 53 | ) 54 | return dataloader 55 | 56 | def create_dataloader_trainval(opt): 57 | assert opt.isTrain 58 | dataset = find_dataset_using_name(opt.dataset_mode_train) 59 | instance = dataset() 60 | instance.initialize(opt) 61 | print("dataset [%s] of size %d was created" % 62 | (type(instance).__name__, len(instance))) 63 | dataloader_train = torch.utils.data.DataLoader( 64 | instance, 65 | batch_size=opt.batchSize, 66 | shuffle=not opt.serial_batches, 67 | num_workers=int(opt.nThreads), 68 | drop_last=True 69 | ) 70 | dataset = find_dataset_using_name(opt.dataset_mode_val) 71 | instance = dataset() 72 | instance.initialize(opt) 73 | print("dataset [%s] of size %d was created" % 74 | (type(instance).__name__, len(instance))) 75 | dataloader_val = torch.utils.data.DataLoader( 76 | instance, 77 | batch_size=opt.batchSize, 78 | shuffle=False, 79 | num_workers=int(opt.nThreads), 80 | drop_last=False 81 | ) 82 | return dataloader_train, dataloader_val 83 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch.utils.data as data 7 | from PIL import Image 8 | import torchvision.transforms as transforms 9 | import numpy as np 10 | import random 11 | 12 | 13 | class BaseDataset(data.Dataset): 14 | def __init__(self): 15 | super(BaseDataset, self).__init__() 16 | 17 | @staticmethod 18 | def modify_commandline_options(parser, is_train): 19 | return parser 20 | 21 | def initialize(self, opt): 22 | pass 23 | 24 | 25 | def get_params(opt, size): 26 | w, h = size 27 | new_h = h 28 | new_w = w 29 | if opt.preprocess_mode == 'resize_and_crop': 30 | new_h = new_w = opt.load_size 31 | elif opt.preprocess_mode == 'scale_width_and_crop': 32 | new_w = opt.load_size 33 | new_h = opt.load_size * h // w 34 | elif opt.preprocess_mode == 'scale_shortside_and_crop': 35 | ss, ls = min(w, h), max(w, h) # shortside and longside 36 | width_is_shorter = w == ss 37 | ls = int(opt.load_size * ls / ss) 38 | new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss) 39 | 40 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 41 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 42 | 43 | flip = random.random() > 0.5 44 | return {'crop_pos': (x, y), 'flip': flip} 45 | 46 | def bbox_transform(opt, params, bbox, size, force_flip=False): 47 | w,h = size 48 | if opt.isTrain and (not opt.no_flip or force_flip): 49 | if params['flip']: 50 | bbox[0] = w-bbox[0]-bbox[2] 51 | rate_h = rate_w = 1 52 | if 'resize' in opt.preprocess_mode: 53 | rate_h = float(opt.load_size)/h 54 | rate_w = float(opt.load_size)/w 55 | elif 'scale_width' in opt.preprocess_mode: 56 | rate_w = rate_h = float(opt.load_size)/w 57 | elif 'scale_shortside' in opt.preprocess_mode: 58 | ss, ls = min(w, h), max(w, h) # shortside and longside 59 | width_is_shorter = w == ss 60 | if (ss == opt.load_size): 61 | rate_w = rate_h = 1 62 | ls = int(opt.load_size * ls / ss) 63 | nw, nh = (ss, ls) if width_is_shorter else (ls, ss) 64 | rate_h = float(nh)/h 65 | rate_w = float(nw)/w 66 | if opt.preprocess_mode == 'fixed': 67 | rate_w = float(opt.crop_size)/w 68 | ht = round(opt.crop_size / opt.aspect_ratio) 69 | rate_h = float(opt.crop_size)/ht 70 | 71 | bbox[0] *= rate_w 72 | bbox[2] *= rate_w 73 | bbox[1] *= rate_h 74 | bbox[3] *= rate_h 75 | w *= rate_w 76 | h *= rate_h 77 | 78 | if 'crop' in opt.preprocess_mode: 79 | x,y = params['crop_pos'] 80 | bbox[0] -= x 81 | bbox[1] -= y 82 | w -= x 83 | h -= y 84 | y2 = bbox[1]+bbox[3] 85 | x2 = bbox[0]+bbox[2] 86 | y2 = min(opt.crop_size,y2) 87 | x2 = min(opt.crop_size,x2) 88 | x1 = max(0,bbox[0]) 89 | y1 = max(0,bbox[1]) 90 | bbox[0] = x1 91 | bbox[1] = y1 92 | bbox[2] = x2-x1 93 | bbox[3] = y2-y1 94 | if opt.preprocess_mode == 'none': 95 | base = 32 96 | _h = int(round(oh / base) * base) 97 | _w = int(round(ow / base) * base) 98 | rate_h = float(_h)/h 99 | rate_w = float(_w)/w 100 | bbox[0] *= rate_w 101 | bbox[2] *= rate_w 102 | bbox[1] *= rate_h 103 | bbox[3] *= rate_h 104 | return bbox 105 | 106 | 107 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True, force_flip=False): 108 | transform_list = [] 109 | if opt.isTrain and (not opt.no_flip or force_flip): 110 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 111 | 112 | if 'resize' in opt.preprocess_mode: 113 | osize = [opt.load_size, opt.load_size] 114 | transform_list.append(transforms.Resize(osize, interpolation=method)) 115 | elif 'scale_width' in opt.preprocess_mode: 116 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 117 | elif 'scale_shortside' in opt.preprocess_mode: 118 | transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method))) 119 | 120 | if 'crop' in opt.preprocess_mode: 121 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 122 | 123 | if opt.preprocess_mode == 'none': 124 | base = 32 125 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) 126 | 127 | if opt.preprocess_mode == 'fixed': 128 | w = opt.crop_size 129 | h = round(opt.crop_size / opt.aspect_ratio) 130 | transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method))) 131 | 132 | if toTensor: 133 | transform_list += [transforms.ToTensor()] 134 | 135 | if normalize: 136 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 137 | (0.5, 0.5, 0.5))] 138 | return transforms.Compose(transform_list) 139 | 140 | 141 | def normalize(): 142 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 143 | 144 | 145 | def __resize(img, w, h, method=Image.BICUBIC): 146 | return img.resize((w, h), method) 147 | 148 | 149 | def __make_power_2(img, base, method=Image.BICUBIC): 150 | ow, oh = img.size 151 | h = int(round(oh / base) * base) 152 | w = int(round(ow / base) * base) 153 | if (h == oh) and (w == ow): 154 | return img 155 | return img.resize((w, h), method) 156 | 157 | 158 | def __scale_width(img, target_width, method=Image.BICUBIC): 159 | ow, oh = img.size 160 | if (ow == target_width): 161 | return img 162 | w = target_width 163 | h = int(target_width * oh / ow) 164 | return img.resize((w, h), method) 165 | 166 | 167 | def __scale_shortside(img, target_width, method=Image.BICUBIC): 168 | ow, oh = img.size 169 | ss, ls = min(ow, oh), max(ow, oh) # shortside and longside 170 | width_is_shorter = ow == ss 171 | if (ss == target_width): 172 | return img 173 | ls = int(target_width * ls / ss) 174 | nw, nh = (ss, ls) if width_is_shorter else (ls, ss) 175 | return img.resize((nw, nh), method) 176 | 177 | 178 | def __crop(img, pos, size): 179 | ow, oh = img.size 180 | x1, y1 = pos 181 | tw = th = size 182 | return img.crop((x1, y1, x1 + tw, y1 + th)) 183 | 184 | 185 | def __flip(img, flip): 186 | if flip: 187 | return img.transpose(Image.FLIP_LEFT_RIGHT) 188 | return img 189 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | ############################################################################### 7 | # Code from 8 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 9 | # Modified the original code so that it also loads images from the current 10 | # directory as well as the subdirectories 11 | ############################################################################### 12 | import torch.utils.data as data 13 | from PIL import Image 14 | import os 15 | 16 | IMG_EXTENSIONS = [ 17 | '.jpg', '.JPG', '.jpeg', '.JPEG', 18 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff', '.webp' 19 | ] 20 | 21 | 22 | def is_image_file(filename): 23 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 24 | 25 | 26 | def make_dataset_rec(dir, images): 27 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 28 | 29 | for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)): 30 | for fname in fnames: 31 | if is_image_file(fname): 32 | path = os.path.join(root, fname) 33 | images.append(path) 34 | 35 | 36 | def make_dataset(dir, recursive=False, read_cache=False, write_cache=False): 37 | images = [] 38 | 39 | if read_cache: 40 | possible_filelist = os.path.join(dir, 'files.list') 41 | if os.path.isfile(possible_filelist): 42 | with open(possible_filelist, 'r') as f: 43 | images = f.read().splitlines() 44 | return images 45 | 46 | if recursive: 47 | make_dataset_rec(dir, images) 48 | else: 49 | assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir 50 | 51 | for root, dnames, fnames in sorted(os.walk(dir)): 52 | for fname in fnames: 53 | if is_image_file(fname): 54 | path = os.path.join(root, fname) 55 | images.append(path) 56 | 57 | if write_cache: 58 | filelist_cache = os.path.join(dir, 'files.list') 59 | with open(filelist_cache, 'w') as f: 60 | for path in images: 61 | f.write("%s\n" % path) 62 | print('wrote filelist cache at %s' % filelist_cache) 63 | 64 | return images 65 | 66 | 67 | def default_loader(path): 68 | return Image.open(path).convert('RGB') 69 | 70 | 71 | class ImageFolder(data.Dataset): 72 | 73 | def __init__(self, root, transform=None, return_paths=False, 74 | loader=default_loader): 75 | imgs = make_dataset(root) 76 | if len(imgs) == 0: 77 | raise(RuntimeError("Found 0 images in: " + root + "\n" 78 | "Supported image extensions are: " + 79 | ",".join(IMG_EXTENSIONS))) 80 | 81 | self.root = root 82 | self.imgs = imgs 83 | self.transform = transform 84 | self.return_paths = return_paths 85 | self.loader = loader 86 | 87 | def __getitem__(self, index): 88 | path = self.imgs[index] 89 | img = self.loader(path) 90 | if self.transform is not None: 91 | img = self.transform(img) 92 | if self.return_paths: 93 | return img, path 94 | else: 95 | return img 96 | 97 | def __len__(self): 98 | return len(self.imgs) 99 | -------------------------------------------------------------------------------- /data/testimage_dataset.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | import torch 3 | from data.base_dataset import get_params, get_transform, BaseDataset 4 | from PIL import Image 5 | from data.image_folder import make_dataset 6 | import os 7 | import pdb 8 | 9 | 10 | class TestImageDataset(BaseDataset): 11 | """ Dataset that loads images from directories 12 | Use option --label_dir, --image_dir, --instance_dir to specify the directories. 13 | The images in the directories are sorted in alphabetical order and paired in order. 14 | """ 15 | 16 | @staticmethod 17 | def modify_commandline_options(parser, is_train): 18 | parser.add_argument('--list_dir', type=str, required=False, 19 | help='path to the directory that contains photo images') 20 | parser.add_argument('--image_dir', type=str, required=True, 21 | help='path to the directory that contains photo images') 22 | parser.add_argument('--mask_dir', type=str, required=True, 23 | help='path to the directory that contains photo images') 24 | parser.add_argument('--output_dir', type=str, required=True, 25 | help='path to the directory that contains photo images') 26 | return parser 27 | 28 | def initialize(self, opt): 29 | self.opt = opt 30 | if not os.path.exists(opt.output_dir): 31 | os.mkdir(opt.output_dir) 32 | 33 | image_paths, mask_paths, output_paths = self.get_paths(opt) 34 | 35 | self.image_paths = image_paths 36 | self.mask_paths = mask_paths 37 | self.output_paths = output_paths 38 | 39 | size = len(self.image_paths) 40 | self.dataset_size = size 41 | transform_list = [ 42 | transforms.ToTensor(), 43 | transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5)) 44 | ] 45 | self.image_transform = transforms.Compose(transform_list) 46 | self.mask_transform = transforms.Compose([ 47 | transforms.ToTensor() 48 | ]) 49 | 50 | def get_paths(self, opt): 51 | img_names = os.listdir(opt.image_dir) 52 | img_postfix = img_names[0].split(".")[-1] 53 | if opt.list_dir is not None: 54 | with open(opt.list_dir, "r") as f: 55 | msk_names = f.readlines() 56 | msk_names = [n.strip("\n") for n in msk_names] 57 | else: 58 | msk_names = os.listdir(opt.mask_dir) 59 | img_names = [n.replace("png", img_postfix) for n in msk_names] 60 | image_paths = [f"{opt.image_dir}/{n}" for n in img_names] 61 | output_paths = [f"{opt.output_dir}/{n}" for n in img_names] 62 | mask_paths = [f"{opt.mask_dir}/{n}" for n in msk_names] 63 | 64 | return image_paths, mask_paths, output_paths 65 | 66 | def __len__(self): 67 | return self.dataset_size 68 | 69 | def __getitem__(self, index): 70 | # input image (real images) 71 | output_path = self.output_paths[index] 72 | image_path = self.image_paths[index] 73 | image = Image.open(image_path) 74 | image = image.convert('RGB') 75 | w, h = image.size 76 | image_tensor = self.image_transform(image) 77 | # mask image 78 | mask_path = self.mask_paths[index] 79 | mask = Image.open(mask_path) 80 | mask = mask.convert("L") 81 | mask = mask.resize((w,h)) 82 | mask_tensor = self.mask_transform(mask) 83 | mask_tensor = (mask_tensor>0).float() 84 | input_dict = { 85 | 'image': image_tensor, 86 | 'mask': mask_tensor, 87 | 'path': output_path, 88 | } 89 | 90 | return input_dict 91 | -------------------------------------------------------------------------------- /data/trainimage_dataset.py: -------------------------------------------------------------------------------- 1 | from data.base_dataset import get_params, get_transform, BaseDataset 2 | from PIL import Image 3 | from data.image_folder import make_dataset 4 | import os 5 | import pdb 6 | 7 | 8 | class TrainImageDataset(BaseDataset): 9 | """ Dataset that loads images from directories 10 | Use option --label_dir, --image_dir, --instance_dir to specify the directories. 11 | The images in the directories are sorted in alphabetical order and paired in order. 12 | """ 13 | 14 | @staticmethod 15 | def modify_commandline_options(parser, is_train): 16 | parser.add_argument('--train_image_dir', type=str, required=True, 17 | help='path to the directory that contains photo images') 18 | parser.add_argument('--train_image_list', type=str, required=True, 19 | help='path to the directory that contains photo images') 20 | parser.add_argument('--train_image_postfix', type=str, default="", 21 | help='path to the directory that contains photo images') 22 | return parser 23 | 24 | def initialize(self, opt): 25 | self.opt = opt 26 | image_paths = self.get_paths(opt) 27 | 28 | self.image_paths = image_paths 29 | 30 | size = len(self.image_paths) 31 | self.dataset_size = size 32 | 33 | def get_paths(self, opt): 34 | image_dir = opt.train_image_dir 35 | image_list = opt.train_image_list 36 | names = open(image_list).readlines() 37 | filenames = list(map(lambda x: x.strip('\n')+opt.train_image_postfix, names)) 38 | image_paths = list(map(lambda x: os.path.join(image_dir, x), filenames)) 39 | return image_paths 40 | 41 | def __len__(self): 42 | return self.dataset_size 43 | 44 | def __getitem__(self, index): 45 | # input image (real images) 46 | image_path = self.image_paths[index] 47 | image = Image.open(image_path) 48 | image = image.convert('RGB') 49 | params = get_params(self.opt, image.size) 50 | transform_image = get_transform(self.opt, params) 51 | image_tensor = transform_image(image) 52 | input_dict = { 53 | 'image': image_tensor, 54 | 'path': image_path, 55 | } 56 | return input_dict 57 | #except: 58 | # print(f"skip {image_path}") 59 | # return self.__getitem__((index+1)%self.__len__()) 60 | -------------------------------------------------------------------------------- /data/valimage_dataset.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | import torch 3 | from data.base_dataset import get_params, get_transform, BaseDataset 4 | from PIL import Image 5 | from data.image_folder import make_dataset 6 | import os 7 | import pdb 8 | 9 | 10 | class ValImageDataset(BaseDataset): 11 | """ Dataset that loads images from directories 12 | Use option --label_dir, --image_dir, --instance_dir to specify the directories. 13 | The images in the directories are sorted in alphabetical order and paired in order. 14 | """ 15 | 16 | @staticmethod 17 | def modify_commandline_options(parser, is_train): 18 | parser.add_argument('--val_image_dir', type=str, required=True, 19 | help='path to the directory that contains photo images') 20 | parser.add_argument('--val_image_list', type=str, required=True, 21 | help='path to the directory that contains photo images') 22 | parser.add_argument('--val_mask_dir', type=str, required=True, 23 | help='path to the directory that contains photo images') 24 | parser.add_argument('--val_image_postfix', type=str, default=".jpg", 25 | help='path to the directory that contains photo images') 26 | parser.add_argument('--val_mask_postfix', type=str, default=".png", 27 | help='path to the directory that contains photo images') 28 | return parser 29 | 30 | def initialize(self, opt): 31 | self.opt = opt 32 | 33 | image_paths, mask_paths = self.get_paths(opt) 34 | 35 | self.image_paths = image_paths 36 | self.mask_paths = mask_paths 37 | 38 | size = len(self.image_paths) 39 | self.dataset_size = size 40 | transform_list = [ 41 | transforms.Resize((opt.crop_size, opt.crop_size), 42 | interpolation=Image.NEAREST), 43 | transforms.ToTensor(), 44 | transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5)) 45 | ] 46 | self.image_transform = transforms.Compose(transform_list) 47 | self.mask_transform = transforms.Compose([ 48 | transforms.Resize((opt.crop_size, opt.crop_size),interpolation=Image.NEAREST), 49 | transforms.ToTensor() 50 | ]) 51 | 52 | def get_paths(self, opt): 53 | image_dir = opt.val_image_dir 54 | image_list = opt.val_image_list 55 | names = open(image_list).readlines() 56 | filenames = list(map(lambda x: x.strip('\n')+opt.val_image_postfix, names)) 57 | image_paths = list(map(lambda x: os.path.join(image_dir, x), filenames)) 58 | filenames = list(map(lambda x: x.strip('\n')+opt.val_mask_postfix, names)) 59 | mask_paths = list(map(lambda x: os.path.join(opt.val_mask_dir, x), filenames)) 60 | return image_paths, mask_paths 61 | 62 | def __len__(self): 63 | return self.dataset_size 64 | 65 | def __getitem__(self, index): 66 | # input image (real images) 67 | image_path = self.image_paths[index] 68 | image = Image.open(image_path) 69 | image = image.convert('RGB') 70 | w, h = image.size 71 | image_tensor = self.image_transform(image) 72 | # mask image 73 | mask_path = self.mask_paths[index] 74 | mask = Image.open(mask_path) 75 | mask = mask.convert("L") 76 | mask = mask.resize((w,h)) 77 | mask_tensor = self.mask_transform(mask) 78 | mask_tensor = (mask_tensor>0).float() 79 | input_dict = { 80 | 'image': image_tensor, 81 | 'mask': mask_tensor, 82 | 'path': image_path, 83 | } 84 | 85 | return input_dict 86 | -------------------------------------------------------------------------------- /datasets/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/datasets/.gitkeep -------------------------------------------------------------------------------- /download/data.sh: -------------------------------------------------------------------------------- 1 | wget https://maildluteducn-my.sharepoint.com/:u:/g/personal/zengyu_mail_dlut_edu_cn/Ed6KS2wg-olJsLicvZOpUHkB4nak9nYtJPXxwvM8W_d9PQ?download=1 2 | mv Ed6KS2wg-olJsLicvZOpUHkB4nak9nYtJPXxwvM8W_d9PQ?download=1 ./datasets/places2sample1k_val.zip 3 | unzip datasets/places2sample1k_val.zip -d datasets/ 4 | -------------------------------------------------------------------------------- /download/ffhq1024.sh: -------------------------------------------------------------------------------- 1 | wget https://maildluteducn-my.sharepoint.com/:u:/g/personal/zengyu_mail_dlut_edu_cn/EXcZ9OHEFQFAqhe1GQVFZ_4BwWoTvWyDM429gP5XrPAdaQ?download=1 2 | mv EXcZ9OHEFQFAqhe1GQVFZ_4BwWoTvWyDM429gP5XrPAdaQ?download=1 ./checkpoints/comod-ffhq-1024/co-mod-gan-ffhq-10-025000_net_G_ema.pth 3 | -------------------------------------------------------------------------------- /download/ffhq512.sh: -------------------------------------------------------------------------------- 1 | #wget https://maildluteducn-my.sharepoint.com/:u:/g/personal/zengyu_mail_dlut_edu_cn/Ee1YPJG2Y7NDnUjJBf-SipoBBSlbv8QfFy6K7lsiiiiFHg?download=1 2 | mv Ee1YPJG2Y7NDnUjJBf-SipoBBSlbv8QfFy6K7lsiiiiFHg?download=1 ./checkpoints/comod-ffhq-512/co-mod-gan-ffhq-9-025000_net_G_ema.pth 3 | -------------------------------------------------------------------------------- /download/places512.sh: -------------------------------------------------------------------------------- 1 | wget https://maildluteducn-my.sharepoint.com/:u:/g/personal/zengyu_mail_dlut_edu_cn/EQG9jJzkFLJDsOWmVJJSoqQB2jRDkXlYt3wnt9Fb9dJDsQ?download=1 2 | mv EQG9jJzkFLJDsOWmVJJSoqQB2jRDkXlYt3wnt9Fb9dJDsQ?download=1 ./checkpoints/comod-places-512/co-mod-gan-places2-050000_net_G_ema.pth 3 | -------------------------------------------------------------------------------- /ffhq_debug/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/ffhq_debug/1.png -------------------------------------------------------------------------------- /ffhq_debug/example_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/ffhq_debug/example_image.jpg -------------------------------------------------------------------------------- /ffhq_debug/images.txt: -------------------------------------------------------------------------------- 1 | 1.png 2 | -------------------------------------------------------------------------------- /ffhq_debug/images/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/ffhq_debug/images/1.png -------------------------------------------------------------------------------- /ffhq_debug/masks/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/ffhq_debug/masks/1.png -------------------------------------------------------------------------------- /ffhq_debug/masks_inv/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/ffhq_debug/masks_inv/1.png -------------------------------------------------------------------------------- /imgs/example_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/imgs/example_image.jpg -------------------------------------------------------------------------------- /imgs/example_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/imgs/example_mask.jpg -------------------------------------------------------------------------------- /imgs/example_output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/imgs/example_output.jpg -------------------------------------------------------------------------------- /imgs/ffhq_in.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/imgs/ffhq_in.png -------------------------------------------------------------------------------- /imgs/ffhq_m.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/imgs/ffhq_m.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import importlib 7 | import torch 8 | 9 | 10 | def find_model_using_name(model_name): 11 | # Given the option --model [modelname], 12 | # the file "models/modelname_model.py" 13 | # will be imported. 14 | model_filename = "models." + model_name + "_model" 15 | modellib = importlib.import_module(model_filename) 16 | 17 | # In the file, the class called ModelNameModel() will 18 | # be instantiated. It has to be a subclass of torch.nn.Module, 19 | # and it is case-insensitive. 20 | model = None 21 | target_model_name = model_name.replace('_', '') + 'model' 22 | for name, cls in modellib.__dict__.items(): 23 | if name.lower() == target_model_name.lower() \ 24 | and issubclass(cls, torch.nn.Module): 25 | model = cls 26 | 27 | if model is None: 28 | print("In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name)) 29 | exit(0) 30 | 31 | return model 32 | 33 | 34 | def get_option_setter(model_name): 35 | model_class = find_model_using_name(model_name) 36 | return model_class.modify_commandline_options 37 | 38 | 39 | def create_model(opt): 40 | model = find_model_using_name(opt.model) 41 | instance = model(opt) 42 | print("model [%s] was created" % (type(instance).__name__)) 43 | 44 | return instance 45 | -------------------------------------------------------------------------------- /models/comod_model.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import torch.nn.functional as F 3 | import torchvision.ops as ops 4 | import math 5 | import torch 6 | import models.networks as networks 7 | import util.util as util 8 | import random 9 | import numpy as np 10 | from models.create_mask import MaskCreator 11 | 12 | 13 | class CoModModel(torch.nn.Module): 14 | @staticmethod 15 | def modify_commandline_options(parser, is_train): 16 | networks.modify_commandline_options(parser, is_train) 17 | parser.add_argument('--no_g_reg', action='store_true') 18 | parser.add_argument('--path_objectshape_base', type=str, required=False, help='path obj base') 19 | parser.add_argument('--path_objectshape_list', type=str, required=False, help='path obj list') 20 | parser.add_argument('--mixing', type=float, default=0.9) 21 | parser.add_argument('--r1', type=float, default=10) 22 | parser.add_argument('--d_reg_every', type=int, default=16) 23 | parser.add_argument('--g_reg_every', type=int, default=4) 24 | parser.add_argument('--path_batch_shrink', type=int, default=2) 25 | parser.add_argument('--truncation', type=float, required=False) 26 | parser.add_argument('--path_regularize', type=int, default=2) 27 | parser.set_defaults(init_type=None) 28 | parser.set_defaults(gan_mode='softplus') 29 | parser.set_defaults(lr=0.002) 30 | parser.set_defaults(z_dim=512) 31 | # factor 32 | parser.add_argument('--factor', type=str, required=False) 33 | parser.add_argument('--factor_d', type=int, default=5) 34 | parser.add_argument('--factor_i', type=int, default=0) 35 | parser.add_argument('--load_pretrained_g', type=str, required=False, help='load pt g') 36 | parser.add_argument('--load_pretrained_g_ema', type=str, required=False, help='load pt g') 37 | parser.add_argument('--load_pretrained_d', type=str, required=False, help='load pt d') 38 | return parser 39 | 40 | def __init__(self, opt): 41 | super().__init__() 42 | self.opt = opt 43 | self.truncation_mean = None 44 | 45 | self.device = torch.device("cuda") if self.use_gpu() \ 46 | else torch.device("cpu") 47 | self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \ 48 | else torch.FloatTensor 49 | self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \ 50 | else torch.ByteTensor 51 | 52 | self.netG, self.netG_ema, self.netD = self.initialize_networks(opt) 53 | if opt.factor is not None: 54 | self.eigvec = torch.load(opt.factor)["eigvec"].to(self.device) 55 | # set loss functions 56 | if opt.isTrain: 57 | if not opt.continue_train: 58 | if opt.load_pretrained_g is not None: 59 | print(f"looad {opt.load_pretrained_g}") 60 | self.netG = util.load_network_path( 61 | self.netG, opt.load_pretrained_g) 62 | if opt.load_pretrained_g_ema is not None: 63 | print(f"looad {opt.load_pretrained_g}") 64 | self.netG_ema = util.load_network_path( 65 | self.netG_ema, opt.load_pretrained_g_ema) 66 | if opt.load_pretrained_d is not None: 67 | print(f"looad {opt.load_pretrained_d}") 68 | self.netD = util.load_network_path( 69 | self.netD, opt.load_pretrained_d) 70 | self.mask_creator = MaskCreator(opt.path_objectshape_list, opt.path_objectshape_base) 71 | self.criterionGAN = networks.GANLoss( 72 | opt.gan_mode, tensor=self.FloatTensor, opt=self.opt) 73 | if not opt.no_vgg_loss: 74 | self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids) 75 | if opt.truncation is not None: 76 | self.truncation_mean = self.mean_latent(4096) 77 | 78 | def accumulate(self, decay=0.999): 79 | par1 = dict(self.netG_ema.named_parameters()) 80 | par2 = dict(self.netG.named_parameters()) 81 | 82 | for k in par1.keys(): 83 | par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay) 84 | 85 | # set loss functions 86 | 87 | # Entry point for all calls involving forward pass 88 | # of deep networks. We used this approach since DataParallel module 89 | # can't parallelize custom functions, we branch to different 90 | # routines based on |mode|. 91 | def forward(self, data, mode): 92 | real_image, mask, mean_path_length = self.preprocess_input(data) 93 | bsize = real_image.size(0) 94 | if mode == 'generator': 95 | g_loss, fake_image = self.compute_generator_loss(real_image, mask) 96 | generated = {'fake':fake_image, 97 | 'input':real_image*(1-mask), 98 | 'gt':real_image, 99 | } 100 | return g_loss, real_image, generated 101 | elif mode == 'dreal': 102 | d_loss = self.compute_discriminator_loss( 103 | real_image, fake_image=None, mask=mask) 104 | return d_loss 105 | elif mode == 'dfake': 106 | with torch.no_grad(): 107 | fake_image, uc_image,_ = self.generate_fake(real_image, mask) 108 | fake_image = fake_image.detach() 109 | fake_image.requires_grad_() 110 | d_loss = self.compute_discriminator_loss( 111 | real_image=None, fake_image=fake_image, mask=mask) 112 | return d_loss 113 | elif mode == 'd_reg': 114 | d_regs = self.compute_discriminator_reg(real_image, mask) 115 | return d_regs 116 | elif mode == 'g_reg': 117 | g_regs, path_lengths, mean_path_length = self.compute_generator_reg( 118 | real_image, 119 | mask, 120 | mean_path_length) 121 | return g_regs, mean_path_length 122 | elif mode == 'inference': 123 | with torch.no_grad(): 124 | if self.opt.factor is None: 125 | fake_image, uc_image,_ = self.generate_fake(real_image, mask, ema=True) 126 | else: 127 | fake_image, _ = self.factorize_fake(real_image, mask) 128 | inp = real_image*(1-mask) 129 | return fake_image, inp 130 | else: 131 | raise ValueError("|mode| is invalid") 132 | 133 | def create_optimizers(self, opt): 134 | G_params = list(self.netG.parameters()) 135 | #G_params = [p for name, p in self.netG.named_parameters() \ 136 | # if (not name.startswith("coarse"))] 137 | if opt.isTrain: 138 | D_params = list(self.netD.parameters()) 139 | 140 | g_reg_ratio = self.opt.g_reg_every / (self.opt.g_reg_every + 1) 141 | d_reg_ratio = self.opt.d_reg_every / (self.opt.d_reg_every + 1) 142 | 143 | g_optim = torch.optim.Adam( 144 | G_params, 145 | lr=self.opt.lr * g_reg_ratio, 146 | betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), 147 | ) 148 | d_optim = torch.optim.Adam( 149 | D_params, 150 | lr=self.opt.lr * d_reg_ratio, 151 | betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), 152 | ) 153 | 154 | return g_optim, d_optim 155 | 156 | def save(self, epoch): 157 | util.save_network(self.netG, 'G', epoch, self.opt) 158 | util.save_network(self.netG_ema, 'G_ema', epoch, self.opt) 159 | util.save_network(self.netD, 'D', epoch, self.opt) 160 | 161 | ############################################################################ 162 | # Private helper methods 163 | ############################################################################ 164 | 165 | def initialize_networks(self, opt): 166 | netG_ema = networks.define_G(opt) 167 | if opt.isTrain: 168 | netG = networks.define_G(opt) 169 | netD = networks.define_D(opt) 170 | else: 171 | netD=None 172 | netG=None 173 | 174 | if not opt.isTrain or opt.continue_train: 175 | netG_ema = util.load_network(netG_ema, 'G_ema', opt.which_epoch, opt) 176 | if opt.isTrain: 177 | netG = util.load_network(netG, 'G', opt.which_epoch, opt) 178 | netD = util.load_network(netD, 'D', opt.which_epoch, opt) 179 | return netG, netG_ema, netD 180 | 181 | # preprocess the input, such as moving the tensors to GPUs and 182 | # transforming the label map to one-hot encoding 183 | # |data|: dictionary of the input data 184 | 185 | def mean_latent(self, n_latent): 186 | self.netG_ema.eval() 187 | latent_in = torch.randn(n_latent, self.opt.z_dim, device=self.device) 188 | dlatent = self.netG_ema(latents_in=[latent_in], get_latent=True)[0] 189 | latent_mean = dlatent.mean(0, keepdim=True) 190 | self.truncation_mean = latent_mean 191 | return self.truncation_mean 192 | 193 | def make_noise(self, batch, n_noise): 194 | if n_noise == 1: 195 | return torch.randn(batch, self.opt.z_dim, device=self.device) 196 | 197 | noises = torch.randn(n_noise, batch, self.opt.z_dim, 198 | device=self.device).unbind(0) 199 | 200 | return noises 201 | 202 | def make_mask(self, data): 203 | b,c,h,w = data['image'].shape 204 | if self.opt.isTrain: 205 | # generate random stroke mask 206 | mask1 = self.mask_creator.stroke_mask(h, w, max_length=min(h,w)/2) 207 | # generate object/square mask 208 | ri = random.randint(0,3) 209 | if self.opt.path_objectshape_base is not None and (ri == 1 or ri == 0): 210 | mask2 = self.mask_creator.object_mask(h, w) 211 | else: 212 | mask2 = self.mask_creator.rectangle_mask(h, w, 213 | min(h,w)//4, min(h,w)//2) 214 | # use the mix of two masks 215 | mask = (mask1+mask2>0) 216 | mask = mask.astype(np.float) 217 | mask = self.FloatTensor(mask)[None, None,...].expand(b,-1,-1,-1) 218 | data['mask'] = mask 219 | else: 220 | if self.use_gpu(): 221 | data['mask'] = data['mask'].cuda() 222 | mask = data['mask'] 223 | return mask 224 | 225 | def mixing_noise(self, batch): 226 | if self.opt.mixing > 0 and random.random() < self.opt.mixing: 227 | noise = self.make_noise(batch, 2) 228 | return noise 229 | else: 230 | return [self.make_noise(batch, 1)] 231 | 232 | def preprocess_input(self, data): 233 | b,c,h,w = data['image'].shape 234 | if 'mask' in data: 235 | if self.use_gpu(): 236 | data['mask'] = data['mask'].cuda() 237 | mask = data['mask'] 238 | else: 239 | mask = self.make_mask(data) 240 | if self.use_gpu(): 241 | data['image'] = data['image'].cuda() 242 | if 'mean_path_length' in data: 243 | mean_path_length = data['mean_path_length'].detach().cuda() 244 | else: 245 | mean_path_length = 0 246 | return data['image'], data['mask'], mean_path_length 247 | 248 | def g_path_regularize(self, fake_image, latents, mean_path_length, decay=0.01): 249 | noise = torch.randn_like(fake_image) / math.sqrt( 250 | fake_image.shape[2] * fake_image.shape[3] 251 | ) 252 | grad, = torch.autograd.grad( 253 | outputs=(fake_image * noise).sum(), inputs=latents, create_graph=True 254 | ) 255 | path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) 256 | 257 | path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) 258 | 259 | path_penalty = (path_lengths - path_mean).pow(2).mean() 260 | 261 | return path_penalty, path_mean.detach(), path_lengths 262 | 263 | def compute_generator_reg(self, real_image, mask, mean_path_length): 264 | G_regs = {} 265 | bsize = real_image.size(0) 266 | path_batch_size = max(1, bsize // self.opt.path_batch_shrink) 267 | fake_image, _, latents = self.generate_fake(real_image, mask, True) 268 | path_loss, mean_path_length, path_lengths = self.g_path_regularize( 269 | fake_image, latents, mean_path_length 270 | ) 271 | weighted_path_loss = self.opt.path_regularize * self.opt.g_reg_every * path_loss 272 | 273 | if self.opt.path_batch_shrink: 274 | weighted_path_loss += 0 * fake_image[0, 0, 0, 0] 275 | G_regs['path'] = weighted_path_loss 276 | return G_regs, path_lengths, mean_path_length 277 | 278 | def compute_generator_loss(self, real_image, mask): 279 | fake_image, uc_image, _ = self.generate_fake(real_image, mask) 280 | #pred_fake, pred_real = self.discriminate( 281 | # fake_image, real_image) 282 | pred_fake = self.netD(fake_image, mask) 283 | 284 | G_losses = {} 285 | G_losses['GAN'] = self.criterionGAN(pred_fake, True, 286 | for_discriminator=False) 287 | if not self.opt.no_vgg_loss: 288 | G_losses['VGG'] = self.criterionVGG(uc_image, real_image) \ 289 | * self.opt.lambda_vgg 290 | if not self.opt.no_l1_loss: 291 | G_losses['L1'] = torch.nn.functional.l1_loss(uc_image, real_image) * self.opt.lambda_l1 292 | return G_losses, fake_image 293 | 294 | def compute_discriminator_reg(self, real_image, mask): 295 | real_image.requires_grad = True 296 | real_pred = self.netD(real_image, mask) 297 | grad_real, = torch.autograd.grad( 298 | outputs=real_pred.sum(), inputs=real_image, create_graph=True 299 | ) 300 | r1_loss = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() 301 | 302 | r1_loss = self.opt.r1 / 2 * r1_loss * self.opt.d_reg_every + 0 * real_pred[0] 303 | D_regs = {'r1': r1_loss} 304 | 305 | return D_regs 306 | 307 | def compute_discriminator_loss(self, real_image, fake_image=None, mask=None): 308 | D_losses = {} 309 | assert mask is not None 310 | assert fake_image is not None or real_image is not None 311 | assert fake_image is None or real_image is None 312 | if fake_image is not None: 313 | fake_image = fake_image.detach() 314 | pred_fake = self.netD(fake_image, mask) 315 | D_losses['D_Fake'] = self.criterionGAN(pred_fake, False, 316 | for_discriminator=True) 317 | elif real_image is not None: 318 | pred_real = self.netD(real_image, mask) 319 | D_losses['D_real'] = self.criterionGAN(pred_real, True, 320 | for_discriminator=True) 321 | 322 | return D_losses 323 | 324 | def factorize_fake(self, real_image, mask, return_latents=False): 325 | self.netG_ema.eval() 326 | bsize = real_image.size(0) 327 | latent_in = torch.randn(bsize, self.opt.z_dim, device=self.device) 328 | dlatent = self.netG_ema(latents_in=[latent_in], get_latent=True)[0] 329 | 330 | direction = self.opt.factor_d * self.eigvec[:, self.opt.factor_i].unsqueeze(0) 331 | img1, _, latent = self.netG_ema( 332 | real_image, 333 | mask, 334 | [dlatent-direction], 335 | return_latents=return_latents, 336 | truncation=self.opt.truncation, 337 | truncation_latent=self.truncation_mean, 338 | input_is_latent=True) 339 | img2, _, latent = self.netG_ema( 340 | real_image, 341 | mask, 342 | [dlatent], 343 | return_latents=return_latents, 344 | truncation=self.opt.truncation, 345 | truncation_latent=self.truncation_mean, 346 | input_is_latent=True) 347 | img3, _, latent = self.netG_ema( 348 | real_image, 349 | mask, 350 | [dlatent+direction], 351 | return_latents=return_latents, 352 | truncation=self.opt.truncation, 353 | truncation_latent=self.truncation_mean, 354 | input_is_latent=True) 355 | fake_image = torch.cat((img1,img2,img3),3) 356 | return fake_image, dlatent 357 | 358 | def generate_fake(self, real_image, mask, return_latents=False, ema=False): 359 | bsize = real_image.size(0) 360 | noise = self.mixing_noise(bsize) 361 | if ema: 362 | self.netG_ema.eval() 363 | fake_image, uc_image, latent = self.netG_ema( 364 | real_image, 365 | mask, 366 | noise, 367 | return_latents=return_latents, 368 | truncation=self.opt.truncation, 369 | truncation_latent=self.truncation_mean, 370 | ) 371 | else: 372 | fake_image, uc_image, latent = self.netG( 373 | real_image, 374 | mask, 375 | noise, 376 | return_latents=return_latents, 377 | ) 378 | return fake_image, uc_image, latent 379 | 380 | # Given fake and real image, return the prediction of discriminator 381 | # for each fake and real image. 382 | 383 | def discriminate(self, fake_image, real_image): 384 | raise NotImplementedError 385 | fake_concat = fake_image 386 | real_concat = real_image 387 | 388 | # In Batch Normalization, the fake and real images are 389 | # recommended to be in the same batch to avoid disparate 390 | # statistics in fake and real images. 391 | # So both fake and real images are fed to D all at once. 392 | fake_and_real = torch.cat([fake_concat, real_concat], dim=0) 393 | discriminator_out = self.netD(fake_and_real) 394 | 395 | pred_fake, pred_real = self.divide_pred(discriminator_out) 396 | 397 | return pred_fake, pred_real 398 | 399 | # Take the prediction of fake and real images from the combined batch 400 | def divide_pred(self, pred): 401 | # the prediction contains the intermediate outputs of multiscale GAN, 402 | # so it's usually a list 403 | if type(pred) == list: 404 | fake = [] 405 | real = [] 406 | for p in pred: 407 | fake.append([tensor[:tensor.size(0) // 2] for tensor in p]) 408 | real.append([tensor[tensor.size(0) // 2:] for tensor in p]) 409 | else: 410 | fake = pred[:pred.size(0) // 2] 411 | real = pred[pred.size(0) // 2:] 412 | 413 | return fake, real 414 | 415 | def use_gpu(self): 416 | return len(self.opt.gpu_ids) > 0 417 | -------------------------------------------------------------------------------- /models/create_mask.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import random 4 | from PIL import Image, ImageDraw 5 | import os 6 | import pdb 7 | import math 8 | 9 | class MaskCreator: 10 | def __init__(self, list_mask_path=None, base_mask_path=None, match_size=False): 11 | self.match_size = match_size 12 | if list_mask_path is not None: 13 | filenames = open(list_mask_path).readlines() 14 | msk_filenames = list(map(lambda x: os.path.join(base_mask_path, x.strip('\n')), filenames)) 15 | self.msk_filenames = msk_filenames 16 | else: 17 | self.msk_filenames = None 18 | 19 | 20 | def object_shadow(self, h, w, blur_kernel=7, noise_loc=0.5, noise_range=0.05): 21 | """ 22 | img: rgb numpy 23 | return: rgb numpy 24 | """ 25 | mask = self.object_mask(h, w) 26 | kernel = np.ones((blur_kernel+3,blur_kernel+3),np.float32) 27 | expand_mask = cv2.dilate(mask,kernel,iterations = 1) 28 | noise = np.random.normal(noise_loc, noise_range, mask.shape) 29 | noise[noise>1] = 1 30 | mask = mask*noise 31 | mask = mask + (mask==0) 32 | kernel = np.ones((blur_kernel,blur_kernel),np.float32)/(blur_kernel*blur_kernel) 33 | mask = cv2.filter2D(mask,-1,kernel) 34 | return mask, expand_mask 35 | 36 | 37 | def object_mask(self, image_height=256, image_width=256): 38 | if self.msk_filenames is None: 39 | raise NotImplementedError 40 | hb, wb = image_height, image_width 41 | # object mask as hole 42 | mask = Image.open(random.choice(self.msk_filenames)) 43 | ## randomly resize 44 | wm, hm = mask.size 45 | if self.match_size: 46 | r = float(min(hb, wb)) / max(wm, hm) 47 | r = r /2 48 | else: 49 | r = 1 50 | scale = random.gauss(r, 0.5) 51 | scale = scale if scale > 0.5 else 0.5 52 | scale = scale if scale < 2 else 2.0 53 | wm, hm = int(wm*scale), int(hm*scale) 54 | mask = mask.resize((wm, hm)) 55 | mask = np.array(mask) 56 | mask = (mask>0) 57 | if mask.sum() > 0: 58 | ## crop object region 59 | col_nz = mask.sum(0) 60 | row_nz = mask.sum(1) 61 | col_nz = np.where(col_nz!=0)[0] 62 | left = col_nz[0] 63 | right = col_nz[-1] 64 | row_nz = np.where(row_nz!=0)[0] 65 | top = row_nz[0] 66 | bot = row_nz[-1] 67 | mask = mask[top:bot, left:right] 68 | else: 69 | return self.object_mask(image_height, image_width) 70 | ## place in a random location on the extended canvas 71 | hm, wm = mask.shape 72 | canvas = np.zeros((hm+hb, wm+wb)) 73 | y = random.randint(0, hb-1) 74 | x = random.randint(0, wb-1) 75 | canvas[y:y+hm, x:x+wm] = mask 76 | hole = canvas[int(hm/2):int(hm/2)+hb, int(wm/2):int(wm/2)+wb] 77 | th = 100 if self.match_size else 1000 78 | if hole.sum() < hb*wb / th: 79 | return self.object_mask(image_height, image_width) 80 | else: 81 | return hole.astype(np.float) 82 | 83 | def rectangle_mask(self, image_height=256, image_width=256, min_hole_size=64, max_hole_size=128): 84 | mask = np.zeros((image_height, image_width)) 85 | hole_size = random.randint(min_hole_size, max_hole_size) 86 | hole_size = min(int(image_width*0.8), int(image_height*0.8), hole_size) 87 | x = random.randint(0, image_width-hole_size-1) 88 | y = random.randint(0, image_height-hole_size-1) 89 | mask[x:x+hole_size, y:y+hole_size] = 1 90 | return mask 91 | 92 | def random_brush( 93 | self, 94 | max_tries, 95 | image_height=256, 96 | image_width=256, 97 | min_num_vertex = 4, 98 | max_num_vertex = 18, 99 | mean_angle = 2*math.pi / 5, 100 | angle_range = 2*math.pi / 15, 101 | min_width = 12, 102 | max_width = 48): 103 | H, W = image_height, image_width 104 | average_radius = math.sqrt(H*H+W*W) / 8 105 | mask = Image.new('L', (W, H), 0) 106 | for _ in range(np.random.randint(max_tries)): 107 | num_vertex = np.random.randint(min_num_vertex, max_num_vertex) 108 | angle_min = mean_angle - np.random.uniform(0, angle_range) 109 | angle_max = mean_angle + np.random.uniform(0, angle_range) 110 | angles = [] 111 | vertex = [] 112 | for i in range(num_vertex): 113 | if i % 2 == 0: 114 | angles.append(2*math.pi - np.random.uniform(angle_min, angle_max)) 115 | else: 116 | angles.append(np.random.uniform(angle_min, angle_max)) 117 | 118 | h, w = mask.size 119 | vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h)))) 120 | for i in range(num_vertex): 121 | r = np.clip( 122 | np.random.normal(loc=average_radius, scale=average_radius//2), 123 | 0, 2*average_radius) 124 | new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w) 125 | new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h) 126 | vertex.append((int(new_x), int(new_y))) 127 | 128 | draw = ImageDraw.Draw(mask) 129 | width = int(np.random.uniform(min_width, max_width)) 130 | draw.line(vertex, fill=1, width=width) 131 | for v in vertex: 132 | draw.ellipse((v[0] - width//2, 133 | v[1] - width//2, 134 | v[0] + width//2, 135 | v[1] + width//2), 136 | fill=1) 137 | if np.random.random() > 0.5: 138 | mask.transpose(Image.FLIP_LEFT_RIGHT) 139 | if np.random.random() > 0.5: 140 | mask.transpose(Image.FLIP_TOP_BOTTOM) 141 | mask = np.asarray(mask, np.uint8) 142 | if np.random.random() > 0.5: 143 | mask = np.flip(mask, 0) 144 | if np.random.random() > 0.5: 145 | mask = np.flip(mask, 1) 146 | return mask 147 | 148 | def random_mask(self, image_height=256, image_width=256, hole_range=[0,1]): 149 | coef = min(hole_range[0] + hole_range[1], 1.0) 150 | #mask = self.random_brush(int(20 * coef), image_height, image_width) 151 | while True: 152 | mask = np.ones((image_height, image_width), np.uint8) 153 | def Fill(max_size): 154 | w, h = np.random.randint(max_size), np.random.randint(max_size) 155 | ww, hh = w // 2, h // 2 156 | x, y = np.random.randint(-ww, image_width - w + ww), np.random.randint(-hh, image_height - h + hh) 157 | mask[max(y, 0): min(y + h, image_height), max(x, 0): min(x + w, image_width)] = 0 158 | def MultiFill(max_tries, max_size): 159 | for _ in range(np.random.randint(max_tries)): 160 | Fill(max_size) 161 | MultiFill(int(10 * coef), max(image_height, image_width) // 2) 162 | MultiFill(int(5 * coef), max(image_height, image_width)) 163 | mask = np.logical_and(mask, 1 - self.random_brush(int(20 * coef), image_height, image_width)) 164 | hole_ratio = 1 - np.mean(mask) 165 | if hole_ratio >= hole_range[0] and hole_ratio <= hole_range[1]: 166 | break 167 | return 1-mask 168 | 169 | def stroke_mask(self, image_height=256, image_width=256, max_vertex=5, max_mask=5, max_length=128): 170 | max_angle = np.pi 171 | max_brush_width = max(1, int(max_length*0.4)) 172 | min_brush_width = max(1, int(max_length*0.1)) 173 | 174 | mask = np.zeros((image_height, image_width)) 175 | for k in range(random.randint(1, max_mask)): 176 | num_vertex = random.randint(1, max_vertex) 177 | start_x = random.randint(0, image_width-1) 178 | start_y = random.randint(0, image_height-1) 179 | for i in range(num_vertex): 180 | angle = random.uniform(0, max_angle) 181 | if i % 2 == 0: 182 | angle = 2*np.pi - angle 183 | length = random.uniform(0, max_length) 184 | brush_width = random.randint(min_brush_width, max_brush_width) 185 | end_x = min(int(start_x + length * np.cos(angle)), image_width) 186 | end_y = min(int(start_y + length * np.sin(angle)), image_height) 187 | mask = cv2.line(mask, (start_x, start_y), (end_x, end_y), color=1, thickness=brush_width) 188 | start_x, start_y = end_x, end_y 189 | mask = cv2.circle(mask, (start_x, start_y), int(brush_width/2), 1) 190 | if random.randint(0, 1): 191 | mask = mask[:, ::-1].copy() 192 | if random.randint(0, 1): 193 | mask = mask[::-1, :].copy() 194 | return mask 195 | 196 | 197 | def get_spatial_discount(mask): 198 | H, W = mask.shape 199 | shift_up = np.zeros((H, W)) 200 | shift_up[:-1, :] = mask[1:, :] 201 | shift_left = np.zeros((H, W)) 202 | shift_left[:, :-1] = mask[:, 1:] 203 | 204 | boundary_y = mask - shift_up 205 | boundary_x = mask - shift_left 206 | 207 | boundary_y = np.abs(boundary_y) 208 | boundary_x = np.abs(boundary_x) 209 | boundary = boundary_x + boundary_y 210 | boundary[boundary != 0 ] = 1 211 | # plt.imshow(boundary) 212 | # plt.show() 213 | 214 | xx, yy = np.meshgrid(range(W), range(H)) 215 | bd_x = xx[boundary==1] 216 | bd_y = yy[boundary==1] 217 | dis_x = xx[..., None] - bd_x[None, None, ...] 218 | dis_y = yy[..., None] - bd_y[None, None, ...] 219 | dis = np.sqrt(dis_x*dis_x + dis_y*dis_y) 220 | min_dis = dis.min(2) 221 | gamma = 0.9 222 | discount_map = (gamma**min_dis)*mask 223 | return discount_map 224 | 225 | 226 | 227 | 228 | if __name__ == "__main__": 229 | import os 230 | from tqdm import tqdm 231 | import pdb 232 | mask_creator = MaskCreator() 233 | mask = mask_creator.random_mask(image_height=512, image_width=512) 234 | Image.fromarray((mask*255).astype(np.uint8)).save("output/mask.png") 235 | -------------------------------------------------------------------------------- /models/networks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | from models.networks.base_network import BaseNetwork 8 | from models.networks.loss import * 9 | from models.networks.discriminator import * 10 | from models.networks.generator import * 11 | import util.util as util 12 | import pdb 13 | 14 | 15 | def find_network_using_name(target_network_name, filename): 16 | target_class_name = target_network_name + filename 17 | module_name = 'models.networks.' + filename 18 | network = util.find_class_in_module(target_class_name, module_name) 19 | 20 | assert issubclass(network, BaseNetwork), \ 21 | "Class %s should be a subclass of BaseNetwork" % network 22 | 23 | return network 24 | 25 | 26 | def modify_commandline_options(parser, is_train): 27 | opt, _ = parser.parse_known_args() 28 | 29 | netG_cls = find_network_using_name(opt.netG, 'generator') 30 | parser = netG_cls.modify_commandline_options(parser, is_train) 31 | if is_train: 32 | netD_cls = find_network_using_name(opt.netD, 'discriminator') 33 | parser = netD_cls.modify_commandline_options(parser, is_train) 34 | return parser 35 | 36 | 37 | def create_network(cls, opt): 38 | net = cls(opt) 39 | net.print_network() 40 | if len(opt.gpu_ids) > 0: 41 | assert(torch.cuda.is_available()) 42 | net.cuda() 43 | if opt.init_type is not None: 44 | net.init_weights(opt.init_type, opt.init_variance) 45 | return net 46 | 47 | 48 | def define_G(opt): 49 | netG_cls = find_network_using_name(opt.netG, 'generator') 50 | return create_network(netG_cls, opt) 51 | 52 | 53 | def define_D(opt): 54 | netD_cls = find_network_using_name(opt.netD, 'discriminator') 55 | return create_network(netD_cls, opt) 56 | 57 | -------------------------------------------------------------------------------- /models/networks/architecture.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision 10 | 11 | # VGG architecter, used for the perceptual loss using a pretrained VGG network 12 | class VGG19(torch.nn.Module): 13 | def __init__(self, requires_grad=False): 14 | super().__init__() 15 | vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features 16 | self.slice1 = torch.nn.Sequential() 17 | self.slice2 = torch.nn.Sequential() 18 | self.slice3 = torch.nn.Sequential() 19 | self.slice4 = torch.nn.Sequential() 20 | self.slice5 = torch.nn.Sequential() 21 | for x in range(2): 22 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 23 | for x in range(2, 7): 24 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 25 | for x in range(7, 12): 26 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 27 | for x in range(12, 21): 28 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 29 | for x in range(21, 30): 30 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 31 | if not requires_grad: 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | def forward(self, X): 36 | h_relu1 = self.slice1(X) 37 | h_relu2 = self.slice2(h_relu1) 38 | h_relu3 = self.slice3(h_relu2) 39 | h_relu4 = self.slice4(h_relu3) 40 | h_relu5 = self.slice5(h_relu4) 41 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 42 | return out 43 | -------------------------------------------------------------------------------- /models/networks/base_network.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch.nn as nn 7 | from torch.nn import init 8 | 9 | 10 | class BaseNetwork(nn.Module): 11 | def __init__(self): 12 | super(BaseNetwork, self).__init__() 13 | 14 | @staticmethod 15 | def modify_commandline_options(parser, is_train): 16 | return parser 17 | 18 | def print_network(self): 19 | if isinstance(self, list): 20 | self = self[0] 21 | num_params = 0 22 | for param in self.parameters(): 23 | num_params += param.numel() 24 | print('Network [%s] was created. Total number of parameters: %.1f million. ' 25 | 'To see the architecture, do print(network).' 26 | % (type(self).__name__, num_params / 1000000)) 27 | 28 | def init_weights(self, init_type='normal', gain=0.02): 29 | def init_func(m): 30 | classname = m.__class__.__name__ 31 | if classname.find('BatchNorm2d') != -1: 32 | if hasattr(m, 'weight') and m.weight is not None: 33 | init.normal_(m.weight.data, 1.0, gain) 34 | if hasattr(m, 'bias') and m.bias is not None: 35 | init.constant_(m.bias.data, 0.0) 36 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 37 | if init_type == 'normal': 38 | init.normal_(m.weight.data, 0.0, gain) 39 | elif init_type == 'xavier': 40 | init.xavier_normal_(m.weight.data, gain=gain) 41 | elif init_type == 'xavier_uniform': 42 | init.xavier_uniform_(m.weight.data, gain=1.0) 43 | elif init_type == 'kaiming': 44 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 45 | elif init_type == 'orthogonal': 46 | init.orthogonal_(m.weight.data, gain=gain) 47 | elif init_type == 'none': # uses pytorch's default init method 48 | m.reset_parameters() 49 | else: 50 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 51 | if hasattr(m, 'bias') and m.bias is not None: 52 | init.constant_(m.bias.data, 0.0) 53 | 54 | self.apply(init_func) 55 | 56 | # propagate to children 57 | for m in self.children(): 58 | if hasattr(m, 'init_weights'): 59 | m.init_weights(init_type, gain) 60 | 61 | def get_param_list(self, label): 62 | print("updating all params") 63 | return self.parameters() 64 | #raise NotImplementedError 65 | -------------------------------------------------------------------------------- /models/networks/co_mod_gan.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import random 3 | from collections import OrderedDict 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from models.networks.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d 8 | from models.networks.stylegan2 import PixelNorm, EqualLinear, EqualConv2d,ConvLayer,StyledConv,ToRGB,ConvToRGB,TransConvLayer 9 | import numpy as np 10 | 11 | from models.networks.base_network import BaseNetwork 12 | 13 | #---------------------------------------------------------------------------- 14 | # Mapping network. 15 | # Transforms the input latent code (z) to the disentangled latent code (w). 16 | # Used in configs B-F (Table 1). 17 | 18 | class G_mapping(nn.Module): 19 | def __init__(self, 20 | opt 21 | ): 22 | latent_size = 512 # Latent vector (Z) dimensionality. 23 | label_size = 0 # Label dimensionality, 0 if no labels. 24 | dlatent_broadcast = None # Output disentangled latent (W) as [minibatch, dlatent_size] or [minibatch, dlatent_broadcast, dlatent_size]. 25 | mapping_layers = 8 # Number of mapping layers. 26 | mapping_fmaps = 512 # Number of activations in the mapping layers. 27 | mapping_lrmul = 0.01 # Learning rate multiplier for the mapping layers. 28 | mapping_nonlinearity = 'lrelu' # Activation function: 'relu', 'lrelu', etc. 29 | normalize_latents = True # Normalize latent vectors (Z) before feeding them to the mapping layers? 30 | super().__init__() 31 | layers = [] 32 | 33 | # Embed labels and concatenate them with latents. 34 | if label_size: 35 | raise NotImplementedError 36 | 37 | # Normalize latents. 38 | if normalize_latents: 39 | layers.append( 40 | ('Normalize', PixelNorm())) 41 | # Mapping layers. 42 | dim_in = latent_size 43 | for layer_idx in range(mapping_layers): 44 | fmaps = opt.dlatent_size if layer_idx == mapping_layers - 1 else mapping_fmaps 45 | layers.append( 46 | ( 47 | 'Dense%d' % layer_idx, 48 | EqualLinear( 49 | dim_in, 50 | fmaps, 51 | lr_mul=mapping_lrmul, 52 | activation="fused_lrelu") 53 | )) 54 | dim_in = fmaps 55 | # Broadcast. 56 | if dlatent_broadcast is not None: 57 | raise NotImplementedError 58 | self.G_mapping = nn.Sequential(OrderedDict(layers)) 59 | 60 | def forward( 61 | self, 62 | latents_in): 63 | styles = self.G_mapping(latents_in) 64 | return styles 65 | 66 | #---------------------------------------------------------------------------- 67 | # CoModGAN synthesis network. 68 | 69 | class G_synthesis_co_mod_gan(nn.Module): 70 | def __init__( 71 | self, 72 | opt 73 | ): 74 | resolution_log2 = int(np.log2(opt.crop_size)) 75 | assert opt.crop_size == 2**resolution_log2 and opt.crop_size >= 4 76 | def nf(stage): return np.clip(int(opt.fmap_base / (2.0 ** (stage * opt.fmap_decay))), opt.fmap_min, opt.fmap_max) 77 | assert opt.architecture in ['skip'] 78 | assert opt.nonlinearity == 'lrelu' 79 | assert opt.fused_modconv 80 | assert not opt.pix2pix 81 | self.nf = nf 82 | super().__init__() 83 | act = opt.nonlinearity 84 | self.num_layers = resolution_log2 * 2 - 2 85 | self.resolution_log2 = resolution_log2 86 | 87 | class E_fromrgb(nn.Module): # res = 2..resolution_log2 88 | def __init__(self, res, channel_in=opt.num_channels+1): 89 | super().__init__() 90 | self.FromRGB = ConvLayer( 91 | channel_in, 92 | nf(res-1), 93 | 1, 94 | blur_kernel=opt.resample_kernel, 95 | activate=True) 96 | def forward(self, data): 97 | y, E_features = data 98 | t = self.FromRGB(y) 99 | return t, E_features 100 | class E_block(nn.Module): # res = 2..resolution_log2 101 | def __init__(self, res): 102 | super().__init__() 103 | self.Conv0 = ConvLayer( 104 | nf(res-1), 105 | nf(res-1), 106 | kernel_size=3, 107 | activate=True) 108 | self.Conv1_down = ConvLayer( 109 | nf(res-1), 110 | nf(res-2), 111 | kernel_size=3, 112 | downsample=True, 113 | blur_kernel=opt.resample_kernel, 114 | activate=True) 115 | self.res = res 116 | def forward(self, data): 117 | x, E_features = data 118 | x = self.Conv0(x) 119 | E_features[self.res] = x 120 | x = self.Conv1_down(x) 121 | return x, E_features 122 | class E_block_final(nn.Module): # res = 2..resolution_log2 123 | def __init__(self): 124 | super().__init__() 125 | self.Conv = ConvLayer( 126 | nf(2), 127 | nf(1), 128 | kernel_size=3, 129 | activate=True) 130 | self.Dense0 = EqualLinear(nf(1)*4*4, nf(1)*2, 131 | activation="fused_lrelu") 132 | self.dropout = nn.Dropout(opt.dropout_rate) 133 | def forward(self, data): 134 | x, E_features = data 135 | x = self.Conv(x) 136 | E_features[2] = x 137 | bsize = x.size(0) 138 | x = x.view(bsize, -1) 139 | x = self.Dense0(x) 140 | x = self.dropout(x) 141 | return x, E_features 142 | def make_encoder(channel_in=opt.num_channels+1): 143 | Es = [] 144 | for res in range(self.resolution_log2, 2, -1): 145 | if res == self.resolution_log2: 146 | Es.append( 147 | ( 148 | '%dx%d_0' % (2**res, 2**res), 149 | E_fromrgb(res, channel_in) 150 | )) 151 | Es.append( 152 | ( 153 | '%dx%d' % (2**res, 2**res), 154 | E_block(res) 155 | 156 | )) 157 | # Final layers. 158 | Es.append( 159 | ( 160 | '4x4', 161 | E_block_final() 162 | 163 | )) 164 | Es = nn.Sequential(OrderedDict(Es)) 165 | return Es 166 | self.make_encoder = make_encoder 167 | 168 | # Main layers. 169 | c_in = opt.num_channels+1 170 | self.E = self.make_encoder(channel_in=c_in) 171 | 172 | # Single convolution layer with all the bells and whistles. 173 | # Building blocks for main layers. 174 | mod_size = 0 175 | if opt.style_mod: 176 | mod_size += opt.dlatent_size 177 | if opt.cond_mod: 178 | mod_size += nf(1)*2 179 | assert mod_size > 0 180 | self.mod_size = mod_size 181 | def get_mod(latent, idx, x_global): 182 | if isinstance(latent, list): 183 | latent = latent[:][idx] 184 | else: 185 | latent = latent[:,idx] 186 | mod_vector = [] 187 | if opt.style_mod: 188 | mod_vector.append(latent) 189 | if opt.cond_mod: 190 | mod_vector.append(x_global) 191 | mod_vector = torch.cat(mod_vector, 1) 192 | return mod_vector 193 | self.get_mod = get_mod 194 | class Block(nn.Module): 195 | def __init__(self, res): 196 | super().__init__() 197 | self.res = res 198 | self.Conv0_up = StyledConv( 199 | nf(res-2), 200 | nf(res-1), 201 | kernel_size=3, 202 | style_dim=mod_size, 203 | upsample=True, 204 | blur_kernel=opt.resample_kernel) 205 | self.Conv1 = StyledConv( 206 | nf(res-1), 207 | nf(res-1), 208 | kernel_size=3, 209 | style_dim=mod_size, 210 | upsample=False) 211 | self.ToRGB = ToRGB( 212 | nf(res-1), 213 | mod_size, out_channel=opt.num_channels) 214 | def forward(self, x, y, dlatents_in, x_global, E_features): 215 | x_skip = E_features[self.res] 216 | mod_vector = get_mod(dlatents_in, res*2-5, x_global) 217 | if opt.noise_injection: 218 | noise = None 219 | else: 220 | noise = 0 221 | x = self.Conv0_up(x, mod_vector, noise, x_skip=x_skip) 222 | x = x + x_skip 223 | mod_vector = get_mod(dlatents_in, self.res*2-4, x_global) 224 | x = self.Conv1(x, mod_vector, noise, x_skip=x_skip) 225 | mod_vector = get_mod(dlatents_in, self.res*2-3, x_global) 226 | y = self.ToRGB(x, mod_vector, skip=y, x_skip=x_skip) 227 | return x, y 228 | self.Block = Block 229 | class Block0(nn.Module): 230 | def __init__(self): 231 | super().__init__() 232 | self.Dense = EqualLinear( 233 | nf(1)*2, 234 | nf(1)*4*4, 235 | activation="fused_lrelu") 236 | self.Conv = StyledConv( 237 | nf(1), 238 | nf(1), 239 | kernel_size=3, 240 | style_dim=mod_size, 241 | ) 242 | self.ToRGB = ToRGB( 243 | nf(1), 244 | style_dim=mod_size, 245 | upsample=False, out_channel=opt.num_channels) 246 | def forward(self, x, dlatents_in, x_global): 247 | x = self.Dense(x) 248 | x = x.view(-1, nf(1), 4, 4) 249 | mod_vector = get_mod(dlatents_in, 0, x_global) 250 | if opt.noise_injection: 251 | noise = None 252 | else: 253 | noise = 0 254 | x = self.Conv(x, mod_vector, noise) 255 | mod_vector = get_mod(dlatents_in, 1, x_global) 256 | y = self.ToRGB(x, mod_vector) 257 | return x, y 258 | # Early layers. 259 | self.G_4x4 = Block0() 260 | # Main layers. 261 | for res in range(3, resolution_log2 + 1): 262 | setattr(self, 'G_%dx%d' % (2**res, 2**res), 263 | Block(res)) 264 | 265 | def forward(self, images_in, masks_in, dlatents_in): 266 | y = torch.cat([1-masks_in - 0.5, images_in * (1-masks_in)], 1) 267 | E_features = {} 268 | x_global, E_features = self.E((y, E_features)) 269 | x = x_global 270 | x, y = self.G_4x4(x, dlatents_in, x_global) 271 | for res in range(3, self.resolution_log2 + 1): 272 | block = getattr(self, 'G_%dx%d' % (2**res, 2**res)) 273 | x, y = block(x, y, dlatents_in, x_global, E_features) 274 | raw_out = y 275 | images_out = y * masks_in + images_in * (1-masks_in) 276 | return images_out, raw_out 277 | 278 | #---------------------------------------------------------------------------- 279 | # Main generator network. 280 | # Composed of two sub-networks (mapping and synthesis) that are defined below. 281 | # Used in configs B-F (Table 1). 282 | 283 | class Generator(BaseNetwork): 284 | @staticmethod 285 | def modify_commandline_options(parser, is_train): 286 | parser.add_argument('--dlatent_size', type=int, default= 512 )# Disentangled latent (W) dimensionality. 287 | parser.add_argument('--num_channels', type=int, default= 3, )# Number of output color channels. 288 | parser.add_argument('--fmap_base', type=int, default= 16 << 10, )# Overall multiplier for the number of feature maps. 289 | parser.add_argument('--fmap_decay', type=int, default= 1.0, )# log2 feature map reduction when doubling the resolution. 290 | parser.add_argument('--fmap_min', type=int, default= 1, )# Minimum number of feature maps in any layer. 291 | parser.add_argument('--fmap_max', type=int, default= 512, )# Maximum number of feature maps in any layer. 292 | parser.add_argument('--randomize_noise', type=bool, default= True, )# True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables. 293 | parser.add_argument('--architecture', type=str, default= 'skip', )# Architecture: 'orig', 'skip', 'resnet'. 294 | parser.add_argument('--nonlinearity', type=str, default= 'lrelu', )# Activation function: 'relu', 'lrelu', etc. 295 | parser.add_argument('--resample_kernel', type=list, default= [1,3,3,1], )# Low-pass filter to apply when resampling activations. None = no filtering. 296 | parser.add_argument('--fused_modconv', type=bool, default= True, )# Implement modulated_conv2d_layer() as a single fused op? 297 | parser.add_argument('--pix2pix', type=bool, default= False) 298 | parser.add_argument('--dropout_rate', type=float, default= 0.5) 299 | parser.add_argument('--cond_mod', type=bool, default= True,) 300 | parser.add_argument('--style_mod', type=bool, default= True,) 301 | parser.add_argument('--noise_injection', type=bool, default= True,) 302 | return parser 303 | def __init__( 304 | self, 305 | opt=None): # Arguments for sub-networks (mapping and synthesis). 306 | super().__init__() 307 | self.G_mapping = G_mapping(opt) 308 | self.G_synthesis = G_synthesis_co_mod_gan(opt) 309 | 310 | def forward( 311 | self, 312 | images_in=None, 313 | masks_in=None, 314 | latents_in=None, 315 | return_latents=False, 316 | inject_index=None, 317 | truncation=None, 318 | truncation_latent=None, 319 | input_is_latent=False, 320 | get_latent=False, 321 | ): 322 | #assert isinstance(latents_in, list) 323 | if not input_is_latent: 324 | dlatents_in = [self.G_mapping(s) for s in latents_in] 325 | else: 326 | dlatents_in = latents_in 327 | if get_latent: 328 | return dlatents_in 329 | if truncation is not None: 330 | dlatents_t = [] 331 | for style in dlatents_in: 332 | dlatents_t.append( 333 | truncation_latent + truncation * (style - truncation_latent) 334 | ) 335 | dlatents_in = dlatents_t 336 | if len(dlatents_in) < 2: 337 | inject_index = self.G_synthesis.num_layers 338 | if dlatents_in[0].ndim < 3: 339 | dlatent = dlatents_in[0].unsqueeze(1).repeat(1, inject_index, 1) 340 | else: 341 | dlatent = dlatents_in[0] 342 | else: 343 | if inject_index is None: 344 | inject_index = random.randint(1, self.G_synthesis.num_layers - 1) 345 | dlatent = dlatents_in[0].unsqueeze(1).repeat(1, inject_index, 1) 346 | dlatent2 = dlatents_in[1].unsqueeze(1).repeat(1, self.G_synthesis.num_layers - inject_index, 1) 347 | 348 | dlatent = torch.cat([dlatent, dlatent2], 1) 349 | output, raw_out = self.G_synthesis(images_in, masks_in, dlatent) 350 | if return_latents: 351 | return output, raw_out, dlatent 352 | else: 353 | return output, raw_out, None 354 | 355 | #---------------------------------------------------------------------------- 356 | # CoModGAN discriminator. 357 | 358 | class Discriminator(BaseNetwork): 359 | @staticmethod 360 | def modify_commandline_options(parser, is_train): 361 | parser.add_argument('--mbstd_num_features', type=int, default= 1, )# Number of features for the minibatch standard deviation layer. 362 | parser.add_argument('--mbstd_group_size', type=int, default= 4, )# Group size for the minibatch standard deviation layer, 0 = disable. 363 | return parser 364 | def __init__( 365 | self, 366 | opt): 367 | label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. 368 | architecture = 'resnet' # Architecture: 'orig', 'skip', 'resnet'. 369 | pix2pix = False 370 | assert not pix2pix 371 | assert opt.nonlinearity == 'lrelu' 372 | assert architecture == 'resnet' 373 | if opt is not None: 374 | resolution = opt.crop_size 375 | 376 | resolution_log2 = int(np.log2(resolution)) 377 | assert resolution == 2**resolution_log2 and resolution >= 4 378 | def nf(stage): return np.clip(int(opt.fmap_base / (2.0 ** (stage * opt.fmap_decay))), opt.fmap_min, opt.fmap_max) 379 | #assert architecture in ['orig', 'skip', 'resnet'] 380 | 381 | # Building blocks for main layers. 382 | super().__init__() 383 | layers = [] 384 | c_in = opt.num_channels+1 385 | layers.append( 386 | ( 387 | "ToRGB", 388 | ConvLayer( 389 | c_in, 390 | nf(resolution_log2-1), 391 | kernel_size=3, 392 | activate=True) 393 | ) 394 | ) 395 | 396 | class Block(nn.Module): 397 | def __init__(self, res): 398 | super().__init__() 399 | self.Conv0 = ConvLayer( 400 | nf(res-1), 401 | nf(res-1), 402 | kernel_size=3, 403 | activate=True) 404 | self.Conv1_down = ConvLayer( 405 | nf(res-1), 406 | nf(res-2), 407 | kernel_size=3, 408 | downsample=True, 409 | blur_kernel=opt.resample_kernel, 410 | activate=True) 411 | self.Skip = ConvLayer( 412 | nf(res-1), 413 | nf(res-2), 414 | kernel_size=1, 415 | downsample=True, 416 | blur_kernel=opt.resample_kernel, 417 | activate=False, 418 | bias=False) 419 | def forward(self, x): 420 | t = x 421 | x = self.Conv0(x) 422 | x = self.Conv1_down(x) 423 | t = self.Skip(t) 424 | x = (x + t) * (1/np.sqrt(2)) 425 | return x 426 | # Main layers. 427 | for res in range(resolution_log2, 2, -1): 428 | layers.append( 429 | ( 430 | '%dx%d' % (2**res, 2**res), 431 | Block(res) 432 | ) 433 | ) 434 | self.convs = nn.Sequential(OrderedDict(layers)) 435 | # Final layers. 436 | self.mbstd_group_size = opt.mbstd_group_size 437 | self.mbstd_num_features = opt.mbstd_num_features 438 | 439 | self.Conv4x4 = ConvLayer(nf(1)+1, nf(1), kernel_size=3, activate=True) 440 | self.Dense0 = EqualLinear(nf(1)*4*4, nf(0), activation='fused_lrelu') 441 | self.Output = EqualLinear(nf(0), 1) 442 | 443 | def forward(self, images_in, masks_in): 444 | masks_in = 1-masks_in 445 | y = torch.cat([masks_in - 0.5, images_in], 1) 446 | out = self.convs(y) 447 | batch, channel, height, width = out.shape 448 | group_size = min(batch, self.mbstd_group_size) 449 | #print(out.shape) 450 | #pdb.set_trace() 451 | stddev = out.view( 452 | group_size, 453 | -1, 454 | self.mbstd_num_features, 455 | channel // self.mbstd_num_features, 456 | height, width 457 | ) 458 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 459 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 460 | stddev = stddev.repeat(group_size, 1, height, width) 461 | out = torch.cat([out, stddev], 1) 462 | out = self.Conv4x4(out) 463 | out = out.view(batch, -1) 464 | out = self.Dense0(out) 465 | out = self.Output(out) 466 | return out 467 | 468 | 469 | 470 | if __name__ == "__main__": 471 | import cv2 472 | from PIL import Image 473 | path_img = "/home/zeng/co-mod-gan/imgs/example_image.jpg" 474 | path_mask = "/home/zeng/co-mod-gan/imgs/example_mask.jpg" 475 | 476 | real = np.asarray(Image.open(path_img)).transpose([2, 0, 1])/255.0 477 | 478 | masks = np.asarray(Image.open(path_mask).convert('1'), dtype=np.float32) 479 | 480 | images = torch.Tensor(real.copy())[None,...]*2-1 481 | masks = torch.Tensor(masks)[None,None,...].float() 482 | masks = (masks==0).float() 483 | 484 | net = Discriminator() 485 | hh = net(images, masks) 486 | pdb.set_trace() 487 | 488 | #net = Generator() 489 | #net.G_mapping.load_from_tf_dict("/home/zeng/co-mod-gan/co-mod-gan-ffhq-9-025000.npz") 490 | #net.G_synthesis.load_from_tf_dict("/home/zeng/co-mod-gan/co-mod-gan-ffhq-9-025000.npz") 491 | #net.eval() 492 | #torch.save(net.state_dict(), "co-mod-gan-ffhq-9-025000.pth") 493 | 494 | #latents_in = torch.randn(1, 512) 495 | 496 | #hh = net(images, masks, [latents_in], truncation=None) 497 | #hh = hh.detach().cpu().numpy() 498 | #hh = (hh+1)/2 499 | #hh = (hh[0].transpose((1,2,0)))*255 500 | #cv2.imwrite("hh.png", hh[:,:,::-1].clip(0,255)) 501 | #pdb.set_trace() 502 | -------------------------------------------------------------------------------- /models/networks/discriminator.py: -------------------------------------------------------------------------------- 1 | from models.networks.co_mod_gan import Discriminator as CoModGANDiscriminator 2 | -------------------------------------------------------------------------------- /models/networks/generator.py: -------------------------------------------------------------------------------- 1 | from models.networks.co_mod_gan import Generator as CoModGANGenerator 2 | -------------------------------------------------------------------------------- /models/networks/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from models.networks.architecture import VGG19 10 | 11 | 12 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 13 | # When LSGAN is used, it is basically same as MSELoss, 14 | # but it abstracts away the need to create the target label tensor 15 | # that has the same size as the input 16 | class GANLoss(nn.Module): 17 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, 18 | tensor=torch.FloatTensor, opt=None): 19 | super(GANLoss, self).__init__() 20 | self.real_label = target_real_label 21 | self.fake_label = target_fake_label 22 | self.real_label_tensor = None 23 | self.fake_label_tensor = None 24 | self.zero_tensor = None 25 | self.Tensor = tensor 26 | self.gan_mode = gan_mode 27 | self.opt = opt 28 | if gan_mode == 'ls': 29 | pass 30 | elif gan_mode == 'original': 31 | pass 32 | elif gan_mode == 'w': 33 | pass 34 | elif gan_mode == 'hinge': 35 | pass 36 | elif gan_mode == 'softplus': 37 | pass 38 | else: 39 | raise ValueError('Unexpected gan_mode {}'.format(gan_mode)) 40 | 41 | def get_target_tensor(self, input, target_is_real): 42 | if target_is_real: 43 | if self.real_label_tensor is None: 44 | self.real_label_tensor = self.Tensor(1).fill_(self.real_label) 45 | self.real_label_tensor.requires_grad_(False) 46 | return self.real_label_tensor.expand_as(input) 47 | else: 48 | if self.fake_label_tensor is None: 49 | self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label) 50 | self.fake_label_tensor.requires_grad_(False) 51 | return self.fake_label_tensor.expand_as(input) 52 | 53 | def get_zero_tensor(self, input): 54 | if self.zero_tensor is None: 55 | self.zero_tensor = self.Tensor(1).fill_(0) 56 | self.zero_tensor.requires_grad_(False) 57 | return self.zero_tensor.expand_as(input) 58 | 59 | def loss(self, input, target_is_real, for_discriminator=True): 60 | if self.gan_mode == 'original': # cross entropy loss 61 | target_tensor = self.get_target_tensor(input, target_is_real) 62 | loss = F.binary_cross_entropy_with_logits(input, target_tensor) 63 | return loss 64 | elif self.gan_mode == 'ls': 65 | target_tensor = self.get_target_tensor(input, target_is_real) 66 | return F.mse_loss(input, target_tensor) 67 | elif self.gan_mode == 'hinge': 68 | if for_discriminator: 69 | if target_is_real: 70 | minval = torch.min(input - 1, self.get_zero_tensor(input)) 71 | loss = -torch.mean(minval) 72 | else: 73 | minval = torch.min(-input - 1, self.get_zero_tensor(input)) 74 | loss = -torch.mean(minval) 75 | else: 76 | assert target_is_real, "The generator's hinge loss must be aiming for real" 77 | loss = -torch.mean(input) 78 | return loss 79 | elif self.gan_mode == 'softplus': 80 | # wgan 81 | if target_is_real: 82 | return F.softplus(-input).mean() 83 | else: 84 | return F.softplus(input).mean() 85 | else: 86 | # wgan 87 | if target_is_real: 88 | return -input.mean() 89 | else: 90 | return input.mean() 91 | 92 | def __call__(self, input, target_is_real, for_discriminator=True): 93 | # computing loss is a bit complicated because |input| may not be 94 | # a tensor, but list of tensors in case of multiscale discriminator 95 | if isinstance(input, list): 96 | loss = 0 97 | for pred_i in input: 98 | if isinstance(pred_i, list): 99 | pred_i = pred_i[-1] 100 | loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) 101 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) 102 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) 103 | loss += new_loss 104 | return loss / len(input) 105 | else: 106 | return self.loss(input, target_is_real, for_discriminator) 107 | 108 | 109 | # Perceptual loss that uses a pretrained VGG network 110 | class VGGLoss(nn.Module): 111 | def __init__(self, gpu_ids): 112 | super(VGGLoss, self).__init__() 113 | self.vgg = VGG19().cuda() 114 | self.criterion = nn.L1Loss() 115 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] 116 | 117 | def forward(self, x, y, **kwargs): 118 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 119 | loss = 0 120 | for i in range(len(x_vgg)): 121 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 122 | return loss 123 | 124 | class MaskedVGGLoss(VGGLoss): 125 | def forward(self, x, y, mask, **kwargs): 126 | x_vgg, y_vgg = self.vgg(x*mask), self.vgg(y*mask) 127 | loss = 0 128 | for i in range(len(x_vgg)): 129 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 130 | return loss 131 | 132 | 133 | class VGGFaceLoss(nn.Module): 134 | def __init__(self, gpu_ids, weights_path): 135 | super(VGGFaceLoss, self).__init__() 136 | self.vgg = VGGFace(weights_path=weights_path).cuda() 137 | self.criterion = nn.L1Loss() 138 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] 139 | 140 | def forward(self, x, y): 141 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 142 | loss = 0 143 | for i in range(len(x_vgg)): 144 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 145 | return loss 146 | 147 | 148 | # KL Divergence loss used in VAE with an image encoder 149 | class KLDLoss(nn.Module): 150 | def forward(self, mu, logvar): 151 | return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 152 | -------------------------------------------------------------------------------- /models/networks/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /models/networks/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, bias, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output, empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | if bias: 39 | grad_bias = grad_input.sum(dim).detach() 40 | 41 | else: 42 | grad_bias = empty 43 | 44 | return grad_input, grad_bias 45 | 46 | @staticmethod 47 | def backward(ctx, gradgrad_input, gradgrad_bias): 48 | out, = ctx.saved_tensors 49 | gradgrad_out = fused.fused_bias_act( 50 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 51 | ) 52 | 53 | return gradgrad_out, None, None, None, None 54 | 55 | 56 | class FusedLeakyReLUFunction(Function): 57 | @staticmethod 58 | def forward(ctx, input, bias, negative_slope, scale): 59 | empty = input.new_empty(0) 60 | 61 | ctx.bias = bias is not None 62 | 63 | if bias is None: 64 | bias = empty 65 | 66 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 67 | ctx.save_for_backward(out) 68 | ctx.negative_slope = negative_slope 69 | ctx.scale = scale 70 | 71 | return out 72 | 73 | @staticmethod 74 | def backward(ctx, grad_output): 75 | out, = ctx.saved_tensors 76 | 77 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 78 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale 79 | ) 80 | 81 | if not ctx.bias: 82 | grad_bias = None 83 | 84 | return grad_input, grad_bias, None, None 85 | 86 | 87 | class FusedLeakyReLU(nn.Module): 88 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 89 | super().__init__() 90 | 91 | if bias: 92 | self.bias = nn.Parameter(torch.zeros(channel)) 93 | 94 | else: 95 | self.bias = None 96 | 97 | self.negative_slope = negative_slope 98 | self.scale = scale 99 | 100 | def forward(self, input): 101 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 102 | 103 | 104 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 105 | if input.device.type == "cpu": 106 | if bias is not None: 107 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 108 | return ( 109 | F.leaky_relu( 110 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 111 | ) 112 | * scale 113 | ) 114 | 115 | else: 116 | return F.leaky_relu(input, negative_slope=0.2) * scale 117 | 118 | else: 119 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 120 | -------------------------------------------------------------------------------- /models/networks/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /models/networks/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /models/networks/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /models/networks/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | upfirdn2d_op = load( 11 | "upfirdn2d", 12 | sources=[ 13 | os.path.join(module_path, "upfirdn2d.cpp"), 14 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 15 | ], 16 | ) 17 | 18 | 19 | class UpFirDn2dBackward(Function): 20 | @staticmethod 21 | def forward( 22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 23 | ): 24 | 25 | up_x, up_y = up 26 | down_x, down_y = down 27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 28 | 29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 30 | 31 | grad_input = upfirdn2d_op.upfirdn2d( 32 | grad_output, 33 | grad_kernel, 34 | down_x, 35 | down_y, 36 | up_x, 37 | up_y, 38 | g_pad_x0, 39 | g_pad_x1, 40 | g_pad_y0, 41 | g_pad_y1, 42 | ) 43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 44 | 45 | ctx.save_for_backward(kernel) 46 | 47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 48 | 49 | ctx.up_x = up_x 50 | ctx.up_y = up_y 51 | ctx.down_x = down_x 52 | ctx.down_y = down_y 53 | ctx.pad_x0 = pad_x0 54 | ctx.pad_x1 = pad_x1 55 | ctx.pad_y0 = pad_y0 56 | ctx.pad_y1 = pad_y1 57 | ctx.in_size = in_size 58 | ctx.out_size = out_size 59 | 60 | return grad_input 61 | 62 | @staticmethod 63 | def backward(ctx, gradgrad_input): 64 | kernel, = ctx.saved_tensors 65 | 66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 67 | 68 | gradgrad_out = upfirdn2d_op.upfirdn2d( 69 | gradgrad_input, 70 | kernel, 71 | ctx.up_x, 72 | ctx.up_y, 73 | ctx.down_x, 74 | ctx.down_y, 75 | ctx.pad_x0, 76 | ctx.pad_x1, 77 | ctx.pad_y0, 78 | ctx.pad_y1, 79 | ) 80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 81 | gradgrad_out = gradgrad_out.view( 82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 83 | ) 84 | 85 | return gradgrad_out, None, None, None, None, None, None, None, None 86 | 87 | 88 | class UpFirDn2d(Function): 89 | @staticmethod 90 | def forward(ctx, input, kernel, up, down, pad): 91 | up_x, up_y = up 92 | down_x, down_y = down 93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 94 | 95 | kernel_h, kernel_w = kernel.shape 96 | batch, channel, in_h, in_w = input.shape 97 | ctx.in_size = input.shape 98 | 99 | input = input.reshape(-1, in_h, in_w, 1) 100 | 101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 102 | 103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 105 | ctx.out_size = (out_h, out_w) 106 | 107 | ctx.up = (up_x, up_y) 108 | ctx.down = (down_x, down_y) 109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 110 | 111 | g_pad_x0 = kernel_w - pad_x0 - 1 112 | g_pad_y0 = kernel_h - pad_y0 - 1 113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 115 | 116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 117 | 118 | out = upfirdn2d_op.upfirdn2d( 119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 120 | ) 121 | # out = out.view(major, out_h, out_w, minor) 122 | out = out.view(-1, channel, out_h, out_w) 123 | 124 | return out 125 | 126 | @staticmethod 127 | def backward(ctx, grad_output): 128 | kernel, grad_kernel = ctx.saved_tensors 129 | 130 | grad_input = UpFirDn2dBackward.apply( 131 | grad_output, 132 | kernel, 133 | grad_kernel, 134 | ctx.up, 135 | ctx.down, 136 | ctx.pad, 137 | ctx.g_pad, 138 | ctx.in_size, 139 | ctx.out_size, 140 | ) 141 | 142 | return grad_input, None, None, None, None 143 | 144 | 145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 146 | if input.device.type == "cpu": 147 | out = upfirdn2d_native( 148 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 149 | ) 150 | 151 | else: 152 | out = UpFirDn2d.apply( 153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 154 | ) 155 | 156 | return out 157 | 158 | 159 | def upfirdn2d_native( 160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 161 | ): 162 | _, channel, in_h, in_w = input.shape 163 | input = input.reshape(-1, in_h, in_w, 1) 164 | 165 | _, in_h, in_w, minor = input.shape 166 | kernel_h, kernel_w = kernel.shape 167 | 168 | out = input.view(-1, in_h, 1, in_w, 1, minor) 169 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 170 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 171 | 172 | out = F.pad( 173 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 174 | ) 175 | out = out[ 176 | :, 177 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 178 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 179 | :, 180 | ] 181 | 182 | out = out.permute(0, 3, 1, 2) 183 | out = out.reshape( 184 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 185 | ) 186 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 187 | out = F.conv2d(out, w) 188 | out = out.reshape( 189 | -1, 190 | minor, 191 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 192 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 193 | ) 194 | out = out.permute(0, 2, 3, 1) 195 | out = out[:, ::down_y, ::down_x, :] 196 | 197 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 198 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 199 | 200 | return out.view(-1, channel, out_h, out_w) 201 | -------------------------------------------------------------------------------- /models/networks/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /models/networks/stylegan2.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pdb 3 | import random 4 | import functools 5 | import operator 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from torch.autograd import Function 11 | try: 12 | from models.networks.sync_batchnorm import SynchronizedBatchNorm2d 13 | except: 14 | pass 15 | 16 | from models.networks.base_network import BaseNetwork 17 | from models.networks.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d 18 | 19 | #from base_network import BaseNetwork 20 | #from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d 21 | 22 | 23 | class PixelNorm(nn.Module): 24 | def __init__(self): 25 | super().__init__() 26 | 27 | def forward(self, input): 28 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 29 | 30 | 31 | def make_kernel(k): 32 | k = torch.tensor(k, dtype=torch.float32) 33 | 34 | if k.ndim == 1: 35 | k = k[None, :] * k[:, None] 36 | 37 | k /= k.sum() 38 | 39 | return k 40 | 41 | 42 | class Upsample(nn.Module): 43 | def __init__(self, kernel, factor=2): 44 | super().__init__() 45 | 46 | self.factor = factor 47 | kernel = make_kernel(kernel) * (factor ** 2) 48 | self.register_buffer("kernel", kernel) 49 | 50 | p = kernel.shape[0] - factor 51 | 52 | pad0 = (p + 1) // 2 + factor - 1 53 | pad1 = p // 2 54 | 55 | self.pad = (pad0, pad1) 56 | 57 | def forward(self, input): 58 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) 59 | 60 | return out 61 | 62 | 63 | class Downsample(nn.Module): 64 | def __init__(self, kernel, factor=2): 65 | super().__init__() 66 | 67 | self.factor = factor 68 | kernel = make_kernel(kernel) 69 | self.register_buffer("kernel", kernel) 70 | 71 | p = kernel.shape[0] - factor 72 | 73 | pad0 = (p + 1) // 2 74 | pad1 = p // 2 75 | 76 | self.pad = (pad0, pad1) 77 | 78 | def forward(self, input): 79 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) 80 | 81 | return out 82 | 83 | 84 | class Blur(nn.Module): 85 | def __init__(self, kernel, pad, upsample_factor=1): 86 | super().__init__() 87 | 88 | kernel = make_kernel(kernel) 89 | 90 | if upsample_factor > 1: 91 | kernel = kernel * (upsample_factor ** 2) 92 | 93 | self.register_buffer("kernel", kernel) 94 | 95 | self.pad = pad 96 | 97 | def forward(self, input): 98 | out = upfirdn2d(input, self.kernel, pad=self.pad) 99 | 100 | return out 101 | 102 | 103 | class EqualConv2d(nn.Module): 104 | def __init__( 105 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 106 | ): 107 | super().__init__() 108 | 109 | self.weight = nn.Parameter( 110 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 111 | ) 112 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 113 | 114 | self.stride = stride 115 | self.padding = padding 116 | 117 | if bias: 118 | self.bias = nn.Parameter(torch.zeros(out_channel)) 119 | 120 | else: 121 | self.bias = None 122 | 123 | def forward(self, input): 124 | out = F.conv2d( 125 | input, 126 | self.weight * self.scale, 127 | bias=self.bias, 128 | stride=self.stride, 129 | padding=self.padding, 130 | ) 131 | 132 | return out 133 | 134 | def __repr__(self): 135 | return ( 136 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," 137 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" 138 | ) 139 | 140 | 141 | class EqualLinear(nn.Module): 142 | def __init__( 143 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 144 | ): 145 | super().__init__() 146 | 147 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 148 | 149 | if bias: 150 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 151 | 152 | else: 153 | self.bias = None 154 | 155 | self.activation = activation 156 | 157 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 158 | self.lr_mul = lr_mul 159 | 160 | def forward(self, input): 161 | if self.activation: 162 | out = F.linear(input, self.weight * self.scale) 163 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 164 | 165 | else: 166 | out = F.linear( 167 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 168 | ) 169 | 170 | return out 171 | 172 | def __repr__(self): 173 | return ( 174 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" 175 | ) 176 | 177 | 178 | class ModulatedConv2d(nn.Module): 179 | def __init__( 180 | self, 181 | in_channel, 182 | out_channel, 183 | kernel_size, 184 | style_dim, 185 | demodulate=True, 186 | upsample=False, 187 | downsample=False, 188 | blur_kernel=[1, 3, 3, 1], 189 | ): 190 | super().__init__() 191 | 192 | self.eps = 1e-8 193 | self.kernel_size = kernel_size 194 | self.in_channel = in_channel 195 | self.out_channel = out_channel 196 | self.upsample = upsample 197 | self.downsample = downsample 198 | 199 | if upsample: 200 | factor = 2 201 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 202 | pad0 = (p + 1) // 2 + factor - 1 203 | pad1 = p // 2 + 1 204 | 205 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) 206 | 207 | if downsample: 208 | factor = 2 209 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 210 | pad0 = (p + 1) // 2 211 | pad1 = p // 2 212 | 213 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 214 | 215 | fan_in = in_channel * kernel_size ** 2 216 | self.scale = 1 / math.sqrt(fan_in) 217 | self.padding = kernel_size // 2 218 | 219 | self.weight = nn.Parameter( 220 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 221 | ) 222 | 223 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 224 | 225 | self.demodulate = demodulate 226 | 227 | def __repr__(self): 228 | return ( 229 | f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " 230 | f"upsample={self.upsample}, downsample={self.downsample})" 231 | ) 232 | 233 | def forward(self, input, style, **kwargs): 234 | batch, in_channel, height, width = input.shape 235 | 236 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1) 237 | weight = self.scale * self.weight * style 238 | 239 | if self.demodulate: 240 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 241 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 242 | 243 | weight = weight.view( 244 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 245 | ) 246 | 247 | if self.upsample: 248 | input = input.view(1, batch * in_channel, height, width) 249 | weight = weight.view( 250 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 251 | ) 252 | weight = weight.transpose(1, 2).reshape( 253 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 254 | ) 255 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) 256 | _, _, height, width = out.shape 257 | out = out.view(batch, self.out_channel, height, width) 258 | out = self.blur(out) 259 | 260 | elif self.downsample: 261 | input = self.blur(input) 262 | _, _, height, width = input.shape 263 | input = input.view(1, batch * in_channel, height, width) 264 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) 265 | _, _, height, width = out.shape 266 | out = out.view(batch, self.out_channel, height, width) 267 | 268 | else: 269 | input = input.view(1, batch * in_channel, height, width) 270 | out = F.conv2d(input, weight, padding=self.padding, groups=batch) 271 | _, _, height, width = out.shape 272 | out = out.view(batch, self.out_channel, height, width) 273 | 274 | return out 275 | 276 | def PositionalNorm2d(x, epsilon=1e-5): 277 | # x: B*C*W*H normalize in C dim 278 | mean = x.mean(dim=1, keepdim=True) 279 | std = x.var(dim=1, keepdim=True).add(epsilon).sqrt() 280 | output = (x - mean) / std 281 | return output 282 | 283 | class WeightedConv2d(ModulatedConv2d): 284 | def __init__( 285 | self, *args, **kwargs 286 | ): 287 | num_weight = kwargs.pop("num_weight") 288 | super().__init__(*args, **kwargs) 289 | self.param_free_norm = PositionalNorm2d 290 | ks = 1 291 | pw = ks // 2 292 | # stylein seg or feature 293 | #nhidden = 128 294 | #self.mpl_shared = nn.Sequential( 295 | # EqualConv2d( 296 | # num_weight, 297 | # nhidden, 298 | # 1), 299 | # FusedLeakyReLU(nhidden)) 300 | nhidden = num_weight 301 | self.mpl_gamma = EqualConv2d( 302 | nhidden, 303 | self.out_channel, 304 | ks, 305 | padding=pw, 306 | ) 307 | self.mpl_beta = EqualConv2d( 308 | nhidden, 309 | self.out_channel, 310 | ks, 311 | padding=pw, 312 | ) 313 | 314 | def forward(self, input, style, skip): 315 | out = super().forward(input, style) 316 | # Part 1. generate parameter-free normalized activations 317 | normalized = self.param_free_norm(out) 318 | # Part 2. produce scaling and bias conditioned on semantic map 319 | #hidden = self.mpl_shared(skip) 320 | hidden = skip 321 | gamma = self.mpl_gamma(hidden) 322 | beta = self.mpl_beta(hidden) 323 | # apply scale and bias 324 | out = normalized * (1 + gamma) + beta 325 | 326 | return out 327 | 328 | 329 | class NoiseInjection(nn.Module): 330 | def __init__(self): 331 | super().__init__() 332 | 333 | self.weight = nn.Parameter(torch.zeros(1)) 334 | 335 | def forward(self, image, noise=None): 336 | if noise is None: 337 | batch, _, height, width = image.shape 338 | noise = image.new_empty(batch, 1, height, width).normal_() 339 | 340 | return image + self.weight * noise 341 | 342 | 343 | class ConstantInput(nn.Module): 344 | def __init__(self, channel, size=4): 345 | super().__init__() 346 | 347 | self.input = nn.Parameter(torch.randn(1, channel, size, size)) 348 | 349 | def forward(self, input): 350 | batch = input.shape[0] 351 | out = self.input.repeat(batch, 1, 1, 1) 352 | 353 | return out 354 | 355 | 356 | class StyledConv(nn.Module): 357 | def __init__( 358 | self, 359 | in_channel, 360 | out_channel, 361 | kernel_size, 362 | style_dim, 363 | upsample=False, 364 | blur_kernel=[1, 3, 3, 1], 365 | demodulate=True, 366 | weightedconv=False, 367 | num_weight=None, 368 | ): 369 | super().__init__() 370 | 371 | if weightedconv: 372 | self.conv = WeightedConv2d( 373 | in_channel, 374 | out_channel, 375 | kernel_size, 376 | style_dim, 377 | upsample=upsample, 378 | blur_kernel=blur_kernel, 379 | demodulate=demodulate, 380 | num_weight=num_weight 381 | ) 382 | else: 383 | self.conv = ModulatedConv2d( 384 | in_channel, 385 | out_channel, 386 | kernel_size, 387 | style_dim, 388 | upsample=upsample, 389 | blur_kernel=blur_kernel, 390 | demodulate=demodulate, 391 | ) 392 | 393 | self.noise = NoiseInjection() 394 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 395 | # self.activate = ScaledLeakyReLU(0.2) 396 | self.activate = FusedLeakyReLU(out_channel) 397 | 398 | def forward(self, input, style, noise=None, x_skip=None): 399 | out = self.conv(input, style, skip=x_skip) 400 | out = self.noise(out, noise=noise) 401 | # out = out + self.bias 402 | out = self.activate(out) 403 | 404 | return out 405 | 406 | class ConvToRGB(nn.Module): 407 | def __init__(self, in_channel, upsample=True, blur_kernel=[1, 3, 3, 1], out_channel=3): 408 | super().__init__() 409 | 410 | if upsample: 411 | self.upsample = Upsample(blur_kernel) 412 | 413 | self.conv = ConvLayer(in_channel, out_channel, 1) 414 | 415 | 416 | def forward(self, input, skip=None): 417 | out = self.conv(input) 418 | 419 | if skip is not None: 420 | skip = self.upsample(skip) 421 | 422 | out = out + skip 423 | 424 | return out 425 | 426 | 427 | class ToRGB(nn.Module): 428 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1], out_channel=3,weightedconv=False, 429 | num_weight=None): 430 | super().__init__() 431 | 432 | if upsample: 433 | self.upsample = Upsample(blur_kernel) 434 | if weightedconv: 435 | self.conv = WeightedConv2d( 436 | in_channel, out_channel, 1, 437 | style_dim, demodulate=False, num_weight=num_weight) 438 | else: 439 | self.conv = ModulatedConv2d(in_channel, out_channel, 1, style_dim, demodulate=False) 440 | self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 441 | 442 | def forward(self, input, style, skip=None, x_skip=None): 443 | out = self.conv(input, style, skip=x_skip) 444 | out = out + self.bias 445 | 446 | if skip is not None: 447 | skip = self.upsample(skip) 448 | 449 | out = out + skip 450 | 451 | return out 452 | 453 | 454 | class Generator(BaseNetwork): 455 | def __init__( 456 | self, 457 | opt 458 | ): 459 | super().__init__() 460 | size = opt.crop_size 461 | style_dim = opt.z_dim 462 | n_mlp = 8 463 | channel_multiplier=2 464 | blur_kernel=[1, 3, 3, 1] 465 | lr_mlp=0.01 466 | 467 | self.size = size 468 | 469 | self.style_dim = style_dim 470 | 471 | layers = [PixelNorm()] 472 | 473 | for i in range(n_mlp): 474 | layers.append( 475 | EqualLinear( 476 | style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" 477 | ) 478 | ) 479 | 480 | self.style = nn.Sequential(*layers) 481 | 482 | self.channels = { 483 | 4: 512, 484 | 8: 512, 485 | 16: 512, 486 | 32: 512, 487 | 64: 256 * channel_multiplier, 488 | 128: 128 * channel_multiplier, 489 | 256: 64 * channel_multiplier, 490 | 512: 32 * channel_multiplier, 491 | 1024: 16 * channel_multiplier, 492 | } 493 | 494 | self.input = ConstantInput(self.channels[4]) 495 | self.conv1 = StyledConv( 496 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel 497 | ) 498 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) 499 | 500 | self.log_size = int(math.log(size, 2)) 501 | self.num_layers = (self.log_size - 2) * 2 + 1 502 | 503 | self.convs = nn.ModuleList() 504 | self.upsamples = nn.ModuleList() 505 | self.to_rgbs = nn.ModuleList() 506 | self.noises = nn.Module() 507 | 508 | in_channel = self.channels[4] 509 | 510 | for layer_idx in range(self.num_layers): 511 | res = (layer_idx + 5) // 2 512 | shape = [1, 1, 2 ** res, 2 ** res] 513 | self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape)) 514 | 515 | for i in range(3, self.log_size + 1): 516 | out_channel = self.channels[2 ** i] 517 | 518 | self.convs.append( 519 | StyledConv( 520 | in_channel, 521 | out_channel, 522 | 3, 523 | style_dim, 524 | upsample=True, 525 | blur_kernel=blur_kernel, 526 | ) 527 | ) 528 | 529 | self.convs.append( 530 | StyledConv( 531 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel 532 | ) 533 | ) 534 | 535 | self.to_rgbs.append(ToRGB(out_channel, style_dim)) 536 | 537 | in_channel = out_channel 538 | 539 | self.n_latent = self.log_size * 2 - 2 540 | 541 | def make_noise(self): 542 | device = self.input.input.device 543 | 544 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] 545 | 546 | for i in range(3, self.log_size + 1): 547 | for _ in range(2): 548 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) 549 | 550 | return noises 551 | 552 | def mean_latent(self, n_latent): 553 | latent_in = torch.randn( 554 | n_latent, self.style_dim, device=self.input.input.device 555 | ) 556 | latent = self.style(latent_in).mean(0, keepdim=True) 557 | 558 | return latent 559 | 560 | def get_latent(self, input): 561 | return self.style(input) 562 | 563 | def forward( 564 | self, 565 | styles, 566 | return_latents=False, 567 | inject_index=None, 568 | truncation=None, 569 | truncation_latent=None, 570 | input_is_latent=False, 571 | noise=None, 572 | randomize_noise=True, 573 | get_latent=False, 574 | ): 575 | if not input_is_latent: 576 | styles = [self.style(s) for s in styles] 577 | if get_latent: 578 | return styles 579 | 580 | if noise is None: 581 | if randomize_noise: 582 | noise = [None] * self.num_layers 583 | else: 584 | noise = [ 585 | getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) 586 | ] 587 | 588 | if truncation is not None: 589 | assert 025}: {:<30}{}\n'.format(str(k), str(v), comment)) 122 | 123 | with open(file_name + '.pkl', 'wb') as opt_file: 124 | pickle.dump(opt, opt_file) 125 | 126 | def update_options_from_file(self, parser, opt): 127 | new_opt = self.load_options(opt) 128 | for k, v in sorted(vars(opt).items()): 129 | if hasattr(new_opt, k) and v != getattr(new_opt, k): 130 | new_val = getattr(new_opt, k) 131 | parser.set_defaults(**{k: new_val}) 132 | return parser 133 | 134 | def load_options(self, opt): 135 | file_name = self.option_file_path(opt, makedir=False) 136 | new_opt = pickle.load(open(file_name + '.pkl', 'rb')) 137 | return new_opt 138 | 139 | def parse(self, save=False): 140 | 141 | opt = self.gather_options() 142 | opt.isTrain = self.isTrain # train or test 143 | 144 | self.print_options(opt) 145 | if opt.isTrain: 146 | self.save_options(opt) 147 | 148 | # set gpu ids 149 | str_ids = opt.gpu_ids.split(',') 150 | opt.gpu_ids = [] 151 | for str_id in str_ids: 152 | id = int(str_id) 153 | if id >= 0: 154 | opt.gpu_ids.append(id) 155 | if len(opt.gpu_ids) > 0: 156 | torch.cuda.set_device(opt.gpu_ids[0]) 157 | 158 | assert len(opt.gpu_ids) == 0 or opt.batchSize % len(opt.gpu_ids) == 0, \ 159 | "Batch size %d is wrong. It must be a multiple of # GPUs %d." \ 160 | % (opt.batchSize, len(opt.gpu_ids)) 161 | 162 | self.opt = opt 163 | return self.opt 164 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from .base_options import BaseOptions 7 | 8 | 9 | class TestOptions(BaseOptions): 10 | def initialize(self, parser): 11 | BaseOptions.initialize(self, parser) 12 | parser.add_argument('--dataset_mode', type=str, default='coco') 13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 14 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 15 | parser.add_argument('--how_many', type=int, default=float("inf"), help='how many test images to run') 16 | 17 | parser.set_defaults(preprocess_mode='scale_width_and_crop', crop_size=256, load_size=256, display_winsize=256) 18 | parser.set_defaults(serial_batches=True) 19 | parser.set_defaults(no_flip=True) 20 | parser.set_defaults(phase='test') 21 | self.isTrain = False 22 | return parser 23 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from .base_options import BaseOptions 7 | 8 | 9 | class TrainOptions(BaseOptions): 10 | def initialize(self, parser): 11 | BaseOptions.initialize(self, parser) 12 | parser.add_argument('--save_remote_gs', type=str, required=False) 13 | parser.add_argument('--trainer', type=str, default='stylegan2') 14 | # for displays 15 | parser.add_argument('--display_freq', type=int, default=101, help='frequency of showing training results on screen') 16 | parser.add_argument('--print_freq', type=int, default=101, help='frequency of showing training results on console') 17 | parser.add_argument('--save_latest_freq', type=int, default=50000, help='frequency of saving the latest results') 18 | parser.add_argument('--validation_freq', type=int, default=50000, help='frequency of saving the latest results') 19 | parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs') 20 | parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration') 21 | # datast 22 | parser.add_argument('--dataset_mode_train', type=str, default='coco') 23 | parser.add_argument('--dataset_mode', type=str, default='coco') 24 | parser.add_argument('--dataset_mode_val', type=str, required=False) 25 | 26 | # for training 27 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 28 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 29 | parser.add_argument('--niter', type=int, default=50, help='# of iter at starting learning rate. This is NOT the total #epochs. Totla #epochs is niter + niter_decay') 30 | parser.add_argument('--niter_decay', type=int, default=0, help='# of iter to linearly decay learning rate to zero') 31 | parser.add_argument('--optimizer', type=str, default='adam') 32 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 33 | parser.add_argument('--D_steps_per_G', type=int, default=1, help='number of discriminator iterations per generator iterations.') 34 | 35 | # for discriminators 36 | parser.add_argument('--lambda_vgg', type=float, default=10.0, help='weight for vgg loss') 37 | parser.add_argument('--lambda_l1', type=float, default=1.0, help='weight for l1 loss') 38 | parser.add_argument('--no_l1_loss', action='store_true', help='if specified, do *not* use l1 loss') 39 | parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss') 40 | parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)') 41 | parser.add_argument('--netD', type=str, default='comodgan') 42 | parser.add_argument('--freeze_D', action='store_true', help='do not update D') 43 | self.isTrain = True 44 | return parser 45 | -------------------------------------------------------------------------------- /output/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/output/.gitkeep -------------------------------------------------------------------------------- /save_remote_gs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | from datetime import datetime 4 | 5 | def init_remote(opt): 6 | os.system(f"rm -rf output/{opt.name}") 7 | cwd = os.getcwd() 8 | os.system(f"gsutil cp -r {opt.save_remote_gs}/{opt.name} ./output/") 9 | if os.path.exists(f"output/{opt.name}/iter.txt") and not os.path.exists(f"checkpoints/{opt.name}/iter.txt"): 10 | os.system(f"cp output/{opt.name}/latest_net_*.pth checkpoints/{opt.name}/") 11 | os.system(f"cp output/{opt.name}/iter.txt checkpoints/{opt.name}/") 12 | 13 | def upload_remote(opt): 14 | os.system(f"gsutil cp -r {opt.save_remote_gs}/{opt.name}/savemodel ./output/{opt.name}") 15 | now = datetime.now() 16 | dt_string = now.strftime("%d/%m/%Y %H:%M:%S") 17 | os.system(f"cp output/{opt.name}.html output/{opt.name}/") 18 | os.system(f"cp checkpoints/{opt.name}/opt.txt output/{opt.name}/") 19 | os.system(f"cp checkpoints/{opt.name}/iter.txt output/{opt.name}/") 20 | os.system(f"echo {dt_string} > output/{opt.name}/time.txt") 21 | with open(f"output/{opt.name}/savemodel","r") as f: 22 | line = f.readlines() 23 | if line[0].startswith("y"): 24 | os.system(f"gsutil cp -r ./checkpoints/{opt.name}/latest_net_*.pth {opt.save_remote_gs}/{opt.name}/") 25 | with open(f"output/{opt.name}/savemodel", "w") as f: 26 | f.writelines("n") 27 | os.system(f"gsutil cp -r ./output/{opt.name} {opt.save_remote_gs}/") 28 | 29 | 30 | 31 | if __name__ == "__main__": 32 | class Temp: 33 | pass 34 | opt = Temp() 35 | opt.save_remote_gs = "gs://zengxianyu" 36 | opt.name = "cline" 37 | init_remote(opt) 38 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import pdb 7 | import cv2 8 | import os 9 | from collections import OrderedDict 10 | import json 11 | from tqdm import tqdm 12 | import numpy as np 13 | import torch 14 | import data 15 | from options.test_options import TestOptions 16 | #from models.pix2pix_model import Pix2PixModel 17 | import models 18 | 19 | 20 | opt = TestOptions().parse() 21 | 22 | dataloader = data.create_dataloader(opt) 23 | 24 | model = models.create_model(opt) 25 | model.eval() 26 | 27 | for i, data_i in tqdm(enumerate(dataloader)): 28 | if i * opt.batchSize >= opt.how_many: 29 | break 30 | with torch.no_grad(): 31 | generated,_ = model(data_i, mode='inference') 32 | generated = torch.clamp(generated, -1, 1) 33 | generated = (generated+1)/2*255 34 | generated = generated.cpu().numpy().astype(np.uint8) 35 | img_path = data_i['path'] 36 | for b in range(generated.shape[0]): 37 | pred_im = generated[b].transpose((1,2,0)) 38 | print('process image... %s' % img_path[b]) 39 | cv2.imwrite(img_path[b], pred_im[:,:,::-1]) 40 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python test.py \ 2 | --mixing 0 \ 3 | --batchSize 1 \ 4 | --nThreads 1 \ 5 | --name comod-ffhq-512 \ 6 | --dataset_mode testimage \ 7 | --image_dir ./ffhq_debug/images \ 8 | --mask_dir ./ffhq_debug/masks \ 9 | --output_dir ./ffhq_debug \ 10 | --load_size 512 \ 11 | --crop_size 512 \ 12 | --z_dim 512 \ 13 | --model comod \ 14 | --netG comodgan \ 15 | --which_epoch co-mod-gan-ffhq-9-025000 \ 16 | ${EXTRA} \ 17 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | import pdb 6 | import sys 7 | import torch 8 | import numpy as np 9 | from collections import OrderedDict 10 | from options.train_options import TrainOptions 11 | import data 12 | from util.iter_counter import IterationCounter 13 | from logger import Logger 14 | from torchvision.utils import make_grid 15 | from trainers import create_trainer 16 | from save_remote_gs import init_remote, upload_remote 17 | from models.networks.sync_batchnorm import DataParallelWithCallback 18 | from pytorch_fid.fid_model import FIDModel 19 | 20 | # parse options 21 | opt = TrainOptions().parse() 22 | 23 | # fid 24 | fid_model = FIDModel().cuda() 25 | fid_model.model = DataParallelWithCallback( 26 | fid_model.model, 27 | device_ids=opt.gpu_ids) 28 | 29 | 30 | # load remote 31 | if opt.save_remote_gs is not None: 32 | init_remote(opt) 33 | 34 | # print options to help debugging 35 | print(' '.join(sys.argv)) 36 | 37 | # load the dataset 38 | if opt.dataset_mode_val is not None: 39 | dataloader_train, dataloader_val = data.create_dataloader_trainval(opt) 40 | else: 41 | dataloader_train = data.create_dataloader(opt) 42 | dataloader_val = None 43 | 44 | # create trainer for our model 45 | trainer = create_trainer(opt) 46 | model = trainer.pix2pix_model 47 | 48 | # create tool for counting iterations 49 | iter_counter = IterationCounter(opt, len(dataloader_train)) 50 | 51 | # create tool for visualization 52 | writer = Logger(f"output/{opt.name}") 53 | with open(f"output/{opt.name}/savemodel", "w") as f: 54 | f.writelines("n") 55 | 56 | trainer.save('latest') 57 | 58 | def get_psnr(generated, gt): 59 | generated = (generated+1)/2*255 60 | bsize, c, h, w = gt.shape 61 | gt = (gt+1)/2*255 62 | mse = ((generated-gt)**2).sum(3).sum(2).sum(1) 63 | mse /= (c*h*w) 64 | psnr = 10*torch.log10(255.0*255.0 / (mse+1e-8)) 65 | return psnr.sum().item() 66 | 67 | def display_batch(epoch, data_i): 68 | losses = trainer.get_latest_losses() 69 | for k,v in losses.items(): 70 | writer.add_scalar(k,v.mean().item(), iter_counter.total_steps_so_far) 71 | writer.write_console(epoch, iter_counter.epoch_iter, iter_counter.time_per_iter) 72 | num_print = min(4, data_i['image'].size(0)) 73 | writer.add_single_image('inputs', 74 | (make_grid(trainer.get_latest_inputs()[:num_print])+1)/2, 75 | iter_counter.total_steps_so_far) 76 | infer_out,inp = trainer.pix2pix_model.forward(data_i, mode='inference') 77 | vis = (make_grid(inp[:num_print])+1)/2 78 | writer.add_single_image('infer_in', 79 | vis, 80 | iter_counter.total_steps_so_far) 81 | vis = (make_grid(infer_out[:num_print])+1)/2 82 | vis = torch.clamp(vis, 0,1) 83 | writer.add_single_image('infer_out', 84 | vis, 85 | iter_counter.total_steps_so_far) 86 | generated = trainer.get_latest_generated() 87 | for k,v in generated.items(): 88 | if v is None: 89 | continue 90 | if 'label' in k: 91 | vis = make_grid(v[:num_print].expand(-1,3,-1,-1))[0] 92 | writer.add_single_label(k, 93 | vis, 94 | iter_counter.total_steps_so_far) 95 | else: 96 | if v.size(1) == 3: 97 | vis = (make_grid(v[:num_print])+1)/2 98 | vis = torch.clamp(vis, 0,1) 99 | else: 100 | vis = make_grid(v[:num_print]) 101 | writer.add_single_image(k, 102 | vis, 103 | iter_counter.total_steps_so_far) 104 | writer.write_html() 105 | 106 | for epoch in iter_counter.training_epochs(): 107 | iter_counter.record_epoch_start(epoch) 108 | for i, data_i in enumerate(dataloader_train, start=iter_counter.epoch_iter): 109 | iter_counter.record_one_iteration() 110 | # train discriminator 111 | if not opt.freeze_D: 112 | trainer.run_discriminator_one_step(data_i, i) 113 | 114 | # Training 115 | # train generator 116 | if i % opt.D_steps_per_G == 0: 117 | trainer.run_generator_one_step(data_i, i) 118 | 119 | if iter_counter.needs_displaying(): 120 | display_batch(epoch, data_i) 121 | if opt.save_remote_gs is not None and iter_counter.needs_saving(): 122 | upload_remote(opt) 123 | if iter_counter.needs_validation(): 124 | print('saving the latest model (epoch %d, total_steps %d)' % 125 | (epoch, iter_counter.total_steps_so_far)) 126 | trainer.save('epoch%d_step%d'% 127 | (epoch, iter_counter.total_steps_so_far)) 128 | trainer.save('latest') 129 | iter_counter.record_current_iter() 130 | if dataloader_val is not None: 131 | print("doing validation") 132 | model.eval() 133 | num = 0 134 | psnr_total = 0 135 | for ii, data_ii in enumerate(dataloader_val): 136 | with torch.no_grad(): 137 | generated,_ = model(data_ii, mode='inference') 138 | generated = generated.cpu() 139 | gt = data_ii['image'] 140 | bsize = gt.size(0) 141 | psnr = get_psnr(generated, gt) 142 | psnr_total += psnr 143 | num += bsize 144 | fid_model.add_sample((generated+1)/2,(gt+1)/2) 145 | psnr_total /= num 146 | fid = fid_model.calculate_activation_statistics() 147 | writer.add_scalar("val.fid", fid, iter_counter.total_steps_so_far) 148 | writer.write_scalar("val.fid", fid, iter_counter.total_steps_so_far) 149 | writer.add_scalar("val.psnr", psnr_total, iter_counter.total_steps_so_far) 150 | writer.write_scalar("val.psnr", psnr_total, iter_counter.total_steps_so_far) 151 | writer.write_html() 152 | model.train() 153 | trainer.update_learning_rate(epoch) 154 | if epoch != 0 and epoch % 3 == 0 and opt.dataset_mode_train == 'cocomaskupdate': 155 | dataloader_train.dataset.update_dataset() 156 | iter_counter.record_epoch_end() 157 | 158 | print('Training was successfully finished.') 159 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | export CXX="g++" 2 | python train.py \ 3 | --batchSize 2 \ 4 | --nThreads 2 \ 5 | --name comod_places \ 6 | --train_image_dir ./datasets/places2sample1k_val/places2samples1k_crop256 \ 7 | --train_image_list ./datasets/places2sample1k_val/files.txt \ 8 | --train_image_postfix '.jpg' \ 9 | --val_image_dir ./datasets/places2sample1k_val/places2samples1k_crop256 \ 10 | --val_image_list ./datasets/places2sample1k_val/files.txt \ 11 | --val_mask_dir ./datasets/places2sample1k_val/places2samples1k_256_mask_square128 \ 12 | --load_size 512 \ 13 | --crop_size 256 \ 14 | --z_dim 512 \ 15 | --validation_freq 10000 \ 16 | --niter 50 \ 17 | --dataset_mode trainimage \ 18 | --trainer stylegan2 \ 19 | --dataset_mode_train trainimage \ 20 | --dataset_mode_val valimage \ 21 | --model comod \ 22 | --netG comodgan \ 23 | --netD comodgan \ 24 | --no_l1_loss \ 25 | --no_vgg_loss \ 26 | --preprocess_mode scale_shortside_and_crop \ 27 | $EXTRA 28 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | def find_trainer_using_name(model_name): 4 | model_filename = "trainers." + model_name + "_trainer" 5 | modellib = importlib.import_module(model_filename) 6 | 7 | # In the file, the class called ModelNameModel() will 8 | # be instantiated. It has to be a subclass of torch.nn.Module, 9 | # and it is case-insensitive. 10 | model = None 11 | target_model_name = model_name.replace('_', '') + 'trainer' 12 | for name, cls in modellib.__dict__.items(): 13 | if name.lower() == target_model_name.lower(): 14 | model = cls 15 | 16 | if model is None: 17 | print("In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name)) 18 | exit(0) 19 | 20 | return model 21 | 22 | 23 | def create_trainer(opt): 24 | model = find_trainer_using_name(opt.trainer) 25 | instance = model(opt) 26 | print("model [%s] was created" % (type(instance).__name__)) 27 | 28 | return instance 29 | -------------------------------------------------------------------------------- /trainers/stylegan2_trainer.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import torch 3 | from models.networks.sync_batchnorm import DataParallelWithCallback 4 | import models 5 | #from models.pix2pix_model import Pix2PixModel 6 | 7 | 8 | class StyleGAN2Trainer(): 9 | def __init__(self, opt): 10 | self.opt = opt 11 | self.pix2pix_model = models.create_model(opt) 12 | if len(opt.gpu_ids) > 0: 13 | self.pix2pix_model = DataParallelWithCallback(self.pix2pix_model, 14 | device_ids=opt.gpu_ids) 15 | self.pix2pix_model_on_one_gpu = self.pix2pix_model.module 16 | else: 17 | self.pix2pix_model_on_one_gpu = self.pix2pix_model 18 | 19 | self.generated = None 20 | self.inputs = None 21 | self.mean_path_length = torch.Tensor([0]) 22 | if opt.isTrain: 23 | self.optimizer_G, self.optimizer_D = \ 24 | self.pix2pix_model_on_one_gpu.create_optimizers(opt) 25 | self.old_lr = opt.lr 26 | 27 | def run_generator_one_step(self, data, i): 28 | self.optimizer_G.zero_grad() 29 | g_losses, inputs, generated = self.pix2pix_model(data, mode='generator') 30 | g_loss = sum(g_losses.values()).mean() 31 | g_loss.backward() 32 | self.optimizer_G.step() 33 | self.g_losses = g_losses 34 | self.generated = generated 35 | self.inputs = inputs 36 | g_regularize = (i % self.opt.g_reg_every == 0) and not (self.opt.no_g_reg) 37 | if g_regularize: 38 | self.optimizer_G.zero_grad() 39 | bsize = data['image'].size(0) 40 | data['mean_path_length'] = self.mean_path_length.expand(bsize) 41 | g_regs, self.mean_path_length \ 42 | = self.pix2pix_model(data, mode='g_reg') 43 | g_reg = sum(g_regs.values()).mean() 44 | g_reg.backward() 45 | self.optimizer_G.step() 46 | self.g_losses = { 47 | **g_losses, 48 | **g_regs} 49 | bsize = inputs.size(0) 50 | accum = 0.5 ** (bsize / (10 * 1000)) # 32 51 | self.pix2pix_model_on_one_gpu.accumulate(accum) 52 | 53 | def run_discriminator_one_step(self, data, i): 54 | self.optimizer_D.zero_grad() 55 | d_losses_real = self.pix2pix_model(data, mode='dreal') 56 | d_loss_real = sum(d_losses_real.values()).mean() 57 | d_loss_real.backward() 58 | d_losses_fake = self.pix2pix_model(data, mode='dfake') 59 | d_loss_fake = sum(d_losses_fake.values()).mean() 60 | d_loss_fake.backward() 61 | self.d_losses = { 62 | **d_losses_real, 63 | **d_losses_fake} 64 | self.optimizer_D.step() 65 | d_regularize = i % self.opt.d_reg_every == 0 66 | if d_regularize: 67 | self.optimizer_D.zero_grad() 68 | d_regs = self.pix2pix_model(data, mode='d_reg') 69 | d_reg = sum(d_regs.values()).mean() 70 | d_reg.backward() 71 | self.optimizer_D.step() 72 | self.d_losses = { 73 | **self.d_losses, 74 | **d_regs} 75 | 76 | def get_latest_losses(self): 77 | if not self.opt.freeze_D: 78 | return {**self.g_losses, **self.d_losses} 79 | else: 80 | return self.g_losses 81 | 82 | def get_latest_generated(self): 83 | return self.generated 84 | def get_latest_inputs(self): 85 | return self.inputs 86 | 87 | def update_learning_rate(self, epoch): 88 | self.update_learning_rate(epoch) 89 | 90 | def save(self, epoch): 91 | self.pix2pix_model_on_one_gpu.save(epoch) 92 | 93 | ################################################################## 94 | # Helper functions 95 | ################################################################## 96 | 97 | def update_learning_rate(self, epoch): 98 | if epoch > self.opt.niter: 99 | lrd = self.opt.lr / self.opt.niter_decay 100 | new_lr = self.old_lr - lrd 101 | else: 102 | new_lr = self.old_lr 103 | 104 | if new_lr != self.old_lr: 105 | if self.opt.no_TTUR: 106 | new_lr_G = new_lr 107 | new_lr_D = new_lr 108 | else: 109 | new_lr_G = new_lr / 2 110 | new_lr_D = new_lr * 2 111 | 112 | for param_group in self.optimizer_D.param_groups: 113 | param_group['lr'] = new_lr_D 114 | for param_group in self.optimizer_G.param_groups: 115 | param_group['lr'] = new_lr_G 116 | print('update learning rate: %f -> %f' % (self.old_lr, new_lr)) 117 | self.old_lr = new_lr 118 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | -------------------------------------------------------------------------------- /util/coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | 7 | def id2label(id): 8 | if id == 182: 9 | id = 0 10 | else: 11 | id = id + 1 12 | labelmap = \ 13 | {0: 'unlabeled', 14 | 1: 'person', 15 | 2: 'bicycle', 16 | 3: 'car', 17 | 4: 'motorcycle', 18 | 5: 'airplane', 19 | 6: 'bus', 20 | 7: 'train', 21 | 8: 'truck', 22 | 9: 'boat', 23 | 10: 'traffic light', 24 | 11: 'fire hydrant', 25 | 12: 'street sign', 26 | 13: 'stop sign', 27 | 14: 'parking meter', 28 | 15: 'bench', 29 | 16: 'bird', 30 | 17: 'cat', 31 | 18: 'dog', 32 | 19: 'horse', 33 | 20: 'sheep', 34 | 21: 'cow', 35 | 22: 'elephant', 36 | 23: 'bear', 37 | 24: 'zebra', 38 | 25: 'giraffe', 39 | 26: 'hat', 40 | 27: 'backpack', 41 | 28: 'umbrella', 42 | 29: 'shoe', 43 | 30: 'eye glasses', 44 | 31: 'handbag', 45 | 32: 'tie', 46 | 33: 'suitcase', 47 | 34: 'frisbee', 48 | 35: 'skis', 49 | 36: 'snowboard', 50 | 37: 'sports ball', 51 | 38: 'kite', 52 | 39: 'baseball bat', 53 | 40: 'baseball glove', 54 | 41: 'skateboard', 55 | 42: 'surfboard', 56 | 43: 'tennis racket', 57 | 44: 'bottle', 58 | 45: 'plate', 59 | 46: 'wine glass', 60 | 47: 'cup', 61 | 48: 'fork', 62 | 49: 'knife', 63 | 50: 'spoon', 64 | 51: 'bowl', 65 | 52: 'banana', 66 | 53: 'apple', 67 | 54: 'sandwich', 68 | 55: 'orange', 69 | 56: 'broccoli', 70 | 57: 'carrot', 71 | 58: 'hot dog', 72 | 59: 'pizza', 73 | 60: 'donut', 74 | 61: 'cake', 75 | 62: 'chair', 76 | 63: 'couch', 77 | 64: 'potted plant', 78 | 65: 'bed', 79 | 66: 'mirror', 80 | 67: 'dining table', 81 | 68: 'window', 82 | 69: 'desk', 83 | 70: 'toilet', 84 | 71: 'door', 85 | 72: 'tv', 86 | 73: 'laptop', 87 | 74: 'mouse', 88 | 75: 'remote', 89 | 76: 'keyboard', 90 | 77: 'cell phone', 91 | 78: 'microwave', 92 | 79: 'oven', 93 | 80: 'toaster', 94 | 81: 'sink', 95 | 82: 'refrigerator', 96 | 83: 'blender', 97 | 84: 'book', 98 | 85: 'clock', 99 | 86: 'vase', 100 | 87: 'scissors', 101 | 88: 'teddy bear', 102 | 89: 'hair drier', 103 | 90: 'toothbrush', 104 | 91: 'hair brush', # Last class of Thing 105 | 92: 'banner', # Beginning of Stuff 106 | 93: 'blanket', 107 | 94: 'branch', 108 | 95: 'bridge', 109 | 96: 'building-other', 110 | 97: 'bush', 111 | 98: 'cabinet', 112 | 99: 'cage', 113 | 100: 'cardboard', 114 | 101: 'carpet', 115 | 102: 'ceiling-other', 116 | 103: 'ceiling-tile', 117 | 104: 'cloth', 118 | 105: 'clothes', 119 | 106: 'clouds', 120 | 107: 'counter', 121 | 108: 'cupboard', 122 | 109: 'curtain', 123 | 110: 'desk-stuff', 124 | 111: 'dirt', 125 | 112: 'door-stuff', 126 | 113: 'fence', 127 | 114: 'floor-marble', 128 | 115: 'floor-other', 129 | 116: 'floor-stone', 130 | 117: 'floor-tile', 131 | 118: 'floor-wood', 132 | 119: 'flower', 133 | 120: 'fog', 134 | 121: 'food-other', 135 | 122: 'fruit', 136 | 123: 'furniture-other', 137 | 124: 'grass', 138 | 125: 'gravel', 139 | 126: 'ground-other', 140 | 127: 'hill', 141 | 128: 'house', 142 | 129: 'leaves', 143 | 130: 'light', 144 | 131: 'mat', 145 | 132: 'metal', 146 | 133: 'mirror-stuff', 147 | 134: 'moss', 148 | 135: 'mountain', 149 | 136: 'mud', 150 | 137: 'napkin', 151 | 138: 'net', 152 | 139: 'paper', 153 | 140: 'pavement', 154 | 141: 'pillow', 155 | 142: 'plant-other', 156 | 143: 'plastic', 157 | 144: 'platform', 158 | 145: 'playingfield', 159 | 146: 'railing', 160 | 147: 'railroad', 161 | 148: 'river', 162 | 149: 'road', 163 | 150: 'rock', 164 | 151: 'roof', 165 | 152: 'rug', 166 | 153: 'salad', 167 | 154: 'sand', 168 | 155: 'sea', 169 | 156: 'shelf', 170 | 157: 'sky-other', 171 | 158: 'skyscraper', 172 | 159: 'snow', 173 | 160: 'solid-other', 174 | 161: 'stairs', 175 | 162: 'stone', 176 | 163: 'straw', 177 | 164: 'structural-other', 178 | 165: 'table', 179 | 166: 'tent', 180 | 167: 'textile-other', 181 | 168: 'towel', 182 | 169: 'tree', 183 | 170: 'vegetable', 184 | 171: 'wall-brick', 185 | 172: 'wall-concrete', 186 | 173: 'wall-other', 187 | 174: 'wall-panel', 188 | 175: 'wall-stone', 189 | 176: 'wall-tile', 190 | 177: 'wall-wood', 191 | 178: 'water-other', 192 | 179: 'waterdrops', 193 | 180: 'window-blind', 194 | 181: 'window-other', 195 | 182: 'wood'} 196 | if id in labelmap: 197 | return labelmap[id] 198 | else: 199 | return 'unknown' 200 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import datetime 7 | import dominate 8 | from dominate.tags import * 9 | import os 10 | 11 | 12 | class HTML: 13 | def __init__(self, web_dir, title, refresh=0): 14 | if web_dir.endswith('.html'): 15 | web_dir, html_name = os.path.split(web_dir) 16 | else: 17 | web_dir, html_name = web_dir, 'index.html' 18 | self.title = title 19 | self.web_dir = web_dir 20 | self.html_name = html_name 21 | self.img_dir = os.path.join(self.web_dir, 'images') 22 | if len(self.web_dir) > 0 and not os.path.exists(self.web_dir): 23 | os.makedirs(self.web_dir) 24 | if len(self.web_dir) > 0 and not os.path.exists(self.img_dir): 25 | os.makedirs(self.img_dir) 26 | 27 | self.doc = dominate.document(title=title) 28 | with self.doc: 29 | h1(datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")) 30 | if refresh > 0: 31 | with self.doc.head: 32 | meta(http_equiv="refresh", content=str(refresh)) 33 | 34 | def get_image_dir(self): 35 | return self.img_dir 36 | 37 | def add_header(self, str): 38 | with self.doc: 39 | h3(str) 40 | 41 | def add_table(self, border=1): 42 | self.t = table(border=border, style="table-layout: fixed;") 43 | self.doc.add(self.t) 44 | 45 | def add_images(self, ims, txts, links, width=512): 46 | self.add_table() 47 | with self.t: 48 | with tr(): 49 | for im, txt, link in zip(ims, txts, links): 50 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 51 | with p(): 52 | with a(href=os.path.join('images', link)): 53 | img(style="width:%dpx" % (width), src=os.path.join('images', im)) 54 | br() 55 | p(txt.encode('utf-8')) 56 | 57 | def save(self): 58 | html_file = os.path.join(self.web_dir, self.html_name) 59 | f = open(html_file, 'wt') 60 | f.write(self.doc.render()) 61 | f.close() 62 | 63 | 64 | if __name__ == '__main__': 65 | html = HTML('web/', 'test_html') 66 | html.add_header('hello world') 67 | 68 | ims = [] 69 | txts = [] 70 | links = [] 71 | for n in range(4): 72 | ims.append('image_%d.jpg' % n) 73 | txts.append('text_%d' % n) 74 | links.append('image_%d.jpg' % n) 75 | html.add_images(ims, txts, links) 76 | html.save() 77 | -------------------------------------------------------------------------------- /util/iter_counter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import pdb 7 | import os 8 | import time 9 | import numpy as np 10 | 11 | 12 | # Helper class that keeps track of training iterations 13 | class IterationCounter(): 14 | def __init__(self, opt, dataset_size): 15 | self.opt = opt 16 | self.dataset_size = dataset_size 17 | 18 | self.first_epoch = 1 19 | self.total_epochs = opt.niter + opt.niter_decay 20 | self.epoch_iter = 0 # iter number within each epoch 21 | self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, 'iter.txt') 22 | if opt.isTrain and opt.continue_train: 23 | try: 24 | self.first_epoch, self.epoch_iter = np.loadtxt( 25 | self.iter_record_path, delimiter=',', dtype=int) 26 | print('Resuming from epoch %d at iteration %d' % (self.first_epoch, self.epoch_iter)) 27 | except: 28 | print('Could not load iteration record at %s. Starting from beginning.' % 29 | self.iter_record_path) 30 | 31 | self.total_steps_so_far = (self.first_epoch - 1) * dataset_size + self.epoch_iter 32 | 33 | # return the iterator of epochs for the training 34 | def training_epochs(self): 35 | return range(self.first_epoch, self.total_epochs + 1) 36 | 37 | def record_epoch_start(self, epoch): 38 | self.epoch_start_time = time.time() 39 | self.last_iter_time = time.time() 40 | self.current_epoch = epoch 41 | 42 | def record_one_iteration(self): 43 | current_time = time.time() 44 | 45 | # the last remaining batch is dropped (see data/__init__.py), 46 | # so we can assume batch size is always opt.batchSize 47 | self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize 48 | self.last_iter_time = current_time 49 | self.total_steps_so_far += self.opt.batchSize 50 | self.epoch_iter += self.opt.batchSize 51 | 52 | def record_epoch_end(self): 53 | self.epoch_iter = 0 54 | current_time = time.time() 55 | self.time_per_epoch = current_time - self.epoch_start_time 56 | print('End of epoch %d / %d \t Time Taken: %d sec' % 57 | (self.current_epoch, self.total_epochs, self.time_per_epoch)) 58 | if self.current_epoch % self.opt.save_epoch_freq == 0: 59 | np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0), 60 | delimiter=',', fmt='%d') 61 | print('Saved current iteration count at %s.' % self.iter_record_path) 62 | 63 | def record_current_iter(self): 64 | np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter), 65 | delimiter=',', fmt='%d') 66 | print('Saved current iteration count at %s.' % self.iter_record_path) 67 | 68 | def needs_saving(self): 69 | return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize 70 | 71 | def needs_validation(self): 72 | return (self.total_steps_so_far % self.opt.validation_freq) < self.opt.batchSize 73 | 74 | def needs_printing(self): 75 | return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize 76 | 77 | def needs_displaying(self): 78 | return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize 79 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import re 7 | import pdb 8 | import importlib 9 | import torch 10 | from argparse import Namespace 11 | import numpy as np 12 | from PIL import Image 13 | import os 14 | import argparse 15 | import dill as pickle 16 | import util.coco 17 | 18 | 19 | def save_obj(obj, name): 20 | with open(name, 'wb') as f: 21 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 22 | 23 | 24 | def load_obj(name): 25 | with open(name, 'rb') as f: 26 | return pickle.load(f) 27 | 28 | # returns a configuration for creating a generator 29 | # |default_opt| should be the opt of the current experiment 30 | # |**kwargs|: if any configuration should be overriden, it can be specified here 31 | 32 | 33 | def copyconf(default_opt, **kwargs): 34 | conf = argparse.Namespace(**vars(default_opt)) 35 | for key in kwargs: 36 | print(key, kwargs[key]) 37 | setattr(conf, key, kwargs[key]) 38 | return conf 39 | 40 | 41 | def tile_images(imgs, picturesPerRow=4): 42 | """ Code borrowed from 43 | https://stackoverflow.com/questions/26521365/cleanly-tile-numpy-array-of-images-stored-in-a-flattened-1d-format/26521997 44 | """ 45 | 46 | # Padding 47 | if imgs.shape[0] % picturesPerRow == 0: 48 | rowPadding = 0 49 | else: 50 | rowPadding = picturesPerRow - imgs.shape[0] % picturesPerRow 51 | if rowPadding > 0: 52 | imgs = np.concatenate([imgs, np.zeros((rowPadding, *imgs.shape[1:]), dtype=imgs.dtype)], axis=0) 53 | 54 | # Tiling Loop (The conditionals are not necessary anymore) 55 | tiled = [] 56 | for i in range(0, imgs.shape[0], picturesPerRow): 57 | tiled.append(np.concatenate([imgs[j] for j in range(i, i + picturesPerRow)], axis=1)) 58 | 59 | tiled = np.concatenate(tiled, axis=0) 60 | return tiled 61 | 62 | 63 | # Converts a Tensor into a Numpy array 64 | # |imtype|: the desired type of the converted numpy array 65 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False): 66 | if isinstance(image_tensor, list): 67 | image_numpy = [] 68 | for i in range(len(image_tensor)): 69 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) 70 | return image_numpy 71 | 72 | if image_tensor.dim() == 4: 73 | # transform each image in the batch 74 | images_np = [] 75 | for b in range(image_tensor.size(0)): 76 | one_image = image_tensor[b] 77 | one_image_np = tensor2im(one_image) 78 | images_np.append(one_image_np.reshape(1, *one_image_np.shape)) 79 | images_np = np.concatenate(images_np, axis=0) 80 | if tile: 81 | images_tiled = tile_images(images_np) 82 | return images_tiled 83 | else: 84 | return images_np 85 | 86 | if image_tensor.dim() == 2: 87 | image_tensor = image_tensor.unsqueeze(0) 88 | image_numpy = image_tensor.detach().cpu().float().numpy() 89 | if normalize: 90 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 91 | else: 92 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 93 | image_numpy = np.clip(image_numpy, 0, 255) 94 | if image_numpy.shape[2] == 1: 95 | image_numpy = image_numpy[:, :, 0] 96 | return image_numpy.astype(imtype) 97 | 98 | 99 | # Converts a one-hot tensor into a colorful label map 100 | def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False): 101 | if label_tensor.dim() == 4: 102 | # transform each image in the batch 103 | images_np = [] 104 | for b in range(label_tensor.size(0)): 105 | one_image = label_tensor[b] 106 | one_image_np = tensor2label(one_image, n_label, imtype) 107 | images_np.append(one_image_np.reshape(1, *one_image_np.shape)) 108 | images_np = np.concatenate(images_np, axis=0) 109 | if tile: 110 | images_tiled = tile_images(images_np) 111 | return images_tiled 112 | else: 113 | images_np = images_np[0] 114 | return images_np 115 | 116 | if label_tensor.dim() == 1: 117 | return np.zeros((64, 64, 3), dtype=np.uint8) 118 | if n_label == 0: 119 | return tensor2im(label_tensor, imtype) 120 | label_tensor = label_tensor.cpu().float() 121 | if label_tensor.size()[0] > 1: 122 | label_tensor = label_tensor.max(0, keepdim=True)[1] 123 | label_tensor = Colorize(n_label)(label_tensor) 124 | label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) 125 | result = label_numpy.astype(imtype) 126 | return result 127 | 128 | 129 | def save_image(image_numpy, image_path, create_dir=False): 130 | if create_dir: 131 | os.makedirs(os.path.dirname(image_path), exist_ok=True) 132 | if len(image_numpy.shape) == 2: 133 | image_numpy = np.expand_dims(image_numpy, axis=2) 134 | if image_numpy.shape[2] == 1: 135 | image_numpy = np.repeat(image_numpy, 3, 2) 136 | image_pil = Image.fromarray(image_numpy) 137 | 138 | # save to png 139 | image_pil.save(image_path.replace('.jpg', '.png')) 140 | 141 | 142 | def mkdirs(paths): 143 | if isinstance(paths, list) and not isinstance(paths, str): 144 | for path in paths: 145 | mkdir(path) 146 | else: 147 | mkdir(paths) 148 | 149 | 150 | def mkdir(path): 151 | if not os.path.exists(path): 152 | os.makedirs(path) 153 | 154 | 155 | def atoi(text): 156 | return int(text) if text.isdigit() else text 157 | 158 | 159 | def natural_keys(text): 160 | ''' 161 | alist.sort(key=natural_keys) sorts in human order 162 | http://nedbatchelder.com/blog/200712/human_sorting.html 163 | (See Toothy's implementation in the comments) 164 | ''' 165 | return [atoi(c) for c in re.split('(\d+)', text)] 166 | 167 | 168 | def natural_sort(items): 169 | items.sort(key=natural_keys) 170 | 171 | 172 | def str2bool(v): 173 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 174 | return True 175 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 176 | return False 177 | else: 178 | raise argparse.ArgumentTypeError('Boolean value expected.') 179 | 180 | 181 | def find_class_in_module(target_cls_name, module): 182 | target_cls_name = target_cls_name.replace('_', '').lower() 183 | clslib = importlib.import_module(module) 184 | cls = None 185 | for name, clsobj in clslib.__dict__.items(): 186 | if name.lower() == target_cls_name: 187 | cls = clsobj 188 | 189 | if cls is None: 190 | print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)) 191 | exit(0) 192 | 193 | return cls 194 | 195 | 196 | def save_network(net, label, epoch, opt): 197 | save_filename = '%s_net_%s.pth' % (epoch, label) 198 | save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename) 199 | torch.save(net.cpu().state_dict(), save_path) 200 | if len(opt.gpu_ids) and torch.cuda.is_available(): 201 | net.cuda() 202 | 203 | def load_network_path(net, save_path): 204 | weights = torch.load(save_path) 205 | new_dict = {} 206 | for k,v in weights.items(): 207 | #if k.startswith("module.conv16") or k.startswith("module.conv17"): 208 | # continue 209 | if k.startswith("module."): 210 | k=k.replace("module.","") 211 | new_dict[k] = v 212 | net.load_state_dict(new_dict, strict=False) 213 | #net.load_state_dict(new_dict) 214 | return net 215 | 216 | 217 | def load_network(net, label, epoch, opt): 218 | save_filename = '%s_net_%s.pth' % (epoch, label) 219 | save_dir = os.path.join(opt.checkpoints_dir, opt.name) 220 | save_path = os.path.join(save_dir, save_filename) 221 | weights = torch.load(save_path) 222 | print("==============load path: =================") 223 | print(save_path) 224 | new_dict = {} 225 | for k,v in weights.items(): 226 | if k.startswith("module."): 227 | k=k.replace("module.","") 228 | new_dict[k] = v 229 | net.load_state_dict(new_dict, strict=False) 230 | return net 231 | 232 | 233 | ############################################################################### 234 | # Code from 235 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py 236 | # Modified so it complies with the Citscape label map colors 237 | ############################################################################### 238 | def uint82bin(n, count=8): 239 | """returns the binary of integer n, count refers to amount of bits""" 240 | return ''.join([str((n >> y) & 1) for y in range(count - 1, -1, -1)]) 241 | 242 | 243 | def labelcolormap(N): 244 | if N == 35: # cityscape 245 | cmap = np.array([(0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (111, 74, 0), (81, 0, 81), 246 | (128, 64, 128), (244, 35, 232), (250, 170, 160), (230, 150, 140), (70, 70, 70), (102, 102, 156), (190, 153, 153), 247 | (180, 165, 180), (150, 100, 100), (150, 120, 90), (153, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0), 248 | (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70), 249 | (0, 60, 100), (0, 0, 90), (0, 0, 110), (0, 80, 100), (0, 0, 230), (119, 11, 32), (0, 0, 142)], 250 | dtype=np.uint8) 251 | else: 252 | cmap = np.zeros((N, 3), dtype=np.uint8) 253 | for i in range(N): 254 | r, g, b = 0, 0, 0 255 | id = i + 1 # let's give 0 a color 256 | for j in range(7): 257 | str_id = uint82bin(id) 258 | r = r ^ (np.uint8(str_id[-1]) << (7 - j)) 259 | g = g ^ (np.uint8(str_id[-2]) << (7 - j)) 260 | b = b ^ (np.uint8(str_id[-3]) << (7 - j)) 261 | id = id >> 3 262 | cmap[i, 0] = r 263 | cmap[i, 1] = g 264 | cmap[i, 2] = b 265 | 266 | if N == 182: # COCO 267 | important_colors = { 268 | 'sea': (54, 62, 167), 269 | 'sky-other': (95, 219, 255), 270 | 'tree': (140, 104, 47), 271 | 'clouds': (170, 170, 170), 272 | 'grass': (29, 195, 49) 273 | } 274 | for i in range(N): 275 | name = util.coco.id2label(i) 276 | if name in important_colors: 277 | color = important_colors[name] 278 | cmap[i] = np.array(list(color)) 279 | 280 | return cmap 281 | 282 | 283 | class Colorize(object): 284 | def __init__(self, n=35): 285 | self.cmap = labelcolormap(n) 286 | self.cmap = torch.from_numpy(self.cmap[:n]) 287 | 288 | def __call__(self, gray_image): 289 | size = gray_image.size() 290 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 291 | 292 | for label in range(0, len(self.cmap)): 293 | mask = (label == gray_image[0]).cpu() 294 | color_image[0][mask] = self.cmap[label][0] 295 | color_image[1][mask] = self.cmap[label][1] 296 | color_image[2][mask] = self.cmap[label][2] 297 | 298 | return color_image 299 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import os 7 | import ntpath 8 | import time 9 | from . import util 10 | from . import html 11 | import scipy.misc 12 | try: 13 | from StringIO import StringIO # Python 2.7 14 | except ImportError: 15 | from io import BytesIO # Python 3.x 16 | 17 | class Visualizer(): 18 | def __init__(self, opt): 19 | self.opt = opt 20 | self.tf_log = opt.isTrain and opt.tf_log 21 | self.use_html = opt.isTrain and not opt.no_html 22 | self.win_size = opt.display_winsize 23 | self.name = opt.name 24 | if self.tf_log: 25 | import tensorflow as tf 26 | self.tf = tf 27 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') 28 | self.writer = tf.summary.FileWriter(self.log_dir) 29 | 30 | if self.use_html: 31 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 32 | self.img_dir = os.path.join(self.web_dir, 'images') 33 | print('create web directory %s...' % self.web_dir) 34 | util.mkdirs([self.web_dir, self.img_dir]) 35 | if opt.isTrain: 36 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 37 | with open(self.log_name, "a") as log_file: 38 | now = time.strftime("%c") 39 | log_file.write('================ Training Loss (%s) ================\n' % now) 40 | 41 | # |visuals|: dictionary of images to display or save 42 | def display_current_results(self, visuals, epoch, step): 43 | 44 | ## convert tensors to numpy arrays 45 | visuals = self.convert_visuals_to_numpy(visuals) 46 | 47 | if self.tf_log: # show images in tensorboard output 48 | img_summaries = [] 49 | for label, image_numpy in visuals.items(): 50 | # Write the image to a string 51 | try: 52 | s = StringIO() 53 | except: 54 | s = BytesIO() 55 | if len(image_numpy.shape) >= 4: 56 | image_numpy = image_numpy[0] 57 | scipy.misc.toimage(image_numpy).save(s, format="jpeg") 58 | # Create an Image object 59 | img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) 60 | # Create a Summary value 61 | img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) 62 | 63 | # Create and write Summary 64 | summary = self.tf.Summary(value=img_summaries) 65 | self.writer.add_summary(summary, step) 66 | 67 | if self.use_html: # save images to a html file 68 | for label, image_numpy in visuals.items(): 69 | if isinstance(image_numpy, list): 70 | for i in range(len(image_numpy)): 71 | img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.3d_%s_%d.png' % (epoch, step, label, i)) 72 | util.save_image(image_numpy[i], img_path) 73 | else: 74 | img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.3d_%s.png' % (epoch, step, label)) 75 | if len(image_numpy.shape) >= 4: 76 | image_numpy = image_numpy[0] 77 | util.save_image(image_numpy, img_path) 78 | 79 | # update website 80 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=5) 81 | for n in range(epoch, 0, -1): 82 | webpage.add_header('epoch [%d]' % n) 83 | ims = [] 84 | txts = [] 85 | links = [] 86 | 87 | for label, image_numpy in visuals.items(): 88 | if isinstance(image_numpy, list): 89 | for i in range(len(image_numpy)): 90 | img_path = 'epoch%.3d_iter%.3d_%s_%d.png' % (n, step, label, i) 91 | ims.append(img_path) 92 | txts.append(label+str(i)) 93 | links.append(img_path) 94 | else: 95 | img_path = 'epoch%.3d_iter%.3d_%s.png' % (n, step, label) 96 | ims.append(img_path) 97 | txts.append(label) 98 | links.append(img_path) 99 | if len(ims) < 10: 100 | webpage.add_images(ims, txts, links, width=self.win_size) 101 | else: 102 | num = int(round(len(ims)/2.0)) 103 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) 104 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) 105 | webpage.save() 106 | 107 | # errors: dictionary of error labels and values 108 | def plot_current_errors(self, errors, step): 109 | if self.tf_log: 110 | for tag, value in errors.items(): 111 | value = value.mean().float() 112 | summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) 113 | self.writer.add_summary(summary, step) 114 | 115 | # errors: same format as |errors| of plotCurrentErrors 116 | def print_current_errors(self, epoch, i, errors, t): 117 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) 118 | for k, v in errors.items(): 119 | #print(v) 120 | #if v != 0: 121 | v = v.mean().float() 122 | message += '%s: %.3f ' % (k, v) 123 | 124 | print(message) 125 | with open(self.log_name, "a") as log_file: 126 | log_file.write('%s\n' % message) 127 | 128 | def convert_visuals_to_numpy(self, visuals): 129 | for key, t in visuals.items(): 130 | tile = self.opt.batchSize > 8 131 | if 'input_label' == key: 132 | t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) 133 | else: 134 | t = util.tensor2im(t, tile=tile) 135 | visuals[key] = t 136 | return visuals 137 | 138 | # save image to the disk 139 | def save_images(self, webpage, visuals, image_path): 140 | visuals = self.convert_visuals_to_numpy(visuals) 141 | 142 | image_dir = webpage.get_image_dir() 143 | short_path = ntpath.basename(image_path[0]) 144 | name = os.path.splitext(short_path)[0] 145 | 146 | webpage.add_header(name) 147 | ims = [] 148 | txts = [] 149 | links = [] 150 | 151 | for label, image_numpy in visuals.items(): 152 | image_name = os.path.join(label, '%s.png' % (name)) 153 | save_path = os.path.join(image_dir, image_name) 154 | util.save_image(image_numpy, save_path, create_dir=True) 155 | 156 | ims.append(image_name) 157 | txts.append(label) 158 | links.append(image_name) 159 | webpage.add_images(ims, txts, links, width=self.win_size) 160 | --------------------------------------------------------------------------------