├── .gitmodules
├── README.md
├── checkpoints
├── .gitkeep
├── comod-ffhq-1024
│ └── .gitkeep
├── comod-ffhq-512
│ └── .gitkeep
└── comod-places-512
│ └── .gitkeep
├── data
├── __init__.py
├── base_dataset.py
├── image_folder.py
├── testimage_dataset.py
├── trainimage_dataset.py
└── valimage_dataset.py
├── datasets
└── .gitkeep
├── download
├── data.sh
├── ffhq1024.sh
├── ffhq512.sh
└── places512.sh
├── ffhq_debug
├── 1.png
├── example_image.jpg
├── images.txt
├── images
│ └── 1.png
├── masks
│ └── 1.png
└── masks_inv
│ └── 1.png
├── imgs
├── example_image.jpg
├── example_mask.jpg
├── example_output.jpg
├── ffhq_in.png
└── ffhq_m.png
├── models
├── __init__.py
├── comod_model.py
├── create_mask.py
└── networks
│ ├── __init__.py
│ ├── architecture.py
│ ├── base_network.py
│ ├── co_mod_gan.py
│ ├── discriminator.py
│ ├── generator.py
│ ├── loss.py
│ ├── op
│ ├── __init__.py
│ ├── fused_act.py
│ ├── fused_bias_act.cpp
│ ├── fused_bias_act_kernel.cu
│ ├── upfirdn2d.cpp
│ ├── upfirdn2d.py
│ └── upfirdn2d_kernel.cu
│ └── stylegan2.py
├── options
├── __init__.py
├── base_options.py
├── test_options.py
└── train_options.py
├── output
└── .gitkeep
├── save_remote_gs.py
├── test.py
├── test.sh
├── train.py
├── train.sh
├── trainers
├── __init__.py
└── stylegan2_trainer.py
└── util
├── __init__.py
├── coco.py
├── html.py
├── iter_counter.py
├── util.py
└── visualizer.py
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "logger"]
2 | path = logger
3 | url = https://github.com/zengxianyu/logger
4 | [submodule "models/networks/sync_batchnorm"]
5 | path = models/networks/sync_batchnorm
6 | url = https://github.com/zengxianyu/sync_batchnorm
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # co-mod-gan-pytorch
2 | Implementation of the paper ``Large Scale Image Completion via Co-Modulated Generative Adversarial Networks"
3 |
4 | official tensorflow version: https://github.com/zsyzzsoft/co-mod-gan
5 |
6 | Input image
Mask
Result
7 |
8 | ## Usage
9 |
10 | ### requirments
11 | ```
12 | conda install pytorch torchvision cudatoolkit=11 -c pytorch
13 | conda install matplotlib jinja2 ninja dill
14 | pip install git+https://github.com/zengxianyu/pytorch-fid
15 | ```
16 |
17 | Download the code:
18 |
19 | ```
20 | git clone https://github.com/zengxianyu/co-mod-gan-pytorch
21 | git checkout train
22 | git submodule init
23 | git submodule update
24 | ```
25 |
26 | ### inference
27 |
28 | 1. download pretrained model using ``download/*.sh" (converted from the tensorflow pretrained model)
29 |
30 | e.g. ffhq512
31 |
32 | ```
33 | ./download/ffhq512.sh
34 | ```
35 |
36 | converted model:
37 | * FFHQ 512 checkpoints/comod-ffhq-512/co-mod-gan-ffhq-9-025000_net_G_ema.pth
38 | * FFHQ 1024 checkpoints/comod-ffhq-1024/co-mod-gan-ffhq-10-025000_net_G_ema.pth
39 | * Places 512 checkpoints/comod-places-512/co-mod-gan-places2-050000_net_G_ema.pth
40 |
41 | 2. use the following command as a minimal example of usage
42 |
43 | ```
44 | ./test.sh
45 | ```
46 |
47 | ### Training
48 | 1. download example datasets for training and validation
49 |
50 | ```
51 | ./download/data.sh
52 | ```
53 |
54 | 2. use the following command as a minimal example of usage
55 |
56 | ```
57 | ./train.sh
58 | ```
59 |
60 | ### Demo
61 | Coming soon
62 |
63 | ## Reference
64 |
65 | [1] official tensorflow version: https://github.com/zsyzzsoft/co-mod-gan
66 |
67 | [2] stylegan2-pytorch https://github.com/rosinality/stylegan2-pytorch
68 |
69 | [3] pix2pixHD https://github.com/NVIDIA/pix2pixHD
70 |
71 | [4] SPADE https://github.com/NVlabs/SPADE
72 |
--------------------------------------------------------------------------------
/checkpoints/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/checkpoints/.gitkeep
--------------------------------------------------------------------------------
/checkpoints/comod-ffhq-1024/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/checkpoints/comod-ffhq-1024/.gitkeep
--------------------------------------------------------------------------------
/checkpoints/comod-ffhq-512/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/checkpoints/comod-ffhq-512/.gitkeep
--------------------------------------------------------------------------------
/checkpoints/comod-places-512/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/checkpoints/comod-places-512/.gitkeep
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import importlib
7 | import torch.utils.data
8 | from data.base_dataset import BaseDataset
9 |
10 |
11 | def find_dataset_using_name(dataset_name):
12 | # Given the option --dataset [datasetname],
13 | # the file "datasets/datasetname_dataset.py"
14 | # will be imported.
15 | dataset_filename = "data." + dataset_name + "_dataset"
16 | datasetlib = importlib.import_module(dataset_filename)
17 |
18 | # In the file, the class called DatasetNameDataset() will
19 | # be instantiated. It has to be a subclass of BaseDataset,
20 | # and it is case-insensitive.
21 | dataset = None
22 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
23 | for name, cls in datasetlib.__dict__.items():
24 | if name.lower() == target_dataset_name.lower() \
25 | and issubclass(cls, BaseDataset):
26 | dataset = cls
27 |
28 | if dataset is None:
29 | raise ValueError("In %s.py, there should be a subclass of BaseDataset "
30 | "with class name that matches %s in lowercase." %
31 | (dataset_filename, target_dataset_name))
32 |
33 | return dataset
34 |
35 |
36 | def get_option_setter(dataset_name):
37 | dataset_class = find_dataset_using_name(dataset_name)
38 | return dataset_class.modify_commandline_options
39 |
40 |
41 | def create_dataloader(opt):
42 | dataset = find_dataset_using_name(opt.dataset_mode)
43 | instance = dataset()
44 | instance.initialize(opt)
45 | print("dataset [%s] of size %d was created" %
46 | (type(instance).__name__, len(instance)))
47 | dataloader = torch.utils.data.DataLoader(
48 | instance,
49 | batch_size=opt.batchSize,
50 | shuffle=not opt.serial_batches,
51 | num_workers=int(opt.nThreads),
52 | drop_last=opt.isTrain
53 | )
54 | return dataloader
55 |
56 | def create_dataloader_trainval(opt):
57 | assert opt.isTrain
58 | dataset = find_dataset_using_name(opt.dataset_mode_train)
59 | instance = dataset()
60 | instance.initialize(opt)
61 | print("dataset [%s] of size %d was created" %
62 | (type(instance).__name__, len(instance)))
63 | dataloader_train = torch.utils.data.DataLoader(
64 | instance,
65 | batch_size=opt.batchSize,
66 | shuffle=not opt.serial_batches,
67 | num_workers=int(opt.nThreads),
68 | drop_last=True
69 | )
70 | dataset = find_dataset_using_name(opt.dataset_mode_val)
71 | instance = dataset()
72 | instance.initialize(opt)
73 | print("dataset [%s] of size %d was created" %
74 | (type(instance).__name__, len(instance)))
75 | dataloader_val = torch.utils.data.DataLoader(
76 | instance,
77 | batch_size=opt.batchSize,
78 | shuffle=False,
79 | num_workers=int(opt.nThreads),
80 | drop_last=False
81 | )
82 | return dataloader_train, dataloader_val
83 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import torch.utils.data as data
7 | from PIL import Image
8 | import torchvision.transforms as transforms
9 | import numpy as np
10 | import random
11 |
12 |
13 | class BaseDataset(data.Dataset):
14 | def __init__(self):
15 | super(BaseDataset, self).__init__()
16 |
17 | @staticmethod
18 | def modify_commandline_options(parser, is_train):
19 | return parser
20 |
21 | def initialize(self, opt):
22 | pass
23 |
24 |
25 | def get_params(opt, size):
26 | w, h = size
27 | new_h = h
28 | new_w = w
29 | if opt.preprocess_mode == 'resize_and_crop':
30 | new_h = new_w = opt.load_size
31 | elif opt.preprocess_mode == 'scale_width_and_crop':
32 | new_w = opt.load_size
33 | new_h = opt.load_size * h // w
34 | elif opt.preprocess_mode == 'scale_shortside_and_crop':
35 | ss, ls = min(w, h), max(w, h) # shortside and longside
36 | width_is_shorter = w == ss
37 | ls = int(opt.load_size * ls / ss)
38 | new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss)
39 |
40 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
41 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
42 |
43 | flip = random.random() > 0.5
44 | return {'crop_pos': (x, y), 'flip': flip}
45 |
46 | def bbox_transform(opt, params, bbox, size, force_flip=False):
47 | w,h = size
48 | if opt.isTrain and (not opt.no_flip or force_flip):
49 | if params['flip']:
50 | bbox[0] = w-bbox[0]-bbox[2]
51 | rate_h = rate_w = 1
52 | if 'resize' in opt.preprocess_mode:
53 | rate_h = float(opt.load_size)/h
54 | rate_w = float(opt.load_size)/w
55 | elif 'scale_width' in opt.preprocess_mode:
56 | rate_w = rate_h = float(opt.load_size)/w
57 | elif 'scale_shortside' in opt.preprocess_mode:
58 | ss, ls = min(w, h), max(w, h) # shortside and longside
59 | width_is_shorter = w == ss
60 | if (ss == opt.load_size):
61 | rate_w = rate_h = 1
62 | ls = int(opt.load_size * ls / ss)
63 | nw, nh = (ss, ls) if width_is_shorter else (ls, ss)
64 | rate_h = float(nh)/h
65 | rate_w = float(nw)/w
66 | if opt.preprocess_mode == 'fixed':
67 | rate_w = float(opt.crop_size)/w
68 | ht = round(opt.crop_size / opt.aspect_ratio)
69 | rate_h = float(opt.crop_size)/ht
70 |
71 | bbox[0] *= rate_w
72 | bbox[2] *= rate_w
73 | bbox[1] *= rate_h
74 | bbox[3] *= rate_h
75 | w *= rate_w
76 | h *= rate_h
77 |
78 | if 'crop' in opt.preprocess_mode:
79 | x,y = params['crop_pos']
80 | bbox[0] -= x
81 | bbox[1] -= y
82 | w -= x
83 | h -= y
84 | y2 = bbox[1]+bbox[3]
85 | x2 = bbox[0]+bbox[2]
86 | y2 = min(opt.crop_size,y2)
87 | x2 = min(opt.crop_size,x2)
88 | x1 = max(0,bbox[0])
89 | y1 = max(0,bbox[1])
90 | bbox[0] = x1
91 | bbox[1] = y1
92 | bbox[2] = x2-x1
93 | bbox[3] = y2-y1
94 | if opt.preprocess_mode == 'none':
95 | base = 32
96 | _h = int(round(oh / base) * base)
97 | _w = int(round(ow / base) * base)
98 | rate_h = float(_h)/h
99 | rate_w = float(_w)/w
100 | bbox[0] *= rate_w
101 | bbox[2] *= rate_w
102 | bbox[1] *= rate_h
103 | bbox[3] *= rate_h
104 | return bbox
105 |
106 |
107 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True, force_flip=False):
108 | transform_list = []
109 | if opt.isTrain and (not opt.no_flip or force_flip):
110 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
111 |
112 | if 'resize' in opt.preprocess_mode:
113 | osize = [opt.load_size, opt.load_size]
114 | transform_list.append(transforms.Resize(osize, interpolation=method))
115 | elif 'scale_width' in opt.preprocess_mode:
116 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
117 | elif 'scale_shortside' in opt.preprocess_mode:
118 | transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method)))
119 |
120 | if 'crop' in opt.preprocess_mode:
121 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
122 |
123 | if opt.preprocess_mode == 'none':
124 | base = 32
125 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
126 |
127 | if opt.preprocess_mode == 'fixed':
128 | w = opt.crop_size
129 | h = round(opt.crop_size / opt.aspect_ratio)
130 | transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method)))
131 |
132 | if toTensor:
133 | transform_list += [transforms.ToTensor()]
134 |
135 | if normalize:
136 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
137 | (0.5, 0.5, 0.5))]
138 | return transforms.Compose(transform_list)
139 |
140 |
141 | def normalize():
142 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
143 |
144 |
145 | def __resize(img, w, h, method=Image.BICUBIC):
146 | return img.resize((w, h), method)
147 |
148 |
149 | def __make_power_2(img, base, method=Image.BICUBIC):
150 | ow, oh = img.size
151 | h = int(round(oh / base) * base)
152 | w = int(round(ow / base) * base)
153 | if (h == oh) and (w == ow):
154 | return img
155 | return img.resize((w, h), method)
156 |
157 |
158 | def __scale_width(img, target_width, method=Image.BICUBIC):
159 | ow, oh = img.size
160 | if (ow == target_width):
161 | return img
162 | w = target_width
163 | h = int(target_width * oh / ow)
164 | return img.resize((w, h), method)
165 |
166 |
167 | def __scale_shortside(img, target_width, method=Image.BICUBIC):
168 | ow, oh = img.size
169 | ss, ls = min(ow, oh), max(ow, oh) # shortside and longside
170 | width_is_shorter = ow == ss
171 | if (ss == target_width):
172 | return img
173 | ls = int(target_width * ls / ss)
174 | nw, nh = (ss, ls) if width_is_shorter else (ls, ss)
175 | return img.resize((nw, nh), method)
176 |
177 |
178 | def __crop(img, pos, size):
179 | ow, oh = img.size
180 | x1, y1 = pos
181 | tw = th = size
182 | return img.crop((x1, y1, x1 + tw, y1 + th))
183 |
184 |
185 | def __flip(img, flip):
186 | if flip:
187 | return img.transpose(Image.FLIP_LEFT_RIGHT)
188 | return img
189 |
--------------------------------------------------------------------------------
/data/image_folder.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | ###############################################################################
7 | # Code from
8 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
9 | # Modified the original code so that it also loads images from the current
10 | # directory as well as the subdirectories
11 | ###############################################################################
12 | import torch.utils.data as data
13 | from PIL import Image
14 | import os
15 |
16 | IMG_EXTENSIONS = [
17 | '.jpg', '.JPG', '.jpeg', '.JPEG',
18 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff', '.webp'
19 | ]
20 |
21 |
22 | def is_image_file(filename):
23 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
24 |
25 |
26 | def make_dataset_rec(dir, images):
27 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
28 |
29 | for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)):
30 | for fname in fnames:
31 | if is_image_file(fname):
32 | path = os.path.join(root, fname)
33 | images.append(path)
34 |
35 |
36 | def make_dataset(dir, recursive=False, read_cache=False, write_cache=False):
37 | images = []
38 |
39 | if read_cache:
40 | possible_filelist = os.path.join(dir, 'files.list')
41 | if os.path.isfile(possible_filelist):
42 | with open(possible_filelist, 'r') as f:
43 | images = f.read().splitlines()
44 | return images
45 |
46 | if recursive:
47 | make_dataset_rec(dir, images)
48 | else:
49 | assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
50 |
51 | for root, dnames, fnames in sorted(os.walk(dir)):
52 | for fname in fnames:
53 | if is_image_file(fname):
54 | path = os.path.join(root, fname)
55 | images.append(path)
56 |
57 | if write_cache:
58 | filelist_cache = os.path.join(dir, 'files.list')
59 | with open(filelist_cache, 'w') as f:
60 | for path in images:
61 | f.write("%s\n" % path)
62 | print('wrote filelist cache at %s' % filelist_cache)
63 |
64 | return images
65 |
66 |
67 | def default_loader(path):
68 | return Image.open(path).convert('RGB')
69 |
70 |
71 | class ImageFolder(data.Dataset):
72 |
73 | def __init__(self, root, transform=None, return_paths=False,
74 | loader=default_loader):
75 | imgs = make_dataset(root)
76 | if len(imgs) == 0:
77 | raise(RuntimeError("Found 0 images in: " + root + "\n"
78 | "Supported image extensions are: " +
79 | ",".join(IMG_EXTENSIONS)))
80 |
81 | self.root = root
82 | self.imgs = imgs
83 | self.transform = transform
84 | self.return_paths = return_paths
85 | self.loader = loader
86 |
87 | def __getitem__(self, index):
88 | path = self.imgs[index]
89 | img = self.loader(path)
90 | if self.transform is not None:
91 | img = self.transform(img)
92 | if self.return_paths:
93 | return img, path
94 | else:
95 | return img
96 |
97 | def __len__(self):
98 | return len(self.imgs)
99 |
--------------------------------------------------------------------------------
/data/testimage_dataset.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms as transforms
2 | import torch
3 | from data.base_dataset import get_params, get_transform, BaseDataset
4 | from PIL import Image
5 | from data.image_folder import make_dataset
6 | import os
7 | import pdb
8 |
9 |
10 | class TestImageDataset(BaseDataset):
11 | """ Dataset that loads images from directories
12 | Use option --label_dir, --image_dir, --instance_dir to specify the directories.
13 | The images in the directories are sorted in alphabetical order and paired in order.
14 | """
15 |
16 | @staticmethod
17 | def modify_commandline_options(parser, is_train):
18 | parser.add_argument('--list_dir', type=str, required=False,
19 | help='path to the directory that contains photo images')
20 | parser.add_argument('--image_dir', type=str, required=True,
21 | help='path to the directory that contains photo images')
22 | parser.add_argument('--mask_dir', type=str, required=True,
23 | help='path to the directory that contains photo images')
24 | parser.add_argument('--output_dir', type=str, required=True,
25 | help='path to the directory that contains photo images')
26 | return parser
27 |
28 | def initialize(self, opt):
29 | self.opt = opt
30 | if not os.path.exists(opt.output_dir):
31 | os.mkdir(opt.output_dir)
32 |
33 | image_paths, mask_paths, output_paths = self.get_paths(opt)
34 |
35 | self.image_paths = image_paths
36 | self.mask_paths = mask_paths
37 | self.output_paths = output_paths
38 |
39 | size = len(self.image_paths)
40 | self.dataset_size = size
41 | transform_list = [
42 | transforms.ToTensor(),
43 | transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
44 | ]
45 | self.image_transform = transforms.Compose(transform_list)
46 | self.mask_transform = transforms.Compose([
47 | transforms.ToTensor()
48 | ])
49 |
50 | def get_paths(self, opt):
51 | img_names = os.listdir(opt.image_dir)
52 | img_postfix = img_names[0].split(".")[-1]
53 | if opt.list_dir is not None:
54 | with open(opt.list_dir, "r") as f:
55 | msk_names = f.readlines()
56 | msk_names = [n.strip("\n") for n in msk_names]
57 | else:
58 | msk_names = os.listdir(opt.mask_dir)
59 | img_names = [n.replace("png", img_postfix) for n in msk_names]
60 | image_paths = [f"{opt.image_dir}/{n}" for n in img_names]
61 | output_paths = [f"{opt.output_dir}/{n}" for n in img_names]
62 | mask_paths = [f"{opt.mask_dir}/{n}" for n in msk_names]
63 |
64 | return image_paths, mask_paths, output_paths
65 |
66 | def __len__(self):
67 | return self.dataset_size
68 |
69 | def __getitem__(self, index):
70 | # input image (real images)
71 | output_path = self.output_paths[index]
72 | image_path = self.image_paths[index]
73 | image = Image.open(image_path)
74 | image = image.convert('RGB')
75 | w, h = image.size
76 | image_tensor = self.image_transform(image)
77 | # mask image
78 | mask_path = self.mask_paths[index]
79 | mask = Image.open(mask_path)
80 | mask = mask.convert("L")
81 | mask = mask.resize((w,h))
82 | mask_tensor = self.mask_transform(mask)
83 | mask_tensor = (mask_tensor>0).float()
84 | input_dict = {
85 | 'image': image_tensor,
86 | 'mask': mask_tensor,
87 | 'path': output_path,
88 | }
89 |
90 | return input_dict
91 |
--------------------------------------------------------------------------------
/data/trainimage_dataset.py:
--------------------------------------------------------------------------------
1 | from data.base_dataset import get_params, get_transform, BaseDataset
2 | from PIL import Image
3 | from data.image_folder import make_dataset
4 | import os
5 | import pdb
6 |
7 |
8 | class TrainImageDataset(BaseDataset):
9 | """ Dataset that loads images from directories
10 | Use option --label_dir, --image_dir, --instance_dir to specify the directories.
11 | The images in the directories are sorted in alphabetical order and paired in order.
12 | """
13 |
14 | @staticmethod
15 | def modify_commandline_options(parser, is_train):
16 | parser.add_argument('--train_image_dir', type=str, required=True,
17 | help='path to the directory that contains photo images')
18 | parser.add_argument('--train_image_list', type=str, required=True,
19 | help='path to the directory that contains photo images')
20 | parser.add_argument('--train_image_postfix', type=str, default="",
21 | help='path to the directory that contains photo images')
22 | return parser
23 |
24 | def initialize(self, opt):
25 | self.opt = opt
26 | image_paths = self.get_paths(opt)
27 |
28 | self.image_paths = image_paths
29 |
30 | size = len(self.image_paths)
31 | self.dataset_size = size
32 |
33 | def get_paths(self, opt):
34 | image_dir = opt.train_image_dir
35 | image_list = opt.train_image_list
36 | names = open(image_list).readlines()
37 | filenames = list(map(lambda x: x.strip('\n')+opt.train_image_postfix, names))
38 | image_paths = list(map(lambda x: os.path.join(image_dir, x), filenames))
39 | return image_paths
40 |
41 | def __len__(self):
42 | return self.dataset_size
43 |
44 | def __getitem__(self, index):
45 | # input image (real images)
46 | image_path = self.image_paths[index]
47 | image = Image.open(image_path)
48 | image = image.convert('RGB')
49 | params = get_params(self.opt, image.size)
50 | transform_image = get_transform(self.opt, params)
51 | image_tensor = transform_image(image)
52 | input_dict = {
53 | 'image': image_tensor,
54 | 'path': image_path,
55 | }
56 | return input_dict
57 | #except:
58 | # print(f"skip {image_path}")
59 | # return self.__getitem__((index+1)%self.__len__())
60 |
--------------------------------------------------------------------------------
/data/valimage_dataset.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms as transforms
2 | import torch
3 | from data.base_dataset import get_params, get_transform, BaseDataset
4 | from PIL import Image
5 | from data.image_folder import make_dataset
6 | import os
7 | import pdb
8 |
9 |
10 | class ValImageDataset(BaseDataset):
11 | """ Dataset that loads images from directories
12 | Use option --label_dir, --image_dir, --instance_dir to specify the directories.
13 | The images in the directories are sorted in alphabetical order and paired in order.
14 | """
15 |
16 | @staticmethod
17 | def modify_commandline_options(parser, is_train):
18 | parser.add_argument('--val_image_dir', type=str, required=True,
19 | help='path to the directory that contains photo images')
20 | parser.add_argument('--val_image_list', type=str, required=True,
21 | help='path to the directory that contains photo images')
22 | parser.add_argument('--val_mask_dir', type=str, required=True,
23 | help='path to the directory that contains photo images')
24 | parser.add_argument('--val_image_postfix', type=str, default=".jpg",
25 | help='path to the directory that contains photo images')
26 | parser.add_argument('--val_mask_postfix', type=str, default=".png",
27 | help='path to the directory that contains photo images')
28 | return parser
29 |
30 | def initialize(self, opt):
31 | self.opt = opt
32 |
33 | image_paths, mask_paths = self.get_paths(opt)
34 |
35 | self.image_paths = image_paths
36 | self.mask_paths = mask_paths
37 |
38 | size = len(self.image_paths)
39 | self.dataset_size = size
40 | transform_list = [
41 | transforms.Resize((opt.crop_size, opt.crop_size),
42 | interpolation=Image.NEAREST),
43 | transforms.ToTensor(),
44 | transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
45 | ]
46 | self.image_transform = transforms.Compose(transform_list)
47 | self.mask_transform = transforms.Compose([
48 | transforms.Resize((opt.crop_size, opt.crop_size),interpolation=Image.NEAREST),
49 | transforms.ToTensor()
50 | ])
51 |
52 | def get_paths(self, opt):
53 | image_dir = opt.val_image_dir
54 | image_list = opt.val_image_list
55 | names = open(image_list).readlines()
56 | filenames = list(map(lambda x: x.strip('\n')+opt.val_image_postfix, names))
57 | image_paths = list(map(lambda x: os.path.join(image_dir, x), filenames))
58 | filenames = list(map(lambda x: x.strip('\n')+opt.val_mask_postfix, names))
59 | mask_paths = list(map(lambda x: os.path.join(opt.val_mask_dir, x), filenames))
60 | return image_paths, mask_paths
61 |
62 | def __len__(self):
63 | return self.dataset_size
64 |
65 | def __getitem__(self, index):
66 | # input image (real images)
67 | image_path = self.image_paths[index]
68 | image = Image.open(image_path)
69 | image = image.convert('RGB')
70 | w, h = image.size
71 | image_tensor = self.image_transform(image)
72 | # mask image
73 | mask_path = self.mask_paths[index]
74 | mask = Image.open(mask_path)
75 | mask = mask.convert("L")
76 | mask = mask.resize((w,h))
77 | mask_tensor = self.mask_transform(mask)
78 | mask_tensor = (mask_tensor>0).float()
79 | input_dict = {
80 | 'image': image_tensor,
81 | 'mask': mask_tensor,
82 | 'path': image_path,
83 | }
84 |
85 | return input_dict
86 |
--------------------------------------------------------------------------------
/datasets/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/datasets/.gitkeep
--------------------------------------------------------------------------------
/download/data.sh:
--------------------------------------------------------------------------------
1 | wget https://maildluteducn-my.sharepoint.com/:u:/g/personal/zengyu_mail_dlut_edu_cn/Ed6KS2wg-olJsLicvZOpUHkB4nak9nYtJPXxwvM8W_d9PQ?download=1
2 | mv Ed6KS2wg-olJsLicvZOpUHkB4nak9nYtJPXxwvM8W_d9PQ?download=1 ./datasets/places2sample1k_val.zip
3 | unzip datasets/places2sample1k_val.zip -d datasets/
4 |
--------------------------------------------------------------------------------
/download/ffhq1024.sh:
--------------------------------------------------------------------------------
1 | wget https://maildluteducn-my.sharepoint.com/:u:/g/personal/zengyu_mail_dlut_edu_cn/EXcZ9OHEFQFAqhe1GQVFZ_4BwWoTvWyDM429gP5XrPAdaQ?download=1
2 | mv EXcZ9OHEFQFAqhe1GQVFZ_4BwWoTvWyDM429gP5XrPAdaQ?download=1 ./checkpoints/comod-ffhq-1024/co-mod-gan-ffhq-10-025000_net_G_ema.pth
3 |
--------------------------------------------------------------------------------
/download/ffhq512.sh:
--------------------------------------------------------------------------------
1 | #wget https://maildluteducn-my.sharepoint.com/:u:/g/personal/zengyu_mail_dlut_edu_cn/Ee1YPJG2Y7NDnUjJBf-SipoBBSlbv8QfFy6K7lsiiiiFHg?download=1
2 | mv Ee1YPJG2Y7NDnUjJBf-SipoBBSlbv8QfFy6K7lsiiiiFHg?download=1 ./checkpoints/comod-ffhq-512/co-mod-gan-ffhq-9-025000_net_G_ema.pth
3 |
--------------------------------------------------------------------------------
/download/places512.sh:
--------------------------------------------------------------------------------
1 | wget https://maildluteducn-my.sharepoint.com/:u:/g/personal/zengyu_mail_dlut_edu_cn/EQG9jJzkFLJDsOWmVJJSoqQB2jRDkXlYt3wnt9Fb9dJDsQ?download=1
2 | mv EQG9jJzkFLJDsOWmVJJSoqQB2jRDkXlYt3wnt9Fb9dJDsQ?download=1 ./checkpoints/comod-places-512/co-mod-gan-places2-050000_net_G_ema.pth
3 |
--------------------------------------------------------------------------------
/ffhq_debug/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/ffhq_debug/1.png
--------------------------------------------------------------------------------
/ffhq_debug/example_image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/ffhq_debug/example_image.jpg
--------------------------------------------------------------------------------
/ffhq_debug/images.txt:
--------------------------------------------------------------------------------
1 | 1.png
2 |
--------------------------------------------------------------------------------
/ffhq_debug/images/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/ffhq_debug/images/1.png
--------------------------------------------------------------------------------
/ffhq_debug/masks/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/ffhq_debug/masks/1.png
--------------------------------------------------------------------------------
/ffhq_debug/masks_inv/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/ffhq_debug/masks_inv/1.png
--------------------------------------------------------------------------------
/imgs/example_image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/imgs/example_image.jpg
--------------------------------------------------------------------------------
/imgs/example_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/imgs/example_mask.jpg
--------------------------------------------------------------------------------
/imgs/example_output.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/imgs/example_output.jpg
--------------------------------------------------------------------------------
/imgs/ffhq_in.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/imgs/ffhq_in.png
--------------------------------------------------------------------------------
/imgs/ffhq_m.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/imgs/ffhq_m.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import importlib
7 | import torch
8 |
9 |
10 | def find_model_using_name(model_name):
11 | # Given the option --model [modelname],
12 | # the file "models/modelname_model.py"
13 | # will be imported.
14 | model_filename = "models." + model_name + "_model"
15 | modellib = importlib.import_module(model_filename)
16 |
17 | # In the file, the class called ModelNameModel() will
18 | # be instantiated. It has to be a subclass of torch.nn.Module,
19 | # and it is case-insensitive.
20 | model = None
21 | target_model_name = model_name.replace('_', '') + 'model'
22 | for name, cls in modellib.__dict__.items():
23 | if name.lower() == target_model_name.lower() \
24 | and issubclass(cls, torch.nn.Module):
25 | model = cls
26 |
27 | if model is None:
28 | print("In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name))
29 | exit(0)
30 |
31 | return model
32 |
33 |
34 | def get_option_setter(model_name):
35 | model_class = find_model_using_name(model_name)
36 | return model_class.modify_commandline_options
37 |
38 |
39 | def create_model(opt):
40 | model = find_model_using_name(opt.model)
41 | instance = model(opt)
42 | print("model [%s] was created" % (type(instance).__name__))
43 |
44 | return instance
45 |
--------------------------------------------------------------------------------
/models/comod_model.py:
--------------------------------------------------------------------------------
1 | import pdb
2 | import torch.nn.functional as F
3 | import torchvision.ops as ops
4 | import math
5 | import torch
6 | import models.networks as networks
7 | import util.util as util
8 | import random
9 | import numpy as np
10 | from models.create_mask import MaskCreator
11 |
12 |
13 | class CoModModel(torch.nn.Module):
14 | @staticmethod
15 | def modify_commandline_options(parser, is_train):
16 | networks.modify_commandline_options(parser, is_train)
17 | parser.add_argument('--no_g_reg', action='store_true')
18 | parser.add_argument('--path_objectshape_base', type=str, required=False, help='path obj base')
19 | parser.add_argument('--path_objectshape_list', type=str, required=False, help='path obj list')
20 | parser.add_argument('--mixing', type=float, default=0.9)
21 | parser.add_argument('--r1', type=float, default=10)
22 | parser.add_argument('--d_reg_every', type=int, default=16)
23 | parser.add_argument('--g_reg_every', type=int, default=4)
24 | parser.add_argument('--path_batch_shrink', type=int, default=2)
25 | parser.add_argument('--truncation', type=float, required=False)
26 | parser.add_argument('--path_regularize', type=int, default=2)
27 | parser.set_defaults(init_type=None)
28 | parser.set_defaults(gan_mode='softplus')
29 | parser.set_defaults(lr=0.002)
30 | parser.set_defaults(z_dim=512)
31 | # factor
32 | parser.add_argument('--factor', type=str, required=False)
33 | parser.add_argument('--factor_d', type=int, default=5)
34 | parser.add_argument('--factor_i', type=int, default=0)
35 | parser.add_argument('--load_pretrained_g', type=str, required=False, help='load pt g')
36 | parser.add_argument('--load_pretrained_g_ema', type=str, required=False, help='load pt g')
37 | parser.add_argument('--load_pretrained_d', type=str, required=False, help='load pt d')
38 | return parser
39 |
40 | def __init__(self, opt):
41 | super().__init__()
42 | self.opt = opt
43 | self.truncation_mean = None
44 |
45 | self.device = torch.device("cuda") if self.use_gpu() \
46 | else torch.device("cpu")
47 | self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
48 | else torch.FloatTensor
49 | self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
50 | else torch.ByteTensor
51 |
52 | self.netG, self.netG_ema, self.netD = self.initialize_networks(opt)
53 | if opt.factor is not None:
54 | self.eigvec = torch.load(opt.factor)["eigvec"].to(self.device)
55 | # set loss functions
56 | if opt.isTrain:
57 | if not opt.continue_train:
58 | if opt.load_pretrained_g is not None:
59 | print(f"looad {opt.load_pretrained_g}")
60 | self.netG = util.load_network_path(
61 | self.netG, opt.load_pretrained_g)
62 | if opt.load_pretrained_g_ema is not None:
63 | print(f"looad {opt.load_pretrained_g}")
64 | self.netG_ema = util.load_network_path(
65 | self.netG_ema, opt.load_pretrained_g_ema)
66 | if opt.load_pretrained_d is not None:
67 | print(f"looad {opt.load_pretrained_d}")
68 | self.netD = util.load_network_path(
69 | self.netD, opt.load_pretrained_d)
70 | self.mask_creator = MaskCreator(opt.path_objectshape_list, opt.path_objectshape_base)
71 | self.criterionGAN = networks.GANLoss(
72 | opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)
73 | if not opt.no_vgg_loss:
74 | self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
75 | if opt.truncation is not None:
76 | self.truncation_mean = self.mean_latent(4096)
77 |
78 | def accumulate(self, decay=0.999):
79 | par1 = dict(self.netG_ema.named_parameters())
80 | par2 = dict(self.netG.named_parameters())
81 |
82 | for k in par1.keys():
83 | par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
84 |
85 | # set loss functions
86 |
87 | # Entry point for all calls involving forward pass
88 | # of deep networks. We used this approach since DataParallel module
89 | # can't parallelize custom functions, we branch to different
90 | # routines based on |mode|.
91 | def forward(self, data, mode):
92 | real_image, mask, mean_path_length = self.preprocess_input(data)
93 | bsize = real_image.size(0)
94 | if mode == 'generator':
95 | g_loss, fake_image = self.compute_generator_loss(real_image, mask)
96 | generated = {'fake':fake_image,
97 | 'input':real_image*(1-mask),
98 | 'gt':real_image,
99 | }
100 | return g_loss, real_image, generated
101 | elif mode == 'dreal':
102 | d_loss = self.compute_discriminator_loss(
103 | real_image, fake_image=None, mask=mask)
104 | return d_loss
105 | elif mode == 'dfake':
106 | with torch.no_grad():
107 | fake_image, uc_image,_ = self.generate_fake(real_image, mask)
108 | fake_image = fake_image.detach()
109 | fake_image.requires_grad_()
110 | d_loss = self.compute_discriminator_loss(
111 | real_image=None, fake_image=fake_image, mask=mask)
112 | return d_loss
113 | elif mode == 'd_reg':
114 | d_regs = self.compute_discriminator_reg(real_image, mask)
115 | return d_regs
116 | elif mode == 'g_reg':
117 | g_regs, path_lengths, mean_path_length = self.compute_generator_reg(
118 | real_image,
119 | mask,
120 | mean_path_length)
121 | return g_regs, mean_path_length
122 | elif mode == 'inference':
123 | with torch.no_grad():
124 | if self.opt.factor is None:
125 | fake_image, uc_image,_ = self.generate_fake(real_image, mask, ema=True)
126 | else:
127 | fake_image, _ = self.factorize_fake(real_image, mask)
128 | inp = real_image*(1-mask)
129 | return fake_image, inp
130 | else:
131 | raise ValueError("|mode| is invalid")
132 |
133 | def create_optimizers(self, opt):
134 | G_params = list(self.netG.parameters())
135 | #G_params = [p for name, p in self.netG.named_parameters() \
136 | # if (not name.startswith("coarse"))]
137 | if opt.isTrain:
138 | D_params = list(self.netD.parameters())
139 |
140 | g_reg_ratio = self.opt.g_reg_every / (self.opt.g_reg_every + 1)
141 | d_reg_ratio = self.opt.d_reg_every / (self.opt.d_reg_every + 1)
142 |
143 | g_optim = torch.optim.Adam(
144 | G_params,
145 | lr=self.opt.lr * g_reg_ratio,
146 | betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
147 | )
148 | d_optim = torch.optim.Adam(
149 | D_params,
150 | lr=self.opt.lr * d_reg_ratio,
151 | betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
152 | )
153 |
154 | return g_optim, d_optim
155 |
156 | def save(self, epoch):
157 | util.save_network(self.netG, 'G', epoch, self.opt)
158 | util.save_network(self.netG_ema, 'G_ema', epoch, self.opt)
159 | util.save_network(self.netD, 'D', epoch, self.opt)
160 |
161 | ############################################################################
162 | # Private helper methods
163 | ############################################################################
164 |
165 | def initialize_networks(self, opt):
166 | netG_ema = networks.define_G(opt)
167 | if opt.isTrain:
168 | netG = networks.define_G(opt)
169 | netD = networks.define_D(opt)
170 | else:
171 | netD=None
172 | netG=None
173 |
174 | if not opt.isTrain or opt.continue_train:
175 | netG_ema = util.load_network(netG_ema, 'G_ema', opt.which_epoch, opt)
176 | if opt.isTrain:
177 | netG = util.load_network(netG, 'G', opt.which_epoch, opt)
178 | netD = util.load_network(netD, 'D', opt.which_epoch, opt)
179 | return netG, netG_ema, netD
180 |
181 | # preprocess the input, such as moving the tensors to GPUs and
182 | # transforming the label map to one-hot encoding
183 | # |data|: dictionary of the input data
184 |
185 | def mean_latent(self, n_latent):
186 | self.netG_ema.eval()
187 | latent_in = torch.randn(n_latent, self.opt.z_dim, device=self.device)
188 | dlatent = self.netG_ema(latents_in=[latent_in], get_latent=True)[0]
189 | latent_mean = dlatent.mean(0, keepdim=True)
190 | self.truncation_mean = latent_mean
191 | return self.truncation_mean
192 |
193 | def make_noise(self, batch, n_noise):
194 | if n_noise == 1:
195 | return torch.randn(batch, self.opt.z_dim, device=self.device)
196 |
197 | noises = torch.randn(n_noise, batch, self.opt.z_dim,
198 | device=self.device).unbind(0)
199 |
200 | return noises
201 |
202 | def make_mask(self, data):
203 | b,c,h,w = data['image'].shape
204 | if self.opt.isTrain:
205 | # generate random stroke mask
206 | mask1 = self.mask_creator.stroke_mask(h, w, max_length=min(h,w)/2)
207 | # generate object/square mask
208 | ri = random.randint(0,3)
209 | if self.opt.path_objectshape_base is not None and (ri == 1 or ri == 0):
210 | mask2 = self.mask_creator.object_mask(h, w)
211 | else:
212 | mask2 = self.mask_creator.rectangle_mask(h, w,
213 | min(h,w)//4, min(h,w)//2)
214 | # use the mix of two masks
215 | mask = (mask1+mask2>0)
216 | mask = mask.astype(np.float)
217 | mask = self.FloatTensor(mask)[None, None,...].expand(b,-1,-1,-1)
218 | data['mask'] = mask
219 | else:
220 | if self.use_gpu():
221 | data['mask'] = data['mask'].cuda()
222 | mask = data['mask']
223 | return mask
224 |
225 | def mixing_noise(self, batch):
226 | if self.opt.mixing > 0 and random.random() < self.opt.mixing:
227 | noise = self.make_noise(batch, 2)
228 | return noise
229 | else:
230 | return [self.make_noise(batch, 1)]
231 |
232 | def preprocess_input(self, data):
233 | b,c,h,w = data['image'].shape
234 | if 'mask' in data:
235 | if self.use_gpu():
236 | data['mask'] = data['mask'].cuda()
237 | mask = data['mask']
238 | else:
239 | mask = self.make_mask(data)
240 | if self.use_gpu():
241 | data['image'] = data['image'].cuda()
242 | if 'mean_path_length' in data:
243 | mean_path_length = data['mean_path_length'].detach().cuda()
244 | else:
245 | mean_path_length = 0
246 | return data['image'], data['mask'], mean_path_length
247 |
248 | def g_path_regularize(self, fake_image, latents, mean_path_length, decay=0.01):
249 | noise = torch.randn_like(fake_image) / math.sqrt(
250 | fake_image.shape[2] * fake_image.shape[3]
251 | )
252 | grad, = torch.autograd.grad(
253 | outputs=(fake_image * noise).sum(), inputs=latents, create_graph=True
254 | )
255 | path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
256 |
257 | path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
258 |
259 | path_penalty = (path_lengths - path_mean).pow(2).mean()
260 |
261 | return path_penalty, path_mean.detach(), path_lengths
262 |
263 | def compute_generator_reg(self, real_image, mask, mean_path_length):
264 | G_regs = {}
265 | bsize = real_image.size(0)
266 | path_batch_size = max(1, bsize // self.opt.path_batch_shrink)
267 | fake_image, _, latents = self.generate_fake(real_image, mask, True)
268 | path_loss, mean_path_length, path_lengths = self.g_path_regularize(
269 | fake_image, latents, mean_path_length
270 | )
271 | weighted_path_loss = self.opt.path_regularize * self.opt.g_reg_every * path_loss
272 |
273 | if self.opt.path_batch_shrink:
274 | weighted_path_loss += 0 * fake_image[0, 0, 0, 0]
275 | G_regs['path'] = weighted_path_loss
276 | return G_regs, path_lengths, mean_path_length
277 |
278 | def compute_generator_loss(self, real_image, mask):
279 | fake_image, uc_image, _ = self.generate_fake(real_image, mask)
280 | #pred_fake, pred_real = self.discriminate(
281 | # fake_image, real_image)
282 | pred_fake = self.netD(fake_image, mask)
283 |
284 | G_losses = {}
285 | G_losses['GAN'] = self.criterionGAN(pred_fake, True,
286 | for_discriminator=False)
287 | if not self.opt.no_vgg_loss:
288 | G_losses['VGG'] = self.criterionVGG(uc_image, real_image) \
289 | * self.opt.lambda_vgg
290 | if not self.opt.no_l1_loss:
291 | G_losses['L1'] = torch.nn.functional.l1_loss(uc_image, real_image) * self.opt.lambda_l1
292 | return G_losses, fake_image
293 |
294 | def compute_discriminator_reg(self, real_image, mask):
295 | real_image.requires_grad = True
296 | real_pred = self.netD(real_image, mask)
297 | grad_real, = torch.autograd.grad(
298 | outputs=real_pred.sum(), inputs=real_image, create_graph=True
299 | )
300 | r1_loss = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
301 |
302 | r1_loss = self.opt.r1 / 2 * r1_loss * self.opt.d_reg_every + 0 * real_pred[0]
303 | D_regs = {'r1': r1_loss}
304 |
305 | return D_regs
306 |
307 | def compute_discriminator_loss(self, real_image, fake_image=None, mask=None):
308 | D_losses = {}
309 | assert mask is not None
310 | assert fake_image is not None or real_image is not None
311 | assert fake_image is None or real_image is None
312 | if fake_image is not None:
313 | fake_image = fake_image.detach()
314 | pred_fake = self.netD(fake_image, mask)
315 | D_losses['D_Fake'] = self.criterionGAN(pred_fake, False,
316 | for_discriminator=True)
317 | elif real_image is not None:
318 | pred_real = self.netD(real_image, mask)
319 | D_losses['D_real'] = self.criterionGAN(pred_real, True,
320 | for_discriminator=True)
321 |
322 | return D_losses
323 |
324 | def factorize_fake(self, real_image, mask, return_latents=False):
325 | self.netG_ema.eval()
326 | bsize = real_image.size(0)
327 | latent_in = torch.randn(bsize, self.opt.z_dim, device=self.device)
328 | dlatent = self.netG_ema(latents_in=[latent_in], get_latent=True)[0]
329 |
330 | direction = self.opt.factor_d * self.eigvec[:, self.opt.factor_i].unsqueeze(0)
331 | img1, _, latent = self.netG_ema(
332 | real_image,
333 | mask,
334 | [dlatent-direction],
335 | return_latents=return_latents,
336 | truncation=self.opt.truncation,
337 | truncation_latent=self.truncation_mean,
338 | input_is_latent=True)
339 | img2, _, latent = self.netG_ema(
340 | real_image,
341 | mask,
342 | [dlatent],
343 | return_latents=return_latents,
344 | truncation=self.opt.truncation,
345 | truncation_latent=self.truncation_mean,
346 | input_is_latent=True)
347 | img3, _, latent = self.netG_ema(
348 | real_image,
349 | mask,
350 | [dlatent+direction],
351 | return_latents=return_latents,
352 | truncation=self.opt.truncation,
353 | truncation_latent=self.truncation_mean,
354 | input_is_latent=True)
355 | fake_image = torch.cat((img1,img2,img3),3)
356 | return fake_image, dlatent
357 |
358 | def generate_fake(self, real_image, mask, return_latents=False, ema=False):
359 | bsize = real_image.size(0)
360 | noise = self.mixing_noise(bsize)
361 | if ema:
362 | self.netG_ema.eval()
363 | fake_image, uc_image, latent = self.netG_ema(
364 | real_image,
365 | mask,
366 | noise,
367 | return_latents=return_latents,
368 | truncation=self.opt.truncation,
369 | truncation_latent=self.truncation_mean,
370 | )
371 | else:
372 | fake_image, uc_image, latent = self.netG(
373 | real_image,
374 | mask,
375 | noise,
376 | return_latents=return_latents,
377 | )
378 | return fake_image, uc_image, latent
379 |
380 | # Given fake and real image, return the prediction of discriminator
381 | # for each fake and real image.
382 |
383 | def discriminate(self, fake_image, real_image):
384 | raise NotImplementedError
385 | fake_concat = fake_image
386 | real_concat = real_image
387 |
388 | # In Batch Normalization, the fake and real images are
389 | # recommended to be in the same batch to avoid disparate
390 | # statistics in fake and real images.
391 | # So both fake and real images are fed to D all at once.
392 | fake_and_real = torch.cat([fake_concat, real_concat], dim=0)
393 | discriminator_out = self.netD(fake_and_real)
394 |
395 | pred_fake, pred_real = self.divide_pred(discriminator_out)
396 |
397 | return pred_fake, pred_real
398 |
399 | # Take the prediction of fake and real images from the combined batch
400 | def divide_pred(self, pred):
401 | # the prediction contains the intermediate outputs of multiscale GAN,
402 | # so it's usually a list
403 | if type(pred) == list:
404 | fake = []
405 | real = []
406 | for p in pred:
407 | fake.append([tensor[:tensor.size(0) // 2] for tensor in p])
408 | real.append([tensor[tensor.size(0) // 2:] for tensor in p])
409 | else:
410 | fake = pred[:pred.size(0) // 2]
411 | real = pred[pred.size(0) // 2:]
412 |
413 | return fake, real
414 |
415 | def use_gpu(self):
416 | return len(self.opt.gpu_ids) > 0
417 |
--------------------------------------------------------------------------------
/models/create_mask.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import random
4 | from PIL import Image, ImageDraw
5 | import os
6 | import pdb
7 | import math
8 |
9 | class MaskCreator:
10 | def __init__(self, list_mask_path=None, base_mask_path=None, match_size=False):
11 | self.match_size = match_size
12 | if list_mask_path is not None:
13 | filenames = open(list_mask_path).readlines()
14 | msk_filenames = list(map(lambda x: os.path.join(base_mask_path, x.strip('\n')), filenames))
15 | self.msk_filenames = msk_filenames
16 | else:
17 | self.msk_filenames = None
18 |
19 |
20 | def object_shadow(self, h, w, blur_kernel=7, noise_loc=0.5, noise_range=0.05):
21 | """
22 | img: rgb numpy
23 | return: rgb numpy
24 | """
25 | mask = self.object_mask(h, w)
26 | kernel = np.ones((blur_kernel+3,blur_kernel+3),np.float32)
27 | expand_mask = cv2.dilate(mask,kernel,iterations = 1)
28 | noise = np.random.normal(noise_loc, noise_range, mask.shape)
29 | noise[noise>1] = 1
30 | mask = mask*noise
31 | mask = mask + (mask==0)
32 | kernel = np.ones((blur_kernel,blur_kernel),np.float32)/(blur_kernel*blur_kernel)
33 | mask = cv2.filter2D(mask,-1,kernel)
34 | return mask, expand_mask
35 |
36 |
37 | def object_mask(self, image_height=256, image_width=256):
38 | if self.msk_filenames is None:
39 | raise NotImplementedError
40 | hb, wb = image_height, image_width
41 | # object mask as hole
42 | mask = Image.open(random.choice(self.msk_filenames))
43 | ## randomly resize
44 | wm, hm = mask.size
45 | if self.match_size:
46 | r = float(min(hb, wb)) / max(wm, hm)
47 | r = r /2
48 | else:
49 | r = 1
50 | scale = random.gauss(r, 0.5)
51 | scale = scale if scale > 0.5 else 0.5
52 | scale = scale if scale < 2 else 2.0
53 | wm, hm = int(wm*scale), int(hm*scale)
54 | mask = mask.resize((wm, hm))
55 | mask = np.array(mask)
56 | mask = (mask>0)
57 | if mask.sum() > 0:
58 | ## crop object region
59 | col_nz = mask.sum(0)
60 | row_nz = mask.sum(1)
61 | col_nz = np.where(col_nz!=0)[0]
62 | left = col_nz[0]
63 | right = col_nz[-1]
64 | row_nz = np.where(row_nz!=0)[0]
65 | top = row_nz[0]
66 | bot = row_nz[-1]
67 | mask = mask[top:bot, left:right]
68 | else:
69 | return self.object_mask(image_height, image_width)
70 | ## place in a random location on the extended canvas
71 | hm, wm = mask.shape
72 | canvas = np.zeros((hm+hb, wm+wb))
73 | y = random.randint(0, hb-1)
74 | x = random.randint(0, wb-1)
75 | canvas[y:y+hm, x:x+wm] = mask
76 | hole = canvas[int(hm/2):int(hm/2)+hb, int(wm/2):int(wm/2)+wb]
77 | th = 100 if self.match_size else 1000
78 | if hole.sum() < hb*wb / th:
79 | return self.object_mask(image_height, image_width)
80 | else:
81 | return hole.astype(np.float)
82 |
83 | def rectangle_mask(self, image_height=256, image_width=256, min_hole_size=64, max_hole_size=128):
84 | mask = np.zeros((image_height, image_width))
85 | hole_size = random.randint(min_hole_size, max_hole_size)
86 | hole_size = min(int(image_width*0.8), int(image_height*0.8), hole_size)
87 | x = random.randint(0, image_width-hole_size-1)
88 | y = random.randint(0, image_height-hole_size-1)
89 | mask[x:x+hole_size, y:y+hole_size] = 1
90 | return mask
91 |
92 | def random_brush(
93 | self,
94 | max_tries,
95 | image_height=256,
96 | image_width=256,
97 | min_num_vertex = 4,
98 | max_num_vertex = 18,
99 | mean_angle = 2*math.pi / 5,
100 | angle_range = 2*math.pi / 15,
101 | min_width = 12,
102 | max_width = 48):
103 | H, W = image_height, image_width
104 | average_radius = math.sqrt(H*H+W*W) / 8
105 | mask = Image.new('L', (W, H), 0)
106 | for _ in range(np.random.randint(max_tries)):
107 | num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
108 | angle_min = mean_angle - np.random.uniform(0, angle_range)
109 | angle_max = mean_angle + np.random.uniform(0, angle_range)
110 | angles = []
111 | vertex = []
112 | for i in range(num_vertex):
113 | if i % 2 == 0:
114 | angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
115 | else:
116 | angles.append(np.random.uniform(angle_min, angle_max))
117 |
118 | h, w = mask.size
119 | vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
120 | for i in range(num_vertex):
121 | r = np.clip(
122 | np.random.normal(loc=average_radius, scale=average_radius//2),
123 | 0, 2*average_radius)
124 | new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
125 | new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
126 | vertex.append((int(new_x), int(new_y)))
127 |
128 | draw = ImageDraw.Draw(mask)
129 | width = int(np.random.uniform(min_width, max_width))
130 | draw.line(vertex, fill=1, width=width)
131 | for v in vertex:
132 | draw.ellipse((v[0] - width//2,
133 | v[1] - width//2,
134 | v[0] + width//2,
135 | v[1] + width//2),
136 | fill=1)
137 | if np.random.random() > 0.5:
138 | mask.transpose(Image.FLIP_LEFT_RIGHT)
139 | if np.random.random() > 0.5:
140 | mask.transpose(Image.FLIP_TOP_BOTTOM)
141 | mask = np.asarray(mask, np.uint8)
142 | if np.random.random() > 0.5:
143 | mask = np.flip(mask, 0)
144 | if np.random.random() > 0.5:
145 | mask = np.flip(mask, 1)
146 | return mask
147 |
148 | def random_mask(self, image_height=256, image_width=256, hole_range=[0,1]):
149 | coef = min(hole_range[0] + hole_range[1], 1.0)
150 | #mask = self.random_brush(int(20 * coef), image_height, image_width)
151 | while True:
152 | mask = np.ones((image_height, image_width), np.uint8)
153 | def Fill(max_size):
154 | w, h = np.random.randint(max_size), np.random.randint(max_size)
155 | ww, hh = w // 2, h // 2
156 | x, y = np.random.randint(-ww, image_width - w + ww), np.random.randint(-hh, image_height - h + hh)
157 | mask[max(y, 0): min(y + h, image_height), max(x, 0): min(x + w, image_width)] = 0
158 | def MultiFill(max_tries, max_size):
159 | for _ in range(np.random.randint(max_tries)):
160 | Fill(max_size)
161 | MultiFill(int(10 * coef), max(image_height, image_width) // 2)
162 | MultiFill(int(5 * coef), max(image_height, image_width))
163 | mask = np.logical_and(mask, 1 - self.random_brush(int(20 * coef), image_height, image_width))
164 | hole_ratio = 1 - np.mean(mask)
165 | if hole_ratio >= hole_range[0] and hole_ratio <= hole_range[1]:
166 | break
167 | return 1-mask
168 |
169 | def stroke_mask(self, image_height=256, image_width=256, max_vertex=5, max_mask=5, max_length=128):
170 | max_angle = np.pi
171 | max_brush_width = max(1, int(max_length*0.4))
172 | min_brush_width = max(1, int(max_length*0.1))
173 |
174 | mask = np.zeros((image_height, image_width))
175 | for k in range(random.randint(1, max_mask)):
176 | num_vertex = random.randint(1, max_vertex)
177 | start_x = random.randint(0, image_width-1)
178 | start_y = random.randint(0, image_height-1)
179 | for i in range(num_vertex):
180 | angle = random.uniform(0, max_angle)
181 | if i % 2 == 0:
182 | angle = 2*np.pi - angle
183 | length = random.uniform(0, max_length)
184 | brush_width = random.randint(min_brush_width, max_brush_width)
185 | end_x = min(int(start_x + length * np.cos(angle)), image_width)
186 | end_y = min(int(start_y + length * np.sin(angle)), image_height)
187 | mask = cv2.line(mask, (start_x, start_y), (end_x, end_y), color=1, thickness=brush_width)
188 | start_x, start_y = end_x, end_y
189 | mask = cv2.circle(mask, (start_x, start_y), int(brush_width/2), 1)
190 | if random.randint(0, 1):
191 | mask = mask[:, ::-1].copy()
192 | if random.randint(0, 1):
193 | mask = mask[::-1, :].copy()
194 | return mask
195 |
196 |
197 | def get_spatial_discount(mask):
198 | H, W = mask.shape
199 | shift_up = np.zeros((H, W))
200 | shift_up[:-1, :] = mask[1:, :]
201 | shift_left = np.zeros((H, W))
202 | shift_left[:, :-1] = mask[:, 1:]
203 |
204 | boundary_y = mask - shift_up
205 | boundary_x = mask - shift_left
206 |
207 | boundary_y = np.abs(boundary_y)
208 | boundary_x = np.abs(boundary_x)
209 | boundary = boundary_x + boundary_y
210 | boundary[boundary != 0 ] = 1
211 | # plt.imshow(boundary)
212 | # plt.show()
213 |
214 | xx, yy = np.meshgrid(range(W), range(H))
215 | bd_x = xx[boundary==1]
216 | bd_y = yy[boundary==1]
217 | dis_x = xx[..., None] - bd_x[None, None, ...]
218 | dis_y = yy[..., None] - bd_y[None, None, ...]
219 | dis = np.sqrt(dis_x*dis_x + dis_y*dis_y)
220 | min_dis = dis.min(2)
221 | gamma = 0.9
222 | discount_map = (gamma**min_dis)*mask
223 | return discount_map
224 |
225 |
226 |
227 |
228 | if __name__ == "__main__":
229 | import os
230 | from tqdm import tqdm
231 | import pdb
232 | mask_creator = MaskCreator()
233 | mask = mask_creator.random_mask(image_height=512, image_width=512)
234 | Image.fromarray((mask*255).astype(np.uint8)).save("output/mask.png")
235 |
--------------------------------------------------------------------------------
/models/networks/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import torch
7 | from models.networks.base_network import BaseNetwork
8 | from models.networks.loss import *
9 | from models.networks.discriminator import *
10 | from models.networks.generator import *
11 | import util.util as util
12 | import pdb
13 |
14 |
15 | def find_network_using_name(target_network_name, filename):
16 | target_class_name = target_network_name + filename
17 | module_name = 'models.networks.' + filename
18 | network = util.find_class_in_module(target_class_name, module_name)
19 |
20 | assert issubclass(network, BaseNetwork), \
21 | "Class %s should be a subclass of BaseNetwork" % network
22 |
23 | return network
24 |
25 |
26 | def modify_commandline_options(parser, is_train):
27 | opt, _ = parser.parse_known_args()
28 |
29 | netG_cls = find_network_using_name(opt.netG, 'generator')
30 | parser = netG_cls.modify_commandline_options(parser, is_train)
31 | if is_train:
32 | netD_cls = find_network_using_name(opt.netD, 'discriminator')
33 | parser = netD_cls.modify_commandline_options(parser, is_train)
34 | return parser
35 |
36 |
37 | def create_network(cls, opt):
38 | net = cls(opt)
39 | net.print_network()
40 | if len(opt.gpu_ids) > 0:
41 | assert(torch.cuda.is_available())
42 | net.cuda()
43 | if opt.init_type is not None:
44 | net.init_weights(opt.init_type, opt.init_variance)
45 | return net
46 |
47 |
48 | def define_G(opt):
49 | netG_cls = find_network_using_name(opt.netG, 'generator')
50 | return create_network(netG_cls, opt)
51 |
52 |
53 | def define_D(opt):
54 | netD_cls = find_network_using_name(opt.netD, 'discriminator')
55 | return create_network(netD_cls, opt)
56 |
57 |
--------------------------------------------------------------------------------
/models/networks/architecture.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torchvision
10 |
11 | # VGG architecter, used for the perceptual loss using a pretrained VGG network
12 | class VGG19(torch.nn.Module):
13 | def __init__(self, requires_grad=False):
14 | super().__init__()
15 | vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
16 | self.slice1 = torch.nn.Sequential()
17 | self.slice2 = torch.nn.Sequential()
18 | self.slice3 = torch.nn.Sequential()
19 | self.slice4 = torch.nn.Sequential()
20 | self.slice5 = torch.nn.Sequential()
21 | for x in range(2):
22 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
23 | for x in range(2, 7):
24 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
25 | for x in range(7, 12):
26 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
27 | for x in range(12, 21):
28 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
29 | for x in range(21, 30):
30 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
31 | if not requires_grad:
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 | def forward(self, X):
36 | h_relu1 = self.slice1(X)
37 | h_relu2 = self.slice2(h_relu1)
38 | h_relu3 = self.slice3(h_relu2)
39 | h_relu4 = self.slice4(h_relu3)
40 | h_relu5 = self.slice5(h_relu4)
41 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
42 | return out
43 |
--------------------------------------------------------------------------------
/models/networks/base_network.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import torch.nn as nn
7 | from torch.nn import init
8 |
9 |
10 | class BaseNetwork(nn.Module):
11 | def __init__(self):
12 | super(BaseNetwork, self).__init__()
13 |
14 | @staticmethod
15 | def modify_commandline_options(parser, is_train):
16 | return parser
17 |
18 | def print_network(self):
19 | if isinstance(self, list):
20 | self = self[0]
21 | num_params = 0
22 | for param in self.parameters():
23 | num_params += param.numel()
24 | print('Network [%s] was created. Total number of parameters: %.1f million. '
25 | 'To see the architecture, do print(network).'
26 | % (type(self).__name__, num_params / 1000000))
27 |
28 | def init_weights(self, init_type='normal', gain=0.02):
29 | def init_func(m):
30 | classname = m.__class__.__name__
31 | if classname.find('BatchNorm2d') != -1:
32 | if hasattr(m, 'weight') and m.weight is not None:
33 | init.normal_(m.weight.data, 1.0, gain)
34 | if hasattr(m, 'bias') and m.bias is not None:
35 | init.constant_(m.bias.data, 0.0)
36 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
37 | if init_type == 'normal':
38 | init.normal_(m.weight.data, 0.0, gain)
39 | elif init_type == 'xavier':
40 | init.xavier_normal_(m.weight.data, gain=gain)
41 | elif init_type == 'xavier_uniform':
42 | init.xavier_uniform_(m.weight.data, gain=1.0)
43 | elif init_type == 'kaiming':
44 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
45 | elif init_type == 'orthogonal':
46 | init.orthogonal_(m.weight.data, gain=gain)
47 | elif init_type == 'none': # uses pytorch's default init method
48 | m.reset_parameters()
49 | else:
50 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
51 | if hasattr(m, 'bias') and m.bias is not None:
52 | init.constant_(m.bias.data, 0.0)
53 |
54 | self.apply(init_func)
55 |
56 | # propagate to children
57 | for m in self.children():
58 | if hasattr(m, 'init_weights'):
59 | m.init_weights(init_type, gain)
60 |
61 | def get_param_list(self, label):
62 | print("updating all params")
63 | return self.parameters()
64 | #raise NotImplementedError
65 |
--------------------------------------------------------------------------------
/models/networks/co_mod_gan.py:
--------------------------------------------------------------------------------
1 | import pdb
2 | import random
3 | from collections import OrderedDict
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 | from models.networks.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
8 | from models.networks.stylegan2 import PixelNorm, EqualLinear, EqualConv2d,ConvLayer,StyledConv,ToRGB,ConvToRGB,TransConvLayer
9 | import numpy as np
10 |
11 | from models.networks.base_network import BaseNetwork
12 |
13 | #----------------------------------------------------------------------------
14 | # Mapping network.
15 | # Transforms the input latent code (z) to the disentangled latent code (w).
16 | # Used in configs B-F (Table 1).
17 |
18 | class G_mapping(nn.Module):
19 | def __init__(self,
20 | opt
21 | ):
22 | latent_size = 512 # Latent vector (Z) dimensionality.
23 | label_size = 0 # Label dimensionality, 0 if no labels.
24 | dlatent_broadcast = None # Output disentangled latent (W) as [minibatch, dlatent_size] or [minibatch, dlatent_broadcast, dlatent_size].
25 | mapping_layers = 8 # Number of mapping layers.
26 | mapping_fmaps = 512 # Number of activations in the mapping layers.
27 | mapping_lrmul = 0.01 # Learning rate multiplier for the mapping layers.
28 | mapping_nonlinearity = 'lrelu' # Activation function: 'relu', 'lrelu', etc.
29 | normalize_latents = True # Normalize latent vectors (Z) before feeding them to the mapping layers?
30 | super().__init__()
31 | layers = []
32 |
33 | # Embed labels and concatenate them with latents.
34 | if label_size:
35 | raise NotImplementedError
36 |
37 | # Normalize latents.
38 | if normalize_latents:
39 | layers.append(
40 | ('Normalize', PixelNorm()))
41 | # Mapping layers.
42 | dim_in = latent_size
43 | for layer_idx in range(mapping_layers):
44 | fmaps = opt.dlatent_size if layer_idx == mapping_layers - 1 else mapping_fmaps
45 | layers.append(
46 | (
47 | 'Dense%d' % layer_idx,
48 | EqualLinear(
49 | dim_in,
50 | fmaps,
51 | lr_mul=mapping_lrmul,
52 | activation="fused_lrelu")
53 | ))
54 | dim_in = fmaps
55 | # Broadcast.
56 | if dlatent_broadcast is not None:
57 | raise NotImplementedError
58 | self.G_mapping = nn.Sequential(OrderedDict(layers))
59 |
60 | def forward(
61 | self,
62 | latents_in):
63 | styles = self.G_mapping(latents_in)
64 | return styles
65 |
66 | #----------------------------------------------------------------------------
67 | # CoModGAN synthesis network.
68 |
69 | class G_synthesis_co_mod_gan(nn.Module):
70 | def __init__(
71 | self,
72 | opt
73 | ):
74 | resolution_log2 = int(np.log2(opt.crop_size))
75 | assert opt.crop_size == 2**resolution_log2 and opt.crop_size >= 4
76 | def nf(stage): return np.clip(int(opt.fmap_base / (2.0 ** (stage * opt.fmap_decay))), opt.fmap_min, opt.fmap_max)
77 | assert opt.architecture in ['skip']
78 | assert opt.nonlinearity == 'lrelu'
79 | assert opt.fused_modconv
80 | assert not opt.pix2pix
81 | self.nf = nf
82 | super().__init__()
83 | act = opt.nonlinearity
84 | self.num_layers = resolution_log2 * 2 - 2
85 | self.resolution_log2 = resolution_log2
86 |
87 | class E_fromrgb(nn.Module): # res = 2..resolution_log2
88 | def __init__(self, res, channel_in=opt.num_channels+1):
89 | super().__init__()
90 | self.FromRGB = ConvLayer(
91 | channel_in,
92 | nf(res-1),
93 | 1,
94 | blur_kernel=opt.resample_kernel,
95 | activate=True)
96 | def forward(self, data):
97 | y, E_features = data
98 | t = self.FromRGB(y)
99 | return t, E_features
100 | class E_block(nn.Module): # res = 2..resolution_log2
101 | def __init__(self, res):
102 | super().__init__()
103 | self.Conv0 = ConvLayer(
104 | nf(res-1),
105 | nf(res-1),
106 | kernel_size=3,
107 | activate=True)
108 | self.Conv1_down = ConvLayer(
109 | nf(res-1),
110 | nf(res-2),
111 | kernel_size=3,
112 | downsample=True,
113 | blur_kernel=opt.resample_kernel,
114 | activate=True)
115 | self.res = res
116 | def forward(self, data):
117 | x, E_features = data
118 | x = self.Conv0(x)
119 | E_features[self.res] = x
120 | x = self.Conv1_down(x)
121 | return x, E_features
122 | class E_block_final(nn.Module): # res = 2..resolution_log2
123 | def __init__(self):
124 | super().__init__()
125 | self.Conv = ConvLayer(
126 | nf(2),
127 | nf(1),
128 | kernel_size=3,
129 | activate=True)
130 | self.Dense0 = EqualLinear(nf(1)*4*4, nf(1)*2,
131 | activation="fused_lrelu")
132 | self.dropout = nn.Dropout(opt.dropout_rate)
133 | def forward(self, data):
134 | x, E_features = data
135 | x = self.Conv(x)
136 | E_features[2] = x
137 | bsize = x.size(0)
138 | x = x.view(bsize, -1)
139 | x = self.Dense0(x)
140 | x = self.dropout(x)
141 | return x, E_features
142 | def make_encoder(channel_in=opt.num_channels+1):
143 | Es = []
144 | for res in range(self.resolution_log2, 2, -1):
145 | if res == self.resolution_log2:
146 | Es.append(
147 | (
148 | '%dx%d_0' % (2**res, 2**res),
149 | E_fromrgb(res, channel_in)
150 | ))
151 | Es.append(
152 | (
153 | '%dx%d' % (2**res, 2**res),
154 | E_block(res)
155 |
156 | ))
157 | # Final layers.
158 | Es.append(
159 | (
160 | '4x4',
161 | E_block_final()
162 |
163 | ))
164 | Es = nn.Sequential(OrderedDict(Es))
165 | return Es
166 | self.make_encoder = make_encoder
167 |
168 | # Main layers.
169 | c_in = opt.num_channels+1
170 | self.E = self.make_encoder(channel_in=c_in)
171 |
172 | # Single convolution layer with all the bells and whistles.
173 | # Building blocks for main layers.
174 | mod_size = 0
175 | if opt.style_mod:
176 | mod_size += opt.dlatent_size
177 | if opt.cond_mod:
178 | mod_size += nf(1)*2
179 | assert mod_size > 0
180 | self.mod_size = mod_size
181 | def get_mod(latent, idx, x_global):
182 | if isinstance(latent, list):
183 | latent = latent[:][idx]
184 | else:
185 | latent = latent[:,idx]
186 | mod_vector = []
187 | if opt.style_mod:
188 | mod_vector.append(latent)
189 | if opt.cond_mod:
190 | mod_vector.append(x_global)
191 | mod_vector = torch.cat(mod_vector, 1)
192 | return mod_vector
193 | self.get_mod = get_mod
194 | class Block(nn.Module):
195 | def __init__(self, res):
196 | super().__init__()
197 | self.res = res
198 | self.Conv0_up = StyledConv(
199 | nf(res-2),
200 | nf(res-1),
201 | kernel_size=3,
202 | style_dim=mod_size,
203 | upsample=True,
204 | blur_kernel=opt.resample_kernel)
205 | self.Conv1 = StyledConv(
206 | nf(res-1),
207 | nf(res-1),
208 | kernel_size=3,
209 | style_dim=mod_size,
210 | upsample=False)
211 | self.ToRGB = ToRGB(
212 | nf(res-1),
213 | mod_size, out_channel=opt.num_channels)
214 | def forward(self, x, y, dlatents_in, x_global, E_features):
215 | x_skip = E_features[self.res]
216 | mod_vector = get_mod(dlatents_in, res*2-5, x_global)
217 | if opt.noise_injection:
218 | noise = None
219 | else:
220 | noise = 0
221 | x = self.Conv0_up(x, mod_vector, noise, x_skip=x_skip)
222 | x = x + x_skip
223 | mod_vector = get_mod(dlatents_in, self.res*2-4, x_global)
224 | x = self.Conv1(x, mod_vector, noise, x_skip=x_skip)
225 | mod_vector = get_mod(dlatents_in, self.res*2-3, x_global)
226 | y = self.ToRGB(x, mod_vector, skip=y, x_skip=x_skip)
227 | return x, y
228 | self.Block = Block
229 | class Block0(nn.Module):
230 | def __init__(self):
231 | super().__init__()
232 | self.Dense = EqualLinear(
233 | nf(1)*2,
234 | nf(1)*4*4,
235 | activation="fused_lrelu")
236 | self.Conv = StyledConv(
237 | nf(1),
238 | nf(1),
239 | kernel_size=3,
240 | style_dim=mod_size,
241 | )
242 | self.ToRGB = ToRGB(
243 | nf(1),
244 | style_dim=mod_size,
245 | upsample=False, out_channel=opt.num_channels)
246 | def forward(self, x, dlatents_in, x_global):
247 | x = self.Dense(x)
248 | x = x.view(-1, nf(1), 4, 4)
249 | mod_vector = get_mod(dlatents_in, 0, x_global)
250 | if opt.noise_injection:
251 | noise = None
252 | else:
253 | noise = 0
254 | x = self.Conv(x, mod_vector, noise)
255 | mod_vector = get_mod(dlatents_in, 1, x_global)
256 | y = self.ToRGB(x, mod_vector)
257 | return x, y
258 | # Early layers.
259 | self.G_4x4 = Block0()
260 | # Main layers.
261 | for res in range(3, resolution_log2 + 1):
262 | setattr(self, 'G_%dx%d' % (2**res, 2**res),
263 | Block(res))
264 |
265 | def forward(self, images_in, masks_in, dlatents_in):
266 | y = torch.cat([1-masks_in - 0.5, images_in * (1-masks_in)], 1)
267 | E_features = {}
268 | x_global, E_features = self.E((y, E_features))
269 | x = x_global
270 | x, y = self.G_4x4(x, dlatents_in, x_global)
271 | for res in range(3, self.resolution_log2 + 1):
272 | block = getattr(self, 'G_%dx%d' % (2**res, 2**res))
273 | x, y = block(x, y, dlatents_in, x_global, E_features)
274 | raw_out = y
275 | images_out = y * masks_in + images_in * (1-masks_in)
276 | return images_out, raw_out
277 |
278 | #----------------------------------------------------------------------------
279 | # Main generator network.
280 | # Composed of two sub-networks (mapping and synthesis) that are defined below.
281 | # Used in configs B-F (Table 1).
282 |
283 | class Generator(BaseNetwork):
284 | @staticmethod
285 | def modify_commandline_options(parser, is_train):
286 | parser.add_argument('--dlatent_size', type=int, default= 512 )# Disentangled latent (W) dimensionality.
287 | parser.add_argument('--num_channels', type=int, default= 3, )# Number of output color channels.
288 | parser.add_argument('--fmap_base', type=int, default= 16 << 10, )# Overall multiplier for the number of feature maps.
289 | parser.add_argument('--fmap_decay', type=int, default= 1.0, )# log2 feature map reduction when doubling the resolution.
290 | parser.add_argument('--fmap_min', type=int, default= 1, )# Minimum number of feature maps in any layer.
291 | parser.add_argument('--fmap_max', type=int, default= 512, )# Maximum number of feature maps in any layer.
292 | parser.add_argument('--randomize_noise', type=bool, default= True, )# True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables.
293 | parser.add_argument('--architecture', type=str, default= 'skip', )# Architecture: 'orig', 'skip', 'resnet'.
294 | parser.add_argument('--nonlinearity', type=str, default= 'lrelu', )# Activation function: 'relu', 'lrelu', etc.
295 | parser.add_argument('--resample_kernel', type=list, default= [1,3,3,1], )# Low-pass filter to apply when resampling activations. None = no filtering.
296 | parser.add_argument('--fused_modconv', type=bool, default= True, )# Implement modulated_conv2d_layer() as a single fused op?
297 | parser.add_argument('--pix2pix', type=bool, default= False)
298 | parser.add_argument('--dropout_rate', type=float, default= 0.5)
299 | parser.add_argument('--cond_mod', type=bool, default= True,)
300 | parser.add_argument('--style_mod', type=bool, default= True,)
301 | parser.add_argument('--noise_injection', type=bool, default= True,)
302 | return parser
303 | def __init__(
304 | self,
305 | opt=None): # Arguments for sub-networks (mapping and synthesis).
306 | super().__init__()
307 | self.G_mapping = G_mapping(opt)
308 | self.G_synthesis = G_synthesis_co_mod_gan(opt)
309 |
310 | def forward(
311 | self,
312 | images_in=None,
313 | masks_in=None,
314 | latents_in=None,
315 | return_latents=False,
316 | inject_index=None,
317 | truncation=None,
318 | truncation_latent=None,
319 | input_is_latent=False,
320 | get_latent=False,
321 | ):
322 | #assert isinstance(latents_in, list)
323 | if not input_is_latent:
324 | dlatents_in = [self.G_mapping(s) for s in latents_in]
325 | else:
326 | dlatents_in = latents_in
327 | if get_latent:
328 | return dlatents_in
329 | if truncation is not None:
330 | dlatents_t = []
331 | for style in dlatents_in:
332 | dlatents_t.append(
333 | truncation_latent + truncation * (style - truncation_latent)
334 | )
335 | dlatents_in = dlatents_t
336 | if len(dlatents_in) < 2:
337 | inject_index = self.G_synthesis.num_layers
338 | if dlatents_in[0].ndim < 3:
339 | dlatent = dlatents_in[0].unsqueeze(1).repeat(1, inject_index, 1)
340 | else:
341 | dlatent = dlatents_in[0]
342 | else:
343 | if inject_index is None:
344 | inject_index = random.randint(1, self.G_synthesis.num_layers - 1)
345 | dlatent = dlatents_in[0].unsqueeze(1).repeat(1, inject_index, 1)
346 | dlatent2 = dlatents_in[1].unsqueeze(1).repeat(1, self.G_synthesis.num_layers - inject_index, 1)
347 |
348 | dlatent = torch.cat([dlatent, dlatent2], 1)
349 | output, raw_out = self.G_synthesis(images_in, masks_in, dlatent)
350 | if return_latents:
351 | return output, raw_out, dlatent
352 | else:
353 | return output, raw_out, None
354 |
355 | #----------------------------------------------------------------------------
356 | # CoModGAN discriminator.
357 |
358 | class Discriminator(BaseNetwork):
359 | @staticmethod
360 | def modify_commandline_options(parser, is_train):
361 | parser.add_argument('--mbstd_num_features', type=int, default= 1, )# Number of features for the minibatch standard deviation layer.
362 | parser.add_argument('--mbstd_group_size', type=int, default= 4, )# Group size for the minibatch standard deviation layer, 0 = disable.
363 | return parser
364 | def __init__(
365 | self,
366 | opt):
367 | label_size = 0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.
368 | architecture = 'resnet' # Architecture: 'orig', 'skip', 'resnet'.
369 | pix2pix = False
370 | assert not pix2pix
371 | assert opt.nonlinearity == 'lrelu'
372 | assert architecture == 'resnet'
373 | if opt is not None:
374 | resolution = opt.crop_size
375 |
376 | resolution_log2 = int(np.log2(resolution))
377 | assert resolution == 2**resolution_log2 and resolution >= 4
378 | def nf(stage): return np.clip(int(opt.fmap_base / (2.0 ** (stage * opt.fmap_decay))), opt.fmap_min, opt.fmap_max)
379 | #assert architecture in ['orig', 'skip', 'resnet']
380 |
381 | # Building blocks for main layers.
382 | super().__init__()
383 | layers = []
384 | c_in = opt.num_channels+1
385 | layers.append(
386 | (
387 | "ToRGB",
388 | ConvLayer(
389 | c_in,
390 | nf(resolution_log2-1),
391 | kernel_size=3,
392 | activate=True)
393 | )
394 | )
395 |
396 | class Block(nn.Module):
397 | def __init__(self, res):
398 | super().__init__()
399 | self.Conv0 = ConvLayer(
400 | nf(res-1),
401 | nf(res-1),
402 | kernel_size=3,
403 | activate=True)
404 | self.Conv1_down = ConvLayer(
405 | nf(res-1),
406 | nf(res-2),
407 | kernel_size=3,
408 | downsample=True,
409 | blur_kernel=opt.resample_kernel,
410 | activate=True)
411 | self.Skip = ConvLayer(
412 | nf(res-1),
413 | nf(res-2),
414 | kernel_size=1,
415 | downsample=True,
416 | blur_kernel=opt.resample_kernel,
417 | activate=False,
418 | bias=False)
419 | def forward(self, x):
420 | t = x
421 | x = self.Conv0(x)
422 | x = self.Conv1_down(x)
423 | t = self.Skip(t)
424 | x = (x + t) * (1/np.sqrt(2))
425 | return x
426 | # Main layers.
427 | for res in range(resolution_log2, 2, -1):
428 | layers.append(
429 | (
430 | '%dx%d' % (2**res, 2**res),
431 | Block(res)
432 | )
433 | )
434 | self.convs = nn.Sequential(OrderedDict(layers))
435 | # Final layers.
436 | self.mbstd_group_size = opt.mbstd_group_size
437 | self.mbstd_num_features = opt.mbstd_num_features
438 |
439 | self.Conv4x4 = ConvLayer(nf(1)+1, nf(1), kernel_size=3, activate=True)
440 | self.Dense0 = EqualLinear(nf(1)*4*4, nf(0), activation='fused_lrelu')
441 | self.Output = EqualLinear(nf(0), 1)
442 |
443 | def forward(self, images_in, masks_in):
444 | masks_in = 1-masks_in
445 | y = torch.cat([masks_in - 0.5, images_in], 1)
446 | out = self.convs(y)
447 | batch, channel, height, width = out.shape
448 | group_size = min(batch, self.mbstd_group_size)
449 | #print(out.shape)
450 | #pdb.set_trace()
451 | stddev = out.view(
452 | group_size,
453 | -1,
454 | self.mbstd_num_features,
455 | channel // self.mbstd_num_features,
456 | height, width
457 | )
458 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
459 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
460 | stddev = stddev.repeat(group_size, 1, height, width)
461 | out = torch.cat([out, stddev], 1)
462 | out = self.Conv4x4(out)
463 | out = out.view(batch, -1)
464 | out = self.Dense0(out)
465 | out = self.Output(out)
466 | return out
467 |
468 |
469 |
470 | if __name__ == "__main__":
471 | import cv2
472 | from PIL import Image
473 | path_img = "/home/zeng/co-mod-gan/imgs/example_image.jpg"
474 | path_mask = "/home/zeng/co-mod-gan/imgs/example_mask.jpg"
475 |
476 | real = np.asarray(Image.open(path_img)).transpose([2, 0, 1])/255.0
477 |
478 | masks = np.asarray(Image.open(path_mask).convert('1'), dtype=np.float32)
479 |
480 | images = torch.Tensor(real.copy())[None,...]*2-1
481 | masks = torch.Tensor(masks)[None,None,...].float()
482 | masks = (masks==0).float()
483 |
484 | net = Discriminator()
485 | hh = net(images, masks)
486 | pdb.set_trace()
487 |
488 | #net = Generator()
489 | #net.G_mapping.load_from_tf_dict("/home/zeng/co-mod-gan/co-mod-gan-ffhq-9-025000.npz")
490 | #net.G_synthesis.load_from_tf_dict("/home/zeng/co-mod-gan/co-mod-gan-ffhq-9-025000.npz")
491 | #net.eval()
492 | #torch.save(net.state_dict(), "co-mod-gan-ffhq-9-025000.pth")
493 |
494 | #latents_in = torch.randn(1, 512)
495 |
496 | #hh = net(images, masks, [latents_in], truncation=None)
497 | #hh = hh.detach().cpu().numpy()
498 | #hh = (hh+1)/2
499 | #hh = (hh[0].transpose((1,2,0)))*255
500 | #cv2.imwrite("hh.png", hh[:,:,::-1].clip(0,255))
501 | #pdb.set_trace()
502 |
--------------------------------------------------------------------------------
/models/networks/discriminator.py:
--------------------------------------------------------------------------------
1 | from models.networks.co_mod_gan import Discriminator as CoModGANDiscriminator
2 |
--------------------------------------------------------------------------------
/models/networks/generator.py:
--------------------------------------------------------------------------------
1 | from models.networks.co_mod_gan import Generator as CoModGANGenerator
2 |
--------------------------------------------------------------------------------
/models/networks/loss.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from models.networks.architecture import VGG19
10 |
11 |
12 | # Defines the GAN loss which uses either LSGAN or the regular GAN.
13 | # When LSGAN is used, it is basically same as MSELoss,
14 | # but it abstracts away the need to create the target label tensor
15 | # that has the same size as the input
16 | class GANLoss(nn.Module):
17 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0,
18 | tensor=torch.FloatTensor, opt=None):
19 | super(GANLoss, self).__init__()
20 | self.real_label = target_real_label
21 | self.fake_label = target_fake_label
22 | self.real_label_tensor = None
23 | self.fake_label_tensor = None
24 | self.zero_tensor = None
25 | self.Tensor = tensor
26 | self.gan_mode = gan_mode
27 | self.opt = opt
28 | if gan_mode == 'ls':
29 | pass
30 | elif gan_mode == 'original':
31 | pass
32 | elif gan_mode == 'w':
33 | pass
34 | elif gan_mode == 'hinge':
35 | pass
36 | elif gan_mode == 'softplus':
37 | pass
38 | else:
39 | raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
40 |
41 | def get_target_tensor(self, input, target_is_real):
42 | if target_is_real:
43 | if self.real_label_tensor is None:
44 | self.real_label_tensor = self.Tensor(1).fill_(self.real_label)
45 | self.real_label_tensor.requires_grad_(False)
46 | return self.real_label_tensor.expand_as(input)
47 | else:
48 | if self.fake_label_tensor is None:
49 | self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label)
50 | self.fake_label_tensor.requires_grad_(False)
51 | return self.fake_label_tensor.expand_as(input)
52 |
53 | def get_zero_tensor(self, input):
54 | if self.zero_tensor is None:
55 | self.zero_tensor = self.Tensor(1).fill_(0)
56 | self.zero_tensor.requires_grad_(False)
57 | return self.zero_tensor.expand_as(input)
58 |
59 | def loss(self, input, target_is_real, for_discriminator=True):
60 | if self.gan_mode == 'original': # cross entropy loss
61 | target_tensor = self.get_target_tensor(input, target_is_real)
62 | loss = F.binary_cross_entropy_with_logits(input, target_tensor)
63 | return loss
64 | elif self.gan_mode == 'ls':
65 | target_tensor = self.get_target_tensor(input, target_is_real)
66 | return F.mse_loss(input, target_tensor)
67 | elif self.gan_mode == 'hinge':
68 | if for_discriminator:
69 | if target_is_real:
70 | minval = torch.min(input - 1, self.get_zero_tensor(input))
71 | loss = -torch.mean(minval)
72 | else:
73 | minval = torch.min(-input - 1, self.get_zero_tensor(input))
74 | loss = -torch.mean(minval)
75 | else:
76 | assert target_is_real, "The generator's hinge loss must be aiming for real"
77 | loss = -torch.mean(input)
78 | return loss
79 | elif self.gan_mode == 'softplus':
80 | # wgan
81 | if target_is_real:
82 | return F.softplus(-input).mean()
83 | else:
84 | return F.softplus(input).mean()
85 | else:
86 | # wgan
87 | if target_is_real:
88 | return -input.mean()
89 | else:
90 | return input.mean()
91 |
92 | def __call__(self, input, target_is_real, for_discriminator=True):
93 | # computing loss is a bit complicated because |input| may not be
94 | # a tensor, but list of tensors in case of multiscale discriminator
95 | if isinstance(input, list):
96 | loss = 0
97 | for pred_i in input:
98 | if isinstance(pred_i, list):
99 | pred_i = pred_i[-1]
100 | loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
101 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
102 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
103 | loss += new_loss
104 | return loss / len(input)
105 | else:
106 | return self.loss(input, target_is_real, for_discriminator)
107 |
108 |
109 | # Perceptual loss that uses a pretrained VGG network
110 | class VGGLoss(nn.Module):
111 | def __init__(self, gpu_ids):
112 | super(VGGLoss, self).__init__()
113 | self.vgg = VGG19().cuda()
114 | self.criterion = nn.L1Loss()
115 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
116 |
117 | def forward(self, x, y, **kwargs):
118 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
119 | loss = 0
120 | for i in range(len(x_vgg)):
121 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
122 | return loss
123 |
124 | class MaskedVGGLoss(VGGLoss):
125 | def forward(self, x, y, mask, **kwargs):
126 | x_vgg, y_vgg = self.vgg(x*mask), self.vgg(y*mask)
127 | loss = 0
128 | for i in range(len(x_vgg)):
129 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
130 | return loss
131 |
132 |
133 | class VGGFaceLoss(nn.Module):
134 | def __init__(self, gpu_ids, weights_path):
135 | super(VGGFaceLoss, self).__init__()
136 | self.vgg = VGGFace(weights_path=weights_path).cuda()
137 | self.criterion = nn.L1Loss()
138 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
139 |
140 | def forward(self, x, y):
141 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
142 | loss = 0
143 | for i in range(len(x_vgg)):
144 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
145 | return loss
146 |
147 |
148 | # KL Divergence loss used in VAE with an image encoder
149 | class KLDLoss(nn.Module):
150 | def forward(self, mu, logvar):
151 | return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
152 |
--------------------------------------------------------------------------------
/models/networks/op/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 | from .upfirdn2d import upfirdn2d
3 |
--------------------------------------------------------------------------------
/models/networks/op/fused_act.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 | from torch.autograd import Function
7 | from torch.utils.cpp_extension import load
8 |
9 |
10 | module_path = os.path.dirname(__file__)
11 | fused = load(
12 | "fused",
13 | sources=[
14 | os.path.join(module_path, "fused_bias_act.cpp"),
15 | os.path.join(module_path, "fused_bias_act_kernel.cu"),
16 | ],
17 | )
18 |
19 |
20 | class FusedLeakyReLUFunctionBackward(Function):
21 | @staticmethod
22 | def forward(ctx, grad_output, out, bias, negative_slope, scale):
23 | ctx.save_for_backward(out)
24 | ctx.negative_slope = negative_slope
25 | ctx.scale = scale
26 |
27 | empty = grad_output.new_empty(0)
28 |
29 | grad_input = fused.fused_bias_act(
30 | grad_output, empty, out, 3, 1, negative_slope, scale
31 | )
32 |
33 | dim = [0]
34 |
35 | if grad_input.ndim > 2:
36 | dim += list(range(2, grad_input.ndim))
37 |
38 | if bias:
39 | grad_bias = grad_input.sum(dim).detach()
40 |
41 | else:
42 | grad_bias = empty
43 |
44 | return grad_input, grad_bias
45 |
46 | @staticmethod
47 | def backward(ctx, gradgrad_input, gradgrad_bias):
48 | out, = ctx.saved_tensors
49 | gradgrad_out = fused.fused_bias_act(
50 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
51 | )
52 |
53 | return gradgrad_out, None, None, None, None
54 |
55 |
56 | class FusedLeakyReLUFunction(Function):
57 | @staticmethod
58 | def forward(ctx, input, bias, negative_slope, scale):
59 | empty = input.new_empty(0)
60 |
61 | ctx.bias = bias is not None
62 |
63 | if bias is None:
64 | bias = empty
65 |
66 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
67 | ctx.save_for_backward(out)
68 | ctx.negative_slope = negative_slope
69 | ctx.scale = scale
70 |
71 | return out
72 |
73 | @staticmethod
74 | def backward(ctx, grad_output):
75 | out, = ctx.saved_tensors
76 |
77 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
78 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
79 | )
80 |
81 | if not ctx.bias:
82 | grad_bias = None
83 |
84 | return grad_input, grad_bias, None, None
85 |
86 |
87 | class FusedLeakyReLU(nn.Module):
88 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
89 | super().__init__()
90 |
91 | if bias:
92 | self.bias = nn.Parameter(torch.zeros(channel))
93 |
94 | else:
95 | self.bias = None
96 |
97 | self.negative_slope = negative_slope
98 | self.scale = scale
99 |
100 | def forward(self, input):
101 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
102 |
103 |
104 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
105 | if input.device.type == "cpu":
106 | if bias is not None:
107 | rest_dim = [1] * (input.ndim - bias.ndim - 1)
108 | return (
109 | F.leaky_relu(
110 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
111 | )
112 | * scale
113 | )
114 |
115 | else:
116 | return F.leaky_relu(input, negative_slope=0.2) * scale
117 |
118 | else:
119 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
120 |
--------------------------------------------------------------------------------
/models/networks/op/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5 | int act, int grad, float alpha, float scale);
6 |
7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10 |
11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12 | int act, int grad, float alpha, float scale) {
13 | CHECK_CUDA(input);
14 | CHECK_CUDA(bias);
15 |
16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17 | }
18 |
19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21 | }
--------------------------------------------------------------------------------
/models/networks/op/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 |
18 | template
19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22 |
23 | scalar_t zero = 0.0;
24 |
25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26 | scalar_t x = p_x[xi];
27 |
28 | if (use_bias) {
29 | x += p_b[(xi / step_b) % size_b];
30 | }
31 |
32 | scalar_t ref = use_ref ? p_ref[xi] : zero;
33 |
34 | scalar_t y;
35 |
36 | switch (act * 10 + grad) {
37 | default:
38 | case 10: y = x; break;
39 | case 11: y = x; break;
40 | case 12: y = 0.0; break;
41 |
42 | case 30: y = (x > 0.0) ? x : x * alpha; break;
43 | case 31: y = (ref > 0.0) ? x : x * alpha; break;
44 | case 32: y = 0.0; break;
45 | }
46 |
47 | out[xi] = y * scale;
48 | }
49 | }
50 |
51 |
52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53 | int act, int grad, float alpha, float scale) {
54 | int curDevice = -1;
55 | cudaGetDevice(&curDevice);
56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57 |
58 | auto x = input.contiguous();
59 | auto b = bias.contiguous();
60 | auto ref = refer.contiguous();
61 |
62 | int use_bias = b.numel() ? 1 : 0;
63 | int use_ref = ref.numel() ? 1 : 0;
64 |
65 | int size_x = x.numel();
66 | int size_b = b.numel();
67 | int step_b = 1;
68 |
69 | for (int i = 1 + 1; i < x.dim(); i++) {
70 | step_b *= x.size(i);
71 | }
72 |
73 | int loop_x = 4;
74 | int block_size = 4 * 32;
75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76 |
77 | auto y = torch::empty_like(x);
78 |
79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80 | fused_bias_act_kernel<<>>(
81 | y.data_ptr(),
82 | x.data_ptr(),
83 | b.data_ptr(),
84 | ref.data_ptr(),
85 | act,
86 | grad,
87 | alpha,
88 | scale,
89 | loop_x,
90 | size_x,
91 | step_b,
92 | size_b,
93 | use_bias,
94 | use_ref
95 | );
96 | });
97 |
98 | return y;
99 | }
--------------------------------------------------------------------------------
/models/networks/op/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5 | int up_x, int up_y, int down_x, int down_y,
6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7 |
8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11 |
12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13 | int up_x, int up_y, int down_x, int down_y,
14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15 | CHECK_CUDA(input);
16 | CHECK_CUDA(kernel);
17 |
18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19 | }
20 |
21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23 | }
--------------------------------------------------------------------------------
/models/networks/op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch.nn import functional as F
5 | from torch.autograd import Function
6 | from torch.utils.cpp_extension import load
7 |
8 |
9 | module_path = os.path.dirname(__file__)
10 | upfirdn2d_op = load(
11 | "upfirdn2d",
12 | sources=[
13 | os.path.join(module_path, "upfirdn2d.cpp"),
14 | os.path.join(module_path, "upfirdn2d_kernel.cu"),
15 | ],
16 | )
17 |
18 |
19 | class UpFirDn2dBackward(Function):
20 | @staticmethod
21 | def forward(
22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
23 | ):
24 |
25 | up_x, up_y = up
26 | down_x, down_y = down
27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
28 |
29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
30 |
31 | grad_input = upfirdn2d_op.upfirdn2d(
32 | grad_output,
33 | grad_kernel,
34 | down_x,
35 | down_y,
36 | up_x,
37 | up_y,
38 | g_pad_x0,
39 | g_pad_x1,
40 | g_pad_y0,
41 | g_pad_y1,
42 | )
43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
44 |
45 | ctx.save_for_backward(kernel)
46 |
47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
48 |
49 | ctx.up_x = up_x
50 | ctx.up_y = up_y
51 | ctx.down_x = down_x
52 | ctx.down_y = down_y
53 | ctx.pad_x0 = pad_x0
54 | ctx.pad_x1 = pad_x1
55 | ctx.pad_y0 = pad_y0
56 | ctx.pad_y1 = pad_y1
57 | ctx.in_size = in_size
58 | ctx.out_size = out_size
59 |
60 | return grad_input
61 |
62 | @staticmethod
63 | def backward(ctx, gradgrad_input):
64 | kernel, = ctx.saved_tensors
65 |
66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
67 |
68 | gradgrad_out = upfirdn2d_op.upfirdn2d(
69 | gradgrad_input,
70 | kernel,
71 | ctx.up_x,
72 | ctx.up_y,
73 | ctx.down_x,
74 | ctx.down_y,
75 | ctx.pad_x0,
76 | ctx.pad_x1,
77 | ctx.pad_y0,
78 | ctx.pad_y1,
79 | )
80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
81 | gradgrad_out = gradgrad_out.view(
82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
83 | )
84 |
85 | return gradgrad_out, None, None, None, None, None, None, None, None
86 |
87 |
88 | class UpFirDn2d(Function):
89 | @staticmethod
90 | def forward(ctx, input, kernel, up, down, pad):
91 | up_x, up_y = up
92 | down_x, down_y = down
93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
94 |
95 | kernel_h, kernel_w = kernel.shape
96 | batch, channel, in_h, in_w = input.shape
97 | ctx.in_size = input.shape
98 |
99 | input = input.reshape(-1, in_h, in_w, 1)
100 |
101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
102 |
103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
105 | ctx.out_size = (out_h, out_w)
106 |
107 | ctx.up = (up_x, up_y)
108 | ctx.down = (down_x, down_y)
109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
110 |
111 | g_pad_x0 = kernel_w - pad_x0 - 1
112 | g_pad_y0 = kernel_h - pad_y0 - 1
113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
115 |
116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
117 |
118 | out = upfirdn2d_op.upfirdn2d(
119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
120 | )
121 | # out = out.view(major, out_h, out_w, minor)
122 | out = out.view(-1, channel, out_h, out_w)
123 |
124 | return out
125 |
126 | @staticmethod
127 | def backward(ctx, grad_output):
128 | kernel, grad_kernel = ctx.saved_tensors
129 |
130 | grad_input = UpFirDn2dBackward.apply(
131 | grad_output,
132 | kernel,
133 | grad_kernel,
134 | ctx.up,
135 | ctx.down,
136 | ctx.pad,
137 | ctx.g_pad,
138 | ctx.in_size,
139 | ctx.out_size,
140 | )
141 |
142 | return grad_input, None, None, None, None
143 |
144 |
145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
146 | if input.device.type == "cpu":
147 | out = upfirdn2d_native(
148 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
149 | )
150 |
151 | else:
152 | out = UpFirDn2d.apply(
153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
154 | )
155 |
156 | return out
157 |
158 |
159 | def upfirdn2d_native(
160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
161 | ):
162 | _, channel, in_h, in_w = input.shape
163 | input = input.reshape(-1, in_h, in_w, 1)
164 |
165 | _, in_h, in_w, minor = input.shape
166 | kernel_h, kernel_w = kernel.shape
167 |
168 | out = input.view(-1, in_h, 1, in_w, 1, minor)
169 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
170 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
171 |
172 | out = F.pad(
173 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
174 | )
175 | out = out[
176 | :,
177 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
178 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
179 | :,
180 | ]
181 |
182 | out = out.permute(0, 3, 1, 2)
183 | out = out.reshape(
184 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
185 | )
186 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
187 | out = F.conv2d(out, w)
188 | out = out.reshape(
189 | -1,
190 | minor,
191 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
192 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
193 | )
194 | out = out.permute(0, 2, 3, 1)
195 | out = out[:, ::down_y, ::down_x, :]
196 |
197 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
198 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
199 |
200 | return out.view(-1, channel, out_h, out_w)
201 |
--------------------------------------------------------------------------------
/models/networks/op/upfirdn2d_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18 | int c = a / b;
19 |
20 | if (c * b > a) {
21 | c--;
22 | }
23 |
24 | return c;
25 | }
26 |
27 | struct UpFirDn2DKernelParams {
28 | int up_x;
29 | int up_y;
30 | int down_x;
31 | int down_y;
32 | int pad_x0;
33 | int pad_x1;
34 | int pad_y0;
35 | int pad_y1;
36 |
37 | int major_dim;
38 | int in_h;
39 | int in_w;
40 | int minor_dim;
41 | int kernel_h;
42 | int kernel_w;
43 | int out_h;
44 | int out_w;
45 | int loop_major;
46 | int loop_x;
47 | };
48 |
49 | template
50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51 | const scalar_t *kernel,
52 | const UpFirDn2DKernelParams p) {
53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54 | int out_y = minor_idx / p.minor_dim;
55 | minor_idx -= out_y * p.minor_dim;
56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57 | int major_idx_base = blockIdx.z * p.loop_major;
58 |
59 | if (out_x_base >= p.out_w || out_y >= p.out_h ||
60 | major_idx_base >= p.major_dim) {
61 | return;
62 | }
63 |
64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68 |
69 | for (int loop_major = 0, major_idx = major_idx_base;
70 | loop_major < p.loop_major && major_idx < p.major_dim;
71 | loop_major++, major_idx++) {
72 | for (int loop_x = 0, out_x = out_x_base;
73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78 |
79 | const scalar_t *x_p =
80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81 | minor_idx];
82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83 | int x_px = p.minor_dim;
84 | int k_px = -p.up_x;
85 | int x_py = p.in_w * p.minor_dim;
86 | int k_py = -p.up_y * p.kernel_w;
87 |
88 | scalar_t v = 0.0f;
89 |
90 | for (int y = 0; y < h; y++) {
91 | for (int x = 0; x < w; x++) {
92 | v += static_cast(*x_p) * static_cast(*k_p);
93 | x_p += x_px;
94 | k_p += k_px;
95 | }
96 |
97 | x_p += x_py - w * x_px;
98 | k_p += k_py - w * k_px;
99 | }
100 |
101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102 | minor_idx] = v;
103 | }
104 | }
105 | }
106 |
107 | template
109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110 | const scalar_t *kernel,
111 | const UpFirDn2DKernelParams p) {
112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114 |
115 | __shared__ volatile float sk[kernel_h][kernel_w];
116 | __shared__ volatile float sx[tile_in_h][tile_in_w];
117 |
118 | int minor_idx = blockIdx.x;
119 | int tile_out_y = minor_idx / p.minor_dim;
120 | minor_idx -= tile_out_y * p.minor_dim;
121 | tile_out_y *= tile_out_h;
122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123 | int major_idx_base = blockIdx.z * p.loop_major;
124 |
125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126 | major_idx_base >= p.major_dim) {
127 | return;
128 | }
129 |
130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131 | tap_idx += blockDim.x) {
132 | int ky = tap_idx / kernel_w;
133 | int kx = tap_idx - ky * kernel_w;
134 | scalar_t v = 0.0;
135 |
136 | if (kx < p.kernel_w & ky < p.kernel_h) {
137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138 | }
139 |
140 | sk[ky][kx] = v;
141 | }
142 |
143 | for (int loop_major = 0, major_idx = major_idx_base;
144 | loop_major < p.loop_major & major_idx < p.major_dim;
145 | loop_major++, major_idx++) {
146 | for (int loop_x = 0, tile_out_x = tile_out_x_base;
147 | loop_x < p.loop_x & tile_out_x < p.out_w;
148 | loop_x++, tile_out_x += tile_out_w) {
149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151 | int tile_in_x = floor_div(tile_mid_x, up_x);
152 | int tile_in_y = floor_div(tile_mid_y, up_y);
153 |
154 | __syncthreads();
155 |
156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157 | in_idx += blockDim.x) {
158 | int rel_in_y = in_idx / tile_in_w;
159 | int rel_in_x = in_idx - rel_in_y * tile_in_w;
160 | int in_x = rel_in_x + tile_in_x;
161 | int in_y = rel_in_y + tile_in_y;
162 |
163 | scalar_t v = 0.0;
164 |
165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167 | p.minor_dim +
168 | minor_idx];
169 | }
170 |
171 | sx[rel_in_y][rel_in_x] = v;
172 | }
173 |
174 | __syncthreads();
175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176 | out_idx += blockDim.x) {
177 | int rel_out_y = out_idx / tile_out_w;
178 | int rel_out_x = out_idx - rel_out_y * tile_out_w;
179 | int out_x = rel_out_x + tile_out_x;
180 | int out_y = rel_out_y + tile_out_y;
181 |
182 | int mid_x = tile_mid_x + rel_out_x * down_x;
183 | int mid_y = tile_mid_y + rel_out_y * down_y;
184 | int in_x = floor_div(mid_x, up_x);
185 | int in_y = floor_div(mid_y, up_y);
186 | int rel_in_x = in_x - tile_in_x;
187 | int rel_in_y = in_y - tile_in_y;
188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190 |
191 | scalar_t v = 0.0;
192 |
193 | #pragma unroll
194 | for (int y = 0; y < kernel_h / up_y; y++)
195 | #pragma unroll
196 | for (int x = 0; x < kernel_w / up_x; x++)
197 | v += sx[rel_in_y + y][rel_in_x + x] *
198 | sk[kernel_y + y * up_y][kernel_x + x * up_x];
199 |
200 | if (out_x < p.out_w & out_y < p.out_h) {
201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202 | minor_idx] = v;
203 | }
204 | }
205 | }
206 | }
207 | }
208 |
209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210 | const torch::Tensor &kernel, int up_x, int up_y,
211 | int down_x, int down_y, int pad_x0, int pad_x1,
212 | int pad_y0, int pad_y1) {
213 | int curDevice = -1;
214 | cudaGetDevice(&curDevice);
215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
216 |
217 | UpFirDn2DKernelParams p;
218 |
219 | auto x = input.contiguous();
220 | auto k = kernel.contiguous();
221 |
222 | p.major_dim = x.size(0);
223 | p.in_h = x.size(1);
224 | p.in_w = x.size(2);
225 | p.minor_dim = x.size(3);
226 | p.kernel_h = k.size(0);
227 | p.kernel_w = k.size(1);
228 | p.up_x = up_x;
229 | p.up_y = up_y;
230 | p.down_x = down_x;
231 | p.down_y = down_y;
232 | p.pad_x0 = pad_x0;
233 | p.pad_x1 = pad_x1;
234 | p.pad_y0 = pad_y0;
235 | p.pad_y1 = pad_y1;
236 |
237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238 | p.down_y;
239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240 | p.down_x;
241 |
242 | auto out =
243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244 |
245 | int mode = -1;
246 |
247 | int tile_out_h = -1;
248 | int tile_out_w = -1;
249 |
250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251 | p.kernel_h <= 4 && p.kernel_w <= 4) {
252 | mode = 1;
253 | tile_out_h = 16;
254 | tile_out_w = 64;
255 | }
256 |
257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258 | p.kernel_h <= 3 && p.kernel_w <= 3) {
259 | mode = 2;
260 | tile_out_h = 16;
261 | tile_out_w = 64;
262 | }
263 |
264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265 | p.kernel_h <= 4 && p.kernel_w <= 4) {
266 | mode = 3;
267 | tile_out_h = 16;
268 | tile_out_w = 64;
269 | }
270 |
271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272 | p.kernel_h <= 2 && p.kernel_w <= 2) {
273 | mode = 4;
274 | tile_out_h = 16;
275 | tile_out_w = 64;
276 | }
277 |
278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279 | p.kernel_h <= 4 && p.kernel_w <= 4) {
280 | mode = 5;
281 | tile_out_h = 8;
282 | tile_out_w = 32;
283 | }
284 |
285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286 | p.kernel_h <= 2 && p.kernel_w <= 2) {
287 | mode = 6;
288 | tile_out_h = 8;
289 | tile_out_w = 32;
290 | }
291 |
292 | dim3 block_size;
293 | dim3 grid_size;
294 |
295 | if (tile_out_h > 0 && tile_out_w > 0) {
296 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
297 | p.loop_x = 1;
298 | block_size = dim3(32 * 8, 1, 1);
299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301 | (p.major_dim - 1) / p.loop_major + 1);
302 | } else {
303 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
304 | p.loop_x = 4;
305 | block_size = dim3(4, 32, 1);
306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308 | (p.major_dim - 1) / p.loop_major + 1);
309 | }
310 |
311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312 | switch (mode) {
313 | case 1:
314 | upfirdn2d_kernel
315 | <<>>(out.data_ptr(),
316 | x.data_ptr(),
317 | k.data_ptr(), p);
318 |
319 | break;
320 |
321 | case 2:
322 | upfirdn2d_kernel
323 | <<>>(out.data_ptr(),
324 | x.data_ptr(),
325 | k.data_ptr(), p);
326 |
327 | break;
328 |
329 | case 3:
330 | upfirdn2d_kernel
331 | <<>>(out.data_ptr(),
332 | x.data_ptr(),
333 | k.data_ptr(), p);
334 |
335 | break;
336 |
337 | case 4:
338 | upfirdn2d_kernel
339 | <<>>(out.data_ptr(),
340 | x.data_ptr(),
341 | k.data_ptr(), p);
342 |
343 | break;
344 |
345 | case 5:
346 | upfirdn2d_kernel
347 | <<>>(out.data_ptr(),
348 | x.data_ptr(),
349 | k.data_ptr(), p);
350 |
351 | break;
352 |
353 | case 6:
354 | upfirdn2d_kernel
355 | <<>>(out.data_ptr(),
356 | x.data_ptr(),
357 | k.data_ptr(), p);
358 |
359 | break;
360 |
361 | default:
362 | upfirdn2d_kernel_large<<>>(
363 | out.data_ptr(), x.data_ptr(),
364 | k.data_ptr(), p);
365 | }
366 | });
367 |
368 | return out;
369 | }
--------------------------------------------------------------------------------
/models/networks/stylegan2.py:
--------------------------------------------------------------------------------
1 | import math
2 | import pdb
3 | import random
4 | import functools
5 | import operator
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 | from torch.autograd import Function
11 | try:
12 | from models.networks.sync_batchnorm import SynchronizedBatchNorm2d
13 | except:
14 | pass
15 |
16 | from models.networks.base_network import BaseNetwork
17 | from models.networks.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
18 |
19 | #from base_network import BaseNetwork
20 | #from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
21 |
22 |
23 | class PixelNorm(nn.Module):
24 | def __init__(self):
25 | super().__init__()
26 |
27 | def forward(self, input):
28 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
29 |
30 |
31 | def make_kernel(k):
32 | k = torch.tensor(k, dtype=torch.float32)
33 |
34 | if k.ndim == 1:
35 | k = k[None, :] * k[:, None]
36 |
37 | k /= k.sum()
38 |
39 | return k
40 |
41 |
42 | class Upsample(nn.Module):
43 | def __init__(self, kernel, factor=2):
44 | super().__init__()
45 |
46 | self.factor = factor
47 | kernel = make_kernel(kernel) * (factor ** 2)
48 | self.register_buffer("kernel", kernel)
49 |
50 | p = kernel.shape[0] - factor
51 |
52 | pad0 = (p + 1) // 2 + factor - 1
53 | pad1 = p // 2
54 |
55 | self.pad = (pad0, pad1)
56 |
57 | def forward(self, input):
58 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
59 |
60 | return out
61 |
62 |
63 | class Downsample(nn.Module):
64 | def __init__(self, kernel, factor=2):
65 | super().__init__()
66 |
67 | self.factor = factor
68 | kernel = make_kernel(kernel)
69 | self.register_buffer("kernel", kernel)
70 |
71 | p = kernel.shape[0] - factor
72 |
73 | pad0 = (p + 1) // 2
74 | pad1 = p // 2
75 |
76 | self.pad = (pad0, pad1)
77 |
78 | def forward(self, input):
79 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
80 |
81 | return out
82 |
83 |
84 | class Blur(nn.Module):
85 | def __init__(self, kernel, pad, upsample_factor=1):
86 | super().__init__()
87 |
88 | kernel = make_kernel(kernel)
89 |
90 | if upsample_factor > 1:
91 | kernel = kernel * (upsample_factor ** 2)
92 |
93 | self.register_buffer("kernel", kernel)
94 |
95 | self.pad = pad
96 |
97 | def forward(self, input):
98 | out = upfirdn2d(input, self.kernel, pad=self.pad)
99 |
100 | return out
101 |
102 |
103 | class EqualConv2d(nn.Module):
104 | def __init__(
105 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
106 | ):
107 | super().__init__()
108 |
109 | self.weight = nn.Parameter(
110 | torch.randn(out_channel, in_channel, kernel_size, kernel_size)
111 | )
112 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
113 |
114 | self.stride = stride
115 | self.padding = padding
116 |
117 | if bias:
118 | self.bias = nn.Parameter(torch.zeros(out_channel))
119 |
120 | else:
121 | self.bias = None
122 |
123 | def forward(self, input):
124 | out = F.conv2d(
125 | input,
126 | self.weight * self.scale,
127 | bias=self.bias,
128 | stride=self.stride,
129 | padding=self.padding,
130 | )
131 |
132 | return out
133 |
134 | def __repr__(self):
135 | return (
136 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
137 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
138 | )
139 |
140 |
141 | class EqualLinear(nn.Module):
142 | def __init__(
143 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
144 | ):
145 | super().__init__()
146 |
147 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
148 |
149 | if bias:
150 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
151 |
152 | else:
153 | self.bias = None
154 |
155 | self.activation = activation
156 |
157 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul
158 | self.lr_mul = lr_mul
159 |
160 | def forward(self, input):
161 | if self.activation:
162 | out = F.linear(input, self.weight * self.scale)
163 | out = fused_leaky_relu(out, self.bias * self.lr_mul)
164 |
165 | else:
166 | out = F.linear(
167 | input, self.weight * self.scale, bias=self.bias * self.lr_mul
168 | )
169 |
170 | return out
171 |
172 | def __repr__(self):
173 | return (
174 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
175 | )
176 |
177 |
178 | class ModulatedConv2d(nn.Module):
179 | def __init__(
180 | self,
181 | in_channel,
182 | out_channel,
183 | kernel_size,
184 | style_dim,
185 | demodulate=True,
186 | upsample=False,
187 | downsample=False,
188 | blur_kernel=[1, 3, 3, 1],
189 | ):
190 | super().__init__()
191 |
192 | self.eps = 1e-8
193 | self.kernel_size = kernel_size
194 | self.in_channel = in_channel
195 | self.out_channel = out_channel
196 | self.upsample = upsample
197 | self.downsample = downsample
198 |
199 | if upsample:
200 | factor = 2
201 | p = (len(blur_kernel) - factor) - (kernel_size - 1)
202 | pad0 = (p + 1) // 2 + factor - 1
203 | pad1 = p // 2 + 1
204 |
205 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
206 |
207 | if downsample:
208 | factor = 2
209 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
210 | pad0 = (p + 1) // 2
211 | pad1 = p // 2
212 |
213 | self.blur = Blur(blur_kernel, pad=(pad0, pad1))
214 |
215 | fan_in = in_channel * kernel_size ** 2
216 | self.scale = 1 / math.sqrt(fan_in)
217 | self.padding = kernel_size // 2
218 |
219 | self.weight = nn.Parameter(
220 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
221 | )
222 |
223 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
224 |
225 | self.demodulate = demodulate
226 |
227 | def __repr__(self):
228 | return (
229 | f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
230 | f"upsample={self.upsample}, downsample={self.downsample})"
231 | )
232 |
233 | def forward(self, input, style, **kwargs):
234 | batch, in_channel, height, width = input.shape
235 |
236 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
237 | weight = self.scale * self.weight * style
238 |
239 | if self.demodulate:
240 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
241 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
242 |
243 | weight = weight.view(
244 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
245 | )
246 |
247 | if self.upsample:
248 | input = input.view(1, batch * in_channel, height, width)
249 | weight = weight.view(
250 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
251 | )
252 | weight = weight.transpose(1, 2).reshape(
253 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
254 | )
255 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
256 | _, _, height, width = out.shape
257 | out = out.view(batch, self.out_channel, height, width)
258 | out = self.blur(out)
259 |
260 | elif self.downsample:
261 | input = self.blur(input)
262 | _, _, height, width = input.shape
263 | input = input.view(1, batch * in_channel, height, width)
264 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
265 | _, _, height, width = out.shape
266 | out = out.view(batch, self.out_channel, height, width)
267 |
268 | else:
269 | input = input.view(1, batch * in_channel, height, width)
270 | out = F.conv2d(input, weight, padding=self.padding, groups=batch)
271 | _, _, height, width = out.shape
272 | out = out.view(batch, self.out_channel, height, width)
273 |
274 | return out
275 |
276 | def PositionalNorm2d(x, epsilon=1e-5):
277 | # x: B*C*W*H normalize in C dim
278 | mean = x.mean(dim=1, keepdim=True)
279 | std = x.var(dim=1, keepdim=True).add(epsilon).sqrt()
280 | output = (x - mean) / std
281 | return output
282 |
283 | class WeightedConv2d(ModulatedConv2d):
284 | def __init__(
285 | self, *args, **kwargs
286 | ):
287 | num_weight = kwargs.pop("num_weight")
288 | super().__init__(*args, **kwargs)
289 | self.param_free_norm = PositionalNorm2d
290 | ks = 1
291 | pw = ks // 2
292 | # stylein seg or feature
293 | #nhidden = 128
294 | #self.mpl_shared = nn.Sequential(
295 | # EqualConv2d(
296 | # num_weight,
297 | # nhidden,
298 | # 1),
299 | # FusedLeakyReLU(nhidden))
300 | nhidden = num_weight
301 | self.mpl_gamma = EqualConv2d(
302 | nhidden,
303 | self.out_channel,
304 | ks,
305 | padding=pw,
306 | )
307 | self.mpl_beta = EqualConv2d(
308 | nhidden,
309 | self.out_channel,
310 | ks,
311 | padding=pw,
312 | )
313 |
314 | def forward(self, input, style, skip):
315 | out = super().forward(input, style)
316 | # Part 1. generate parameter-free normalized activations
317 | normalized = self.param_free_norm(out)
318 | # Part 2. produce scaling and bias conditioned on semantic map
319 | #hidden = self.mpl_shared(skip)
320 | hidden = skip
321 | gamma = self.mpl_gamma(hidden)
322 | beta = self.mpl_beta(hidden)
323 | # apply scale and bias
324 | out = normalized * (1 + gamma) + beta
325 |
326 | return out
327 |
328 |
329 | class NoiseInjection(nn.Module):
330 | def __init__(self):
331 | super().__init__()
332 |
333 | self.weight = nn.Parameter(torch.zeros(1))
334 |
335 | def forward(self, image, noise=None):
336 | if noise is None:
337 | batch, _, height, width = image.shape
338 | noise = image.new_empty(batch, 1, height, width).normal_()
339 |
340 | return image + self.weight * noise
341 |
342 |
343 | class ConstantInput(nn.Module):
344 | def __init__(self, channel, size=4):
345 | super().__init__()
346 |
347 | self.input = nn.Parameter(torch.randn(1, channel, size, size))
348 |
349 | def forward(self, input):
350 | batch = input.shape[0]
351 | out = self.input.repeat(batch, 1, 1, 1)
352 |
353 | return out
354 |
355 |
356 | class StyledConv(nn.Module):
357 | def __init__(
358 | self,
359 | in_channel,
360 | out_channel,
361 | kernel_size,
362 | style_dim,
363 | upsample=False,
364 | blur_kernel=[1, 3, 3, 1],
365 | demodulate=True,
366 | weightedconv=False,
367 | num_weight=None,
368 | ):
369 | super().__init__()
370 |
371 | if weightedconv:
372 | self.conv = WeightedConv2d(
373 | in_channel,
374 | out_channel,
375 | kernel_size,
376 | style_dim,
377 | upsample=upsample,
378 | blur_kernel=blur_kernel,
379 | demodulate=demodulate,
380 | num_weight=num_weight
381 | )
382 | else:
383 | self.conv = ModulatedConv2d(
384 | in_channel,
385 | out_channel,
386 | kernel_size,
387 | style_dim,
388 | upsample=upsample,
389 | blur_kernel=blur_kernel,
390 | demodulate=demodulate,
391 | )
392 |
393 | self.noise = NoiseInjection()
394 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
395 | # self.activate = ScaledLeakyReLU(0.2)
396 | self.activate = FusedLeakyReLU(out_channel)
397 |
398 | def forward(self, input, style, noise=None, x_skip=None):
399 | out = self.conv(input, style, skip=x_skip)
400 | out = self.noise(out, noise=noise)
401 | # out = out + self.bias
402 | out = self.activate(out)
403 |
404 | return out
405 |
406 | class ConvToRGB(nn.Module):
407 | def __init__(self, in_channel, upsample=True, blur_kernel=[1, 3, 3, 1], out_channel=3):
408 | super().__init__()
409 |
410 | if upsample:
411 | self.upsample = Upsample(blur_kernel)
412 |
413 | self.conv = ConvLayer(in_channel, out_channel, 1)
414 |
415 |
416 | def forward(self, input, skip=None):
417 | out = self.conv(input)
418 |
419 | if skip is not None:
420 | skip = self.upsample(skip)
421 |
422 | out = out + skip
423 |
424 | return out
425 |
426 |
427 | class ToRGB(nn.Module):
428 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1], out_channel=3,weightedconv=False,
429 | num_weight=None):
430 | super().__init__()
431 |
432 | if upsample:
433 | self.upsample = Upsample(blur_kernel)
434 | if weightedconv:
435 | self.conv = WeightedConv2d(
436 | in_channel, out_channel, 1,
437 | style_dim, demodulate=False, num_weight=num_weight)
438 | else:
439 | self.conv = ModulatedConv2d(in_channel, out_channel, 1, style_dim, demodulate=False)
440 | self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
441 |
442 | def forward(self, input, style, skip=None, x_skip=None):
443 | out = self.conv(input, style, skip=x_skip)
444 | out = out + self.bias
445 |
446 | if skip is not None:
447 | skip = self.upsample(skip)
448 |
449 | out = out + skip
450 |
451 | return out
452 |
453 |
454 | class Generator(BaseNetwork):
455 | def __init__(
456 | self,
457 | opt
458 | ):
459 | super().__init__()
460 | size = opt.crop_size
461 | style_dim = opt.z_dim
462 | n_mlp = 8
463 | channel_multiplier=2
464 | blur_kernel=[1, 3, 3, 1]
465 | lr_mlp=0.01
466 |
467 | self.size = size
468 |
469 | self.style_dim = style_dim
470 |
471 | layers = [PixelNorm()]
472 |
473 | for i in range(n_mlp):
474 | layers.append(
475 | EqualLinear(
476 | style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
477 | )
478 | )
479 |
480 | self.style = nn.Sequential(*layers)
481 |
482 | self.channels = {
483 | 4: 512,
484 | 8: 512,
485 | 16: 512,
486 | 32: 512,
487 | 64: 256 * channel_multiplier,
488 | 128: 128 * channel_multiplier,
489 | 256: 64 * channel_multiplier,
490 | 512: 32 * channel_multiplier,
491 | 1024: 16 * channel_multiplier,
492 | }
493 |
494 | self.input = ConstantInput(self.channels[4])
495 | self.conv1 = StyledConv(
496 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
497 | )
498 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
499 |
500 | self.log_size = int(math.log(size, 2))
501 | self.num_layers = (self.log_size - 2) * 2 + 1
502 |
503 | self.convs = nn.ModuleList()
504 | self.upsamples = nn.ModuleList()
505 | self.to_rgbs = nn.ModuleList()
506 | self.noises = nn.Module()
507 |
508 | in_channel = self.channels[4]
509 |
510 | for layer_idx in range(self.num_layers):
511 | res = (layer_idx + 5) // 2
512 | shape = [1, 1, 2 ** res, 2 ** res]
513 | self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
514 |
515 | for i in range(3, self.log_size + 1):
516 | out_channel = self.channels[2 ** i]
517 |
518 | self.convs.append(
519 | StyledConv(
520 | in_channel,
521 | out_channel,
522 | 3,
523 | style_dim,
524 | upsample=True,
525 | blur_kernel=blur_kernel,
526 | )
527 | )
528 |
529 | self.convs.append(
530 | StyledConv(
531 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
532 | )
533 | )
534 |
535 | self.to_rgbs.append(ToRGB(out_channel, style_dim))
536 |
537 | in_channel = out_channel
538 |
539 | self.n_latent = self.log_size * 2 - 2
540 |
541 | def make_noise(self):
542 | device = self.input.input.device
543 |
544 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
545 |
546 | for i in range(3, self.log_size + 1):
547 | for _ in range(2):
548 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
549 |
550 | return noises
551 |
552 | def mean_latent(self, n_latent):
553 | latent_in = torch.randn(
554 | n_latent, self.style_dim, device=self.input.input.device
555 | )
556 | latent = self.style(latent_in).mean(0, keepdim=True)
557 |
558 | return latent
559 |
560 | def get_latent(self, input):
561 | return self.style(input)
562 |
563 | def forward(
564 | self,
565 | styles,
566 | return_latents=False,
567 | inject_index=None,
568 | truncation=None,
569 | truncation_latent=None,
570 | input_is_latent=False,
571 | noise=None,
572 | randomize_noise=True,
573 | get_latent=False,
574 | ):
575 | if not input_is_latent:
576 | styles = [self.style(s) for s in styles]
577 | if get_latent:
578 | return styles
579 |
580 | if noise is None:
581 | if randomize_noise:
582 | noise = [None] * self.num_layers
583 | else:
584 | noise = [
585 | getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
586 | ]
587 |
588 | if truncation is not None:
589 | assert 025}: {:<30}{}\n'.format(str(k), str(v), comment)
103 | message += '----------------- End -------------------'
104 | print(message)
105 |
106 | def option_file_path(self, opt, makedir=False):
107 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
108 | if makedir:
109 | util.mkdirs(expr_dir)
110 | file_name = os.path.join(expr_dir, 'opt')
111 | return file_name
112 |
113 | def save_options(self, opt):
114 | file_name = self.option_file_path(opt, makedir=True)
115 | with open(file_name + '.txt', 'wt') as opt_file:
116 | for k, v in sorted(vars(opt).items()):
117 | comment = ''
118 | default = self.parser.get_default(k)
119 | if v != default:
120 | comment = '\t[default: %s]' % str(default)
121 | opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment))
122 |
123 | with open(file_name + '.pkl', 'wb') as opt_file:
124 | pickle.dump(opt, opt_file)
125 |
126 | def update_options_from_file(self, parser, opt):
127 | new_opt = self.load_options(opt)
128 | for k, v in sorted(vars(opt).items()):
129 | if hasattr(new_opt, k) and v != getattr(new_opt, k):
130 | new_val = getattr(new_opt, k)
131 | parser.set_defaults(**{k: new_val})
132 | return parser
133 |
134 | def load_options(self, opt):
135 | file_name = self.option_file_path(opt, makedir=False)
136 | new_opt = pickle.load(open(file_name + '.pkl', 'rb'))
137 | return new_opt
138 |
139 | def parse(self, save=False):
140 |
141 | opt = self.gather_options()
142 | opt.isTrain = self.isTrain # train or test
143 |
144 | self.print_options(opt)
145 | if opt.isTrain:
146 | self.save_options(opt)
147 |
148 | # set gpu ids
149 | str_ids = opt.gpu_ids.split(',')
150 | opt.gpu_ids = []
151 | for str_id in str_ids:
152 | id = int(str_id)
153 | if id >= 0:
154 | opt.gpu_ids.append(id)
155 | if len(opt.gpu_ids) > 0:
156 | torch.cuda.set_device(opt.gpu_ids[0])
157 |
158 | assert len(opt.gpu_ids) == 0 or opt.batchSize % len(opt.gpu_ids) == 0, \
159 | "Batch size %d is wrong. It must be a multiple of # GPUs %d." \
160 | % (opt.batchSize, len(opt.gpu_ids))
161 |
162 | self.opt = opt
163 | return self.opt
164 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | from .base_options import BaseOptions
7 |
8 |
9 | class TestOptions(BaseOptions):
10 | def initialize(self, parser):
11 | BaseOptions.initialize(self, parser)
12 | parser.add_argument('--dataset_mode', type=str, default='coco')
13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
14 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
15 | parser.add_argument('--how_many', type=int, default=float("inf"), help='how many test images to run')
16 |
17 | parser.set_defaults(preprocess_mode='scale_width_and_crop', crop_size=256, load_size=256, display_winsize=256)
18 | parser.set_defaults(serial_batches=True)
19 | parser.set_defaults(no_flip=True)
20 | parser.set_defaults(phase='test')
21 | self.isTrain = False
22 | return parser
23 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | from .base_options import BaseOptions
7 |
8 |
9 | class TrainOptions(BaseOptions):
10 | def initialize(self, parser):
11 | BaseOptions.initialize(self, parser)
12 | parser.add_argument('--save_remote_gs', type=str, required=False)
13 | parser.add_argument('--trainer', type=str, default='stylegan2')
14 | # for displays
15 | parser.add_argument('--display_freq', type=int, default=101, help='frequency of showing training results on screen')
16 | parser.add_argument('--print_freq', type=int, default=101, help='frequency of showing training results on console')
17 | parser.add_argument('--save_latest_freq', type=int, default=50000, help='frequency of saving the latest results')
18 | parser.add_argument('--validation_freq', type=int, default=50000, help='frequency of saving the latest results')
19 | parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs')
20 | parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration')
21 | # datast
22 | parser.add_argument('--dataset_mode_train', type=str, default='coco')
23 | parser.add_argument('--dataset_mode', type=str, default='coco')
24 | parser.add_argument('--dataset_mode_val', type=str, required=False)
25 |
26 | # for training
27 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
28 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
29 | parser.add_argument('--niter', type=int, default=50, help='# of iter at starting learning rate. This is NOT the total #epochs. Totla #epochs is niter + niter_decay')
30 | parser.add_argument('--niter_decay', type=int, default=0, help='# of iter to linearly decay learning rate to zero')
31 | parser.add_argument('--optimizer', type=str, default='adam')
32 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
33 | parser.add_argument('--D_steps_per_G', type=int, default=1, help='number of discriminator iterations per generator iterations.')
34 |
35 | # for discriminators
36 | parser.add_argument('--lambda_vgg', type=float, default=10.0, help='weight for vgg loss')
37 | parser.add_argument('--lambda_l1', type=float, default=1.0, help='weight for l1 loss')
38 | parser.add_argument('--no_l1_loss', action='store_true', help='if specified, do *not* use l1 loss')
39 | parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss')
40 | parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)')
41 | parser.add_argument('--netD', type=str, default='comodgan')
42 | parser.add_argument('--freeze_D', action='store_true', help='do not update D')
43 | self.isTrain = True
44 | return parser
45 |
--------------------------------------------------------------------------------
/output/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zengxianyu/co-mod-gan-pytorch/a178bfd15d2675214532151037ef8f6e34b3fd91/output/.gitkeep
--------------------------------------------------------------------------------
/save_remote_gs.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pdb
3 | from datetime import datetime
4 |
5 | def init_remote(opt):
6 | os.system(f"rm -rf output/{opt.name}")
7 | cwd = os.getcwd()
8 | os.system(f"gsutil cp -r {opt.save_remote_gs}/{opt.name} ./output/")
9 | if os.path.exists(f"output/{opt.name}/iter.txt") and not os.path.exists(f"checkpoints/{opt.name}/iter.txt"):
10 | os.system(f"cp output/{opt.name}/latest_net_*.pth checkpoints/{opt.name}/")
11 | os.system(f"cp output/{opt.name}/iter.txt checkpoints/{opt.name}/")
12 |
13 | def upload_remote(opt):
14 | os.system(f"gsutil cp -r {opt.save_remote_gs}/{opt.name}/savemodel ./output/{opt.name}")
15 | now = datetime.now()
16 | dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
17 | os.system(f"cp output/{opt.name}.html output/{opt.name}/")
18 | os.system(f"cp checkpoints/{opt.name}/opt.txt output/{opt.name}/")
19 | os.system(f"cp checkpoints/{opt.name}/iter.txt output/{opt.name}/")
20 | os.system(f"echo {dt_string} > output/{opt.name}/time.txt")
21 | with open(f"output/{opt.name}/savemodel","r") as f:
22 | line = f.readlines()
23 | if line[0].startswith("y"):
24 | os.system(f"gsutil cp -r ./checkpoints/{opt.name}/latest_net_*.pth {opt.save_remote_gs}/{opt.name}/")
25 | with open(f"output/{opt.name}/savemodel", "w") as f:
26 | f.writelines("n")
27 | os.system(f"gsutil cp -r ./output/{opt.name} {opt.save_remote_gs}/")
28 |
29 |
30 |
31 | if __name__ == "__main__":
32 | class Temp:
33 | pass
34 | opt = Temp()
35 | opt.save_remote_gs = "gs://zengxianyu"
36 | opt.name = "cline"
37 | init_remote(opt)
38 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import pdb
7 | import cv2
8 | import os
9 | from collections import OrderedDict
10 | import json
11 | from tqdm import tqdm
12 | import numpy as np
13 | import torch
14 | import data
15 | from options.test_options import TestOptions
16 | #from models.pix2pix_model import Pix2PixModel
17 | import models
18 |
19 |
20 | opt = TestOptions().parse()
21 |
22 | dataloader = data.create_dataloader(opt)
23 |
24 | model = models.create_model(opt)
25 | model.eval()
26 |
27 | for i, data_i in tqdm(enumerate(dataloader)):
28 | if i * opt.batchSize >= opt.how_many:
29 | break
30 | with torch.no_grad():
31 | generated,_ = model(data_i, mode='inference')
32 | generated = torch.clamp(generated, -1, 1)
33 | generated = (generated+1)/2*255
34 | generated = generated.cpu().numpy().astype(np.uint8)
35 | img_path = data_i['path']
36 | for b in range(generated.shape[0]):
37 | pred_im = generated[b].transpose((1,2,0))
38 | print('process image... %s' % img_path[b])
39 | cv2.imwrite(img_path[b], pred_im[:,:,::-1])
40 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | python test.py \
2 | --mixing 0 \
3 | --batchSize 1 \
4 | --nThreads 1 \
5 | --name comod-ffhq-512 \
6 | --dataset_mode testimage \
7 | --image_dir ./ffhq_debug/images \
8 | --mask_dir ./ffhq_debug/masks \
9 | --output_dir ./ffhq_debug \
10 | --load_size 512 \
11 | --crop_size 512 \
12 | --z_dim 512 \
13 | --model comod \
14 | --netG comodgan \
15 | --which_epoch co-mod-gan-ffhq-9-025000 \
16 | ${EXTRA} \
17 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 | import pdb
6 | import sys
7 | import torch
8 | import numpy as np
9 | from collections import OrderedDict
10 | from options.train_options import TrainOptions
11 | import data
12 | from util.iter_counter import IterationCounter
13 | from logger import Logger
14 | from torchvision.utils import make_grid
15 | from trainers import create_trainer
16 | from save_remote_gs import init_remote, upload_remote
17 | from models.networks.sync_batchnorm import DataParallelWithCallback
18 | from pytorch_fid.fid_model import FIDModel
19 |
20 | # parse options
21 | opt = TrainOptions().parse()
22 |
23 | # fid
24 | fid_model = FIDModel().cuda()
25 | fid_model.model = DataParallelWithCallback(
26 | fid_model.model,
27 | device_ids=opt.gpu_ids)
28 |
29 |
30 | # load remote
31 | if opt.save_remote_gs is not None:
32 | init_remote(opt)
33 |
34 | # print options to help debugging
35 | print(' '.join(sys.argv))
36 |
37 | # load the dataset
38 | if opt.dataset_mode_val is not None:
39 | dataloader_train, dataloader_val = data.create_dataloader_trainval(opt)
40 | else:
41 | dataloader_train = data.create_dataloader(opt)
42 | dataloader_val = None
43 |
44 | # create trainer for our model
45 | trainer = create_trainer(opt)
46 | model = trainer.pix2pix_model
47 |
48 | # create tool for counting iterations
49 | iter_counter = IterationCounter(opt, len(dataloader_train))
50 |
51 | # create tool for visualization
52 | writer = Logger(f"output/{opt.name}")
53 | with open(f"output/{opt.name}/savemodel", "w") as f:
54 | f.writelines("n")
55 |
56 | trainer.save('latest')
57 |
58 | def get_psnr(generated, gt):
59 | generated = (generated+1)/2*255
60 | bsize, c, h, w = gt.shape
61 | gt = (gt+1)/2*255
62 | mse = ((generated-gt)**2).sum(3).sum(2).sum(1)
63 | mse /= (c*h*w)
64 | psnr = 10*torch.log10(255.0*255.0 / (mse+1e-8))
65 | return psnr.sum().item()
66 |
67 | def display_batch(epoch, data_i):
68 | losses = trainer.get_latest_losses()
69 | for k,v in losses.items():
70 | writer.add_scalar(k,v.mean().item(), iter_counter.total_steps_so_far)
71 | writer.write_console(epoch, iter_counter.epoch_iter, iter_counter.time_per_iter)
72 | num_print = min(4, data_i['image'].size(0))
73 | writer.add_single_image('inputs',
74 | (make_grid(trainer.get_latest_inputs()[:num_print])+1)/2,
75 | iter_counter.total_steps_so_far)
76 | infer_out,inp = trainer.pix2pix_model.forward(data_i, mode='inference')
77 | vis = (make_grid(inp[:num_print])+1)/2
78 | writer.add_single_image('infer_in',
79 | vis,
80 | iter_counter.total_steps_so_far)
81 | vis = (make_grid(infer_out[:num_print])+1)/2
82 | vis = torch.clamp(vis, 0,1)
83 | writer.add_single_image('infer_out',
84 | vis,
85 | iter_counter.total_steps_so_far)
86 | generated = trainer.get_latest_generated()
87 | for k,v in generated.items():
88 | if v is None:
89 | continue
90 | if 'label' in k:
91 | vis = make_grid(v[:num_print].expand(-1,3,-1,-1))[0]
92 | writer.add_single_label(k,
93 | vis,
94 | iter_counter.total_steps_so_far)
95 | else:
96 | if v.size(1) == 3:
97 | vis = (make_grid(v[:num_print])+1)/2
98 | vis = torch.clamp(vis, 0,1)
99 | else:
100 | vis = make_grid(v[:num_print])
101 | writer.add_single_image(k,
102 | vis,
103 | iter_counter.total_steps_so_far)
104 | writer.write_html()
105 |
106 | for epoch in iter_counter.training_epochs():
107 | iter_counter.record_epoch_start(epoch)
108 | for i, data_i in enumerate(dataloader_train, start=iter_counter.epoch_iter):
109 | iter_counter.record_one_iteration()
110 | # train discriminator
111 | if not opt.freeze_D:
112 | trainer.run_discriminator_one_step(data_i, i)
113 |
114 | # Training
115 | # train generator
116 | if i % opt.D_steps_per_G == 0:
117 | trainer.run_generator_one_step(data_i, i)
118 |
119 | if iter_counter.needs_displaying():
120 | display_batch(epoch, data_i)
121 | if opt.save_remote_gs is not None and iter_counter.needs_saving():
122 | upload_remote(opt)
123 | if iter_counter.needs_validation():
124 | print('saving the latest model (epoch %d, total_steps %d)' %
125 | (epoch, iter_counter.total_steps_so_far))
126 | trainer.save('epoch%d_step%d'%
127 | (epoch, iter_counter.total_steps_so_far))
128 | trainer.save('latest')
129 | iter_counter.record_current_iter()
130 | if dataloader_val is not None:
131 | print("doing validation")
132 | model.eval()
133 | num = 0
134 | psnr_total = 0
135 | for ii, data_ii in enumerate(dataloader_val):
136 | with torch.no_grad():
137 | generated,_ = model(data_ii, mode='inference')
138 | generated = generated.cpu()
139 | gt = data_ii['image']
140 | bsize = gt.size(0)
141 | psnr = get_psnr(generated, gt)
142 | psnr_total += psnr
143 | num += bsize
144 | fid_model.add_sample((generated+1)/2,(gt+1)/2)
145 | psnr_total /= num
146 | fid = fid_model.calculate_activation_statistics()
147 | writer.add_scalar("val.fid", fid, iter_counter.total_steps_so_far)
148 | writer.write_scalar("val.fid", fid, iter_counter.total_steps_so_far)
149 | writer.add_scalar("val.psnr", psnr_total, iter_counter.total_steps_so_far)
150 | writer.write_scalar("val.psnr", psnr_total, iter_counter.total_steps_so_far)
151 | writer.write_html()
152 | model.train()
153 | trainer.update_learning_rate(epoch)
154 | if epoch != 0 and epoch % 3 == 0 and opt.dataset_mode_train == 'cocomaskupdate':
155 | dataloader_train.dataset.update_dataset()
156 | iter_counter.record_epoch_end()
157 |
158 | print('Training was successfully finished.')
159 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | export CXX="g++"
2 | python train.py \
3 | --batchSize 2 \
4 | --nThreads 2 \
5 | --name comod_places \
6 | --train_image_dir ./datasets/places2sample1k_val/places2samples1k_crop256 \
7 | --train_image_list ./datasets/places2sample1k_val/files.txt \
8 | --train_image_postfix '.jpg' \
9 | --val_image_dir ./datasets/places2sample1k_val/places2samples1k_crop256 \
10 | --val_image_list ./datasets/places2sample1k_val/files.txt \
11 | --val_mask_dir ./datasets/places2sample1k_val/places2samples1k_256_mask_square128 \
12 | --load_size 512 \
13 | --crop_size 256 \
14 | --z_dim 512 \
15 | --validation_freq 10000 \
16 | --niter 50 \
17 | --dataset_mode trainimage \
18 | --trainer stylegan2 \
19 | --dataset_mode_train trainimage \
20 | --dataset_mode_val valimage \
21 | --model comod \
22 | --netG comodgan \
23 | --netD comodgan \
24 | --no_l1_loss \
25 | --no_vgg_loss \
26 | --preprocess_mode scale_shortside_and_crop \
27 | $EXTRA
28 |
--------------------------------------------------------------------------------
/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | def find_trainer_using_name(model_name):
4 | model_filename = "trainers." + model_name + "_trainer"
5 | modellib = importlib.import_module(model_filename)
6 |
7 | # In the file, the class called ModelNameModel() will
8 | # be instantiated. It has to be a subclass of torch.nn.Module,
9 | # and it is case-insensitive.
10 | model = None
11 | target_model_name = model_name.replace('_', '') + 'trainer'
12 | for name, cls in modellib.__dict__.items():
13 | if name.lower() == target_model_name.lower():
14 | model = cls
15 |
16 | if model is None:
17 | print("In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name))
18 | exit(0)
19 |
20 | return model
21 |
22 |
23 | def create_trainer(opt):
24 | model = find_trainer_using_name(opt.trainer)
25 | instance = model(opt)
26 | print("model [%s] was created" % (type(instance).__name__))
27 |
28 | return instance
29 |
--------------------------------------------------------------------------------
/trainers/stylegan2_trainer.py:
--------------------------------------------------------------------------------
1 | import pdb
2 | import torch
3 | from models.networks.sync_batchnorm import DataParallelWithCallback
4 | import models
5 | #from models.pix2pix_model import Pix2PixModel
6 |
7 |
8 | class StyleGAN2Trainer():
9 | def __init__(self, opt):
10 | self.opt = opt
11 | self.pix2pix_model = models.create_model(opt)
12 | if len(opt.gpu_ids) > 0:
13 | self.pix2pix_model = DataParallelWithCallback(self.pix2pix_model,
14 | device_ids=opt.gpu_ids)
15 | self.pix2pix_model_on_one_gpu = self.pix2pix_model.module
16 | else:
17 | self.pix2pix_model_on_one_gpu = self.pix2pix_model
18 |
19 | self.generated = None
20 | self.inputs = None
21 | self.mean_path_length = torch.Tensor([0])
22 | if opt.isTrain:
23 | self.optimizer_G, self.optimizer_D = \
24 | self.pix2pix_model_on_one_gpu.create_optimizers(opt)
25 | self.old_lr = opt.lr
26 |
27 | def run_generator_one_step(self, data, i):
28 | self.optimizer_G.zero_grad()
29 | g_losses, inputs, generated = self.pix2pix_model(data, mode='generator')
30 | g_loss = sum(g_losses.values()).mean()
31 | g_loss.backward()
32 | self.optimizer_G.step()
33 | self.g_losses = g_losses
34 | self.generated = generated
35 | self.inputs = inputs
36 | g_regularize = (i % self.opt.g_reg_every == 0) and not (self.opt.no_g_reg)
37 | if g_regularize:
38 | self.optimizer_G.zero_grad()
39 | bsize = data['image'].size(0)
40 | data['mean_path_length'] = self.mean_path_length.expand(bsize)
41 | g_regs, self.mean_path_length \
42 | = self.pix2pix_model(data, mode='g_reg')
43 | g_reg = sum(g_regs.values()).mean()
44 | g_reg.backward()
45 | self.optimizer_G.step()
46 | self.g_losses = {
47 | **g_losses,
48 | **g_regs}
49 | bsize = inputs.size(0)
50 | accum = 0.5 ** (bsize / (10 * 1000)) # 32
51 | self.pix2pix_model_on_one_gpu.accumulate(accum)
52 |
53 | def run_discriminator_one_step(self, data, i):
54 | self.optimizer_D.zero_grad()
55 | d_losses_real = self.pix2pix_model(data, mode='dreal')
56 | d_loss_real = sum(d_losses_real.values()).mean()
57 | d_loss_real.backward()
58 | d_losses_fake = self.pix2pix_model(data, mode='dfake')
59 | d_loss_fake = sum(d_losses_fake.values()).mean()
60 | d_loss_fake.backward()
61 | self.d_losses = {
62 | **d_losses_real,
63 | **d_losses_fake}
64 | self.optimizer_D.step()
65 | d_regularize = i % self.opt.d_reg_every == 0
66 | if d_regularize:
67 | self.optimizer_D.zero_grad()
68 | d_regs = self.pix2pix_model(data, mode='d_reg')
69 | d_reg = sum(d_regs.values()).mean()
70 | d_reg.backward()
71 | self.optimizer_D.step()
72 | self.d_losses = {
73 | **self.d_losses,
74 | **d_regs}
75 |
76 | def get_latest_losses(self):
77 | if not self.opt.freeze_D:
78 | return {**self.g_losses, **self.d_losses}
79 | else:
80 | return self.g_losses
81 |
82 | def get_latest_generated(self):
83 | return self.generated
84 | def get_latest_inputs(self):
85 | return self.inputs
86 |
87 | def update_learning_rate(self, epoch):
88 | self.update_learning_rate(epoch)
89 |
90 | def save(self, epoch):
91 | self.pix2pix_model_on_one_gpu.save(epoch)
92 |
93 | ##################################################################
94 | # Helper functions
95 | ##################################################################
96 |
97 | def update_learning_rate(self, epoch):
98 | if epoch > self.opt.niter:
99 | lrd = self.opt.lr / self.opt.niter_decay
100 | new_lr = self.old_lr - lrd
101 | else:
102 | new_lr = self.old_lr
103 |
104 | if new_lr != self.old_lr:
105 | if self.opt.no_TTUR:
106 | new_lr_G = new_lr
107 | new_lr_D = new_lr
108 | else:
109 | new_lr_G = new_lr / 2
110 | new_lr_D = new_lr * 2
111 |
112 | for param_group in self.optimizer_D.param_groups:
113 | param_group['lr'] = new_lr_D
114 | for param_group in self.optimizer_G.param_groups:
115 | param_group['lr'] = new_lr_G
116 | print('update learning rate: %f -> %f' % (self.old_lr, new_lr))
117 | self.old_lr = new_lr
118 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
--------------------------------------------------------------------------------
/util/coco.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 |
7 | def id2label(id):
8 | if id == 182:
9 | id = 0
10 | else:
11 | id = id + 1
12 | labelmap = \
13 | {0: 'unlabeled',
14 | 1: 'person',
15 | 2: 'bicycle',
16 | 3: 'car',
17 | 4: 'motorcycle',
18 | 5: 'airplane',
19 | 6: 'bus',
20 | 7: 'train',
21 | 8: 'truck',
22 | 9: 'boat',
23 | 10: 'traffic light',
24 | 11: 'fire hydrant',
25 | 12: 'street sign',
26 | 13: 'stop sign',
27 | 14: 'parking meter',
28 | 15: 'bench',
29 | 16: 'bird',
30 | 17: 'cat',
31 | 18: 'dog',
32 | 19: 'horse',
33 | 20: 'sheep',
34 | 21: 'cow',
35 | 22: 'elephant',
36 | 23: 'bear',
37 | 24: 'zebra',
38 | 25: 'giraffe',
39 | 26: 'hat',
40 | 27: 'backpack',
41 | 28: 'umbrella',
42 | 29: 'shoe',
43 | 30: 'eye glasses',
44 | 31: 'handbag',
45 | 32: 'tie',
46 | 33: 'suitcase',
47 | 34: 'frisbee',
48 | 35: 'skis',
49 | 36: 'snowboard',
50 | 37: 'sports ball',
51 | 38: 'kite',
52 | 39: 'baseball bat',
53 | 40: 'baseball glove',
54 | 41: 'skateboard',
55 | 42: 'surfboard',
56 | 43: 'tennis racket',
57 | 44: 'bottle',
58 | 45: 'plate',
59 | 46: 'wine glass',
60 | 47: 'cup',
61 | 48: 'fork',
62 | 49: 'knife',
63 | 50: 'spoon',
64 | 51: 'bowl',
65 | 52: 'banana',
66 | 53: 'apple',
67 | 54: 'sandwich',
68 | 55: 'orange',
69 | 56: 'broccoli',
70 | 57: 'carrot',
71 | 58: 'hot dog',
72 | 59: 'pizza',
73 | 60: 'donut',
74 | 61: 'cake',
75 | 62: 'chair',
76 | 63: 'couch',
77 | 64: 'potted plant',
78 | 65: 'bed',
79 | 66: 'mirror',
80 | 67: 'dining table',
81 | 68: 'window',
82 | 69: 'desk',
83 | 70: 'toilet',
84 | 71: 'door',
85 | 72: 'tv',
86 | 73: 'laptop',
87 | 74: 'mouse',
88 | 75: 'remote',
89 | 76: 'keyboard',
90 | 77: 'cell phone',
91 | 78: 'microwave',
92 | 79: 'oven',
93 | 80: 'toaster',
94 | 81: 'sink',
95 | 82: 'refrigerator',
96 | 83: 'blender',
97 | 84: 'book',
98 | 85: 'clock',
99 | 86: 'vase',
100 | 87: 'scissors',
101 | 88: 'teddy bear',
102 | 89: 'hair drier',
103 | 90: 'toothbrush',
104 | 91: 'hair brush', # Last class of Thing
105 | 92: 'banner', # Beginning of Stuff
106 | 93: 'blanket',
107 | 94: 'branch',
108 | 95: 'bridge',
109 | 96: 'building-other',
110 | 97: 'bush',
111 | 98: 'cabinet',
112 | 99: 'cage',
113 | 100: 'cardboard',
114 | 101: 'carpet',
115 | 102: 'ceiling-other',
116 | 103: 'ceiling-tile',
117 | 104: 'cloth',
118 | 105: 'clothes',
119 | 106: 'clouds',
120 | 107: 'counter',
121 | 108: 'cupboard',
122 | 109: 'curtain',
123 | 110: 'desk-stuff',
124 | 111: 'dirt',
125 | 112: 'door-stuff',
126 | 113: 'fence',
127 | 114: 'floor-marble',
128 | 115: 'floor-other',
129 | 116: 'floor-stone',
130 | 117: 'floor-tile',
131 | 118: 'floor-wood',
132 | 119: 'flower',
133 | 120: 'fog',
134 | 121: 'food-other',
135 | 122: 'fruit',
136 | 123: 'furniture-other',
137 | 124: 'grass',
138 | 125: 'gravel',
139 | 126: 'ground-other',
140 | 127: 'hill',
141 | 128: 'house',
142 | 129: 'leaves',
143 | 130: 'light',
144 | 131: 'mat',
145 | 132: 'metal',
146 | 133: 'mirror-stuff',
147 | 134: 'moss',
148 | 135: 'mountain',
149 | 136: 'mud',
150 | 137: 'napkin',
151 | 138: 'net',
152 | 139: 'paper',
153 | 140: 'pavement',
154 | 141: 'pillow',
155 | 142: 'plant-other',
156 | 143: 'plastic',
157 | 144: 'platform',
158 | 145: 'playingfield',
159 | 146: 'railing',
160 | 147: 'railroad',
161 | 148: 'river',
162 | 149: 'road',
163 | 150: 'rock',
164 | 151: 'roof',
165 | 152: 'rug',
166 | 153: 'salad',
167 | 154: 'sand',
168 | 155: 'sea',
169 | 156: 'shelf',
170 | 157: 'sky-other',
171 | 158: 'skyscraper',
172 | 159: 'snow',
173 | 160: 'solid-other',
174 | 161: 'stairs',
175 | 162: 'stone',
176 | 163: 'straw',
177 | 164: 'structural-other',
178 | 165: 'table',
179 | 166: 'tent',
180 | 167: 'textile-other',
181 | 168: 'towel',
182 | 169: 'tree',
183 | 170: 'vegetable',
184 | 171: 'wall-brick',
185 | 172: 'wall-concrete',
186 | 173: 'wall-other',
187 | 174: 'wall-panel',
188 | 175: 'wall-stone',
189 | 176: 'wall-tile',
190 | 177: 'wall-wood',
191 | 178: 'water-other',
192 | 179: 'waterdrops',
193 | 180: 'window-blind',
194 | 181: 'window-other',
195 | 182: 'wood'}
196 | if id in labelmap:
197 | return labelmap[id]
198 | else:
199 | return 'unknown'
200 |
--------------------------------------------------------------------------------
/util/html.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import datetime
7 | import dominate
8 | from dominate.tags import *
9 | import os
10 |
11 |
12 | class HTML:
13 | def __init__(self, web_dir, title, refresh=0):
14 | if web_dir.endswith('.html'):
15 | web_dir, html_name = os.path.split(web_dir)
16 | else:
17 | web_dir, html_name = web_dir, 'index.html'
18 | self.title = title
19 | self.web_dir = web_dir
20 | self.html_name = html_name
21 | self.img_dir = os.path.join(self.web_dir, 'images')
22 | if len(self.web_dir) > 0 and not os.path.exists(self.web_dir):
23 | os.makedirs(self.web_dir)
24 | if len(self.web_dir) > 0 and not os.path.exists(self.img_dir):
25 | os.makedirs(self.img_dir)
26 |
27 | self.doc = dominate.document(title=title)
28 | with self.doc:
29 | h1(datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y"))
30 | if refresh > 0:
31 | with self.doc.head:
32 | meta(http_equiv="refresh", content=str(refresh))
33 |
34 | def get_image_dir(self):
35 | return self.img_dir
36 |
37 | def add_header(self, str):
38 | with self.doc:
39 | h3(str)
40 |
41 | def add_table(self, border=1):
42 | self.t = table(border=border, style="table-layout: fixed;")
43 | self.doc.add(self.t)
44 |
45 | def add_images(self, ims, txts, links, width=512):
46 | self.add_table()
47 | with self.t:
48 | with tr():
49 | for im, txt, link in zip(ims, txts, links):
50 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
51 | with p():
52 | with a(href=os.path.join('images', link)):
53 | img(style="width:%dpx" % (width), src=os.path.join('images', im))
54 | br()
55 | p(txt.encode('utf-8'))
56 |
57 | def save(self):
58 | html_file = os.path.join(self.web_dir, self.html_name)
59 | f = open(html_file, 'wt')
60 | f.write(self.doc.render())
61 | f.close()
62 |
63 |
64 | if __name__ == '__main__':
65 | html = HTML('web/', 'test_html')
66 | html.add_header('hello world')
67 |
68 | ims = []
69 | txts = []
70 | links = []
71 | for n in range(4):
72 | ims.append('image_%d.jpg' % n)
73 | txts.append('text_%d' % n)
74 | links.append('image_%d.jpg' % n)
75 | html.add_images(ims, txts, links)
76 | html.save()
77 |
--------------------------------------------------------------------------------
/util/iter_counter.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import pdb
7 | import os
8 | import time
9 | import numpy as np
10 |
11 |
12 | # Helper class that keeps track of training iterations
13 | class IterationCounter():
14 | def __init__(self, opt, dataset_size):
15 | self.opt = opt
16 | self.dataset_size = dataset_size
17 |
18 | self.first_epoch = 1
19 | self.total_epochs = opt.niter + opt.niter_decay
20 | self.epoch_iter = 0 # iter number within each epoch
21 | self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, 'iter.txt')
22 | if opt.isTrain and opt.continue_train:
23 | try:
24 | self.first_epoch, self.epoch_iter = np.loadtxt(
25 | self.iter_record_path, delimiter=',', dtype=int)
26 | print('Resuming from epoch %d at iteration %d' % (self.first_epoch, self.epoch_iter))
27 | except:
28 | print('Could not load iteration record at %s. Starting from beginning.' %
29 | self.iter_record_path)
30 |
31 | self.total_steps_so_far = (self.first_epoch - 1) * dataset_size + self.epoch_iter
32 |
33 | # return the iterator of epochs for the training
34 | def training_epochs(self):
35 | return range(self.first_epoch, self.total_epochs + 1)
36 |
37 | def record_epoch_start(self, epoch):
38 | self.epoch_start_time = time.time()
39 | self.last_iter_time = time.time()
40 | self.current_epoch = epoch
41 |
42 | def record_one_iteration(self):
43 | current_time = time.time()
44 |
45 | # the last remaining batch is dropped (see data/__init__.py),
46 | # so we can assume batch size is always opt.batchSize
47 | self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize
48 | self.last_iter_time = current_time
49 | self.total_steps_so_far += self.opt.batchSize
50 | self.epoch_iter += self.opt.batchSize
51 |
52 | def record_epoch_end(self):
53 | self.epoch_iter = 0
54 | current_time = time.time()
55 | self.time_per_epoch = current_time - self.epoch_start_time
56 | print('End of epoch %d / %d \t Time Taken: %d sec' %
57 | (self.current_epoch, self.total_epochs, self.time_per_epoch))
58 | if self.current_epoch % self.opt.save_epoch_freq == 0:
59 | np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0),
60 | delimiter=',', fmt='%d')
61 | print('Saved current iteration count at %s.' % self.iter_record_path)
62 |
63 | def record_current_iter(self):
64 | np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter),
65 | delimiter=',', fmt='%d')
66 | print('Saved current iteration count at %s.' % self.iter_record_path)
67 |
68 | def needs_saving(self):
69 | return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize
70 |
71 | def needs_validation(self):
72 | return (self.total_steps_so_far % self.opt.validation_freq) < self.opt.batchSize
73 |
74 | def needs_printing(self):
75 | return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize
76 |
77 | def needs_displaying(self):
78 | return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize
79 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import re
7 | import pdb
8 | import importlib
9 | import torch
10 | from argparse import Namespace
11 | import numpy as np
12 | from PIL import Image
13 | import os
14 | import argparse
15 | import dill as pickle
16 | import util.coco
17 |
18 |
19 | def save_obj(obj, name):
20 | with open(name, 'wb') as f:
21 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
22 |
23 |
24 | def load_obj(name):
25 | with open(name, 'rb') as f:
26 | return pickle.load(f)
27 |
28 | # returns a configuration for creating a generator
29 | # |default_opt| should be the opt of the current experiment
30 | # |**kwargs|: if any configuration should be overriden, it can be specified here
31 |
32 |
33 | def copyconf(default_opt, **kwargs):
34 | conf = argparse.Namespace(**vars(default_opt))
35 | for key in kwargs:
36 | print(key, kwargs[key])
37 | setattr(conf, key, kwargs[key])
38 | return conf
39 |
40 |
41 | def tile_images(imgs, picturesPerRow=4):
42 | """ Code borrowed from
43 | https://stackoverflow.com/questions/26521365/cleanly-tile-numpy-array-of-images-stored-in-a-flattened-1d-format/26521997
44 | """
45 |
46 | # Padding
47 | if imgs.shape[0] % picturesPerRow == 0:
48 | rowPadding = 0
49 | else:
50 | rowPadding = picturesPerRow - imgs.shape[0] % picturesPerRow
51 | if rowPadding > 0:
52 | imgs = np.concatenate([imgs, np.zeros((rowPadding, *imgs.shape[1:]), dtype=imgs.dtype)], axis=0)
53 |
54 | # Tiling Loop (The conditionals are not necessary anymore)
55 | tiled = []
56 | for i in range(0, imgs.shape[0], picturesPerRow):
57 | tiled.append(np.concatenate([imgs[j] for j in range(i, i + picturesPerRow)], axis=1))
58 |
59 | tiled = np.concatenate(tiled, axis=0)
60 | return tiled
61 |
62 |
63 | # Converts a Tensor into a Numpy array
64 | # |imtype|: the desired type of the converted numpy array
65 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False):
66 | if isinstance(image_tensor, list):
67 | image_numpy = []
68 | for i in range(len(image_tensor)):
69 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
70 | return image_numpy
71 |
72 | if image_tensor.dim() == 4:
73 | # transform each image in the batch
74 | images_np = []
75 | for b in range(image_tensor.size(0)):
76 | one_image = image_tensor[b]
77 | one_image_np = tensor2im(one_image)
78 | images_np.append(one_image_np.reshape(1, *one_image_np.shape))
79 | images_np = np.concatenate(images_np, axis=0)
80 | if tile:
81 | images_tiled = tile_images(images_np)
82 | return images_tiled
83 | else:
84 | return images_np
85 |
86 | if image_tensor.dim() == 2:
87 | image_tensor = image_tensor.unsqueeze(0)
88 | image_numpy = image_tensor.detach().cpu().float().numpy()
89 | if normalize:
90 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
91 | else:
92 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
93 | image_numpy = np.clip(image_numpy, 0, 255)
94 | if image_numpy.shape[2] == 1:
95 | image_numpy = image_numpy[:, :, 0]
96 | return image_numpy.astype(imtype)
97 |
98 |
99 | # Converts a one-hot tensor into a colorful label map
100 | def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False):
101 | if label_tensor.dim() == 4:
102 | # transform each image in the batch
103 | images_np = []
104 | for b in range(label_tensor.size(0)):
105 | one_image = label_tensor[b]
106 | one_image_np = tensor2label(one_image, n_label, imtype)
107 | images_np.append(one_image_np.reshape(1, *one_image_np.shape))
108 | images_np = np.concatenate(images_np, axis=0)
109 | if tile:
110 | images_tiled = tile_images(images_np)
111 | return images_tiled
112 | else:
113 | images_np = images_np[0]
114 | return images_np
115 |
116 | if label_tensor.dim() == 1:
117 | return np.zeros((64, 64, 3), dtype=np.uint8)
118 | if n_label == 0:
119 | return tensor2im(label_tensor, imtype)
120 | label_tensor = label_tensor.cpu().float()
121 | if label_tensor.size()[0] > 1:
122 | label_tensor = label_tensor.max(0, keepdim=True)[1]
123 | label_tensor = Colorize(n_label)(label_tensor)
124 | label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
125 | result = label_numpy.astype(imtype)
126 | return result
127 |
128 |
129 | def save_image(image_numpy, image_path, create_dir=False):
130 | if create_dir:
131 | os.makedirs(os.path.dirname(image_path), exist_ok=True)
132 | if len(image_numpy.shape) == 2:
133 | image_numpy = np.expand_dims(image_numpy, axis=2)
134 | if image_numpy.shape[2] == 1:
135 | image_numpy = np.repeat(image_numpy, 3, 2)
136 | image_pil = Image.fromarray(image_numpy)
137 |
138 | # save to png
139 | image_pil.save(image_path.replace('.jpg', '.png'))
140 |
141 |
142 | def mkdirs(paths):
143 | if isinstance(paths, list) and not isinstance(paths, str):
144 | for path in paths:
145 | mkdir(path)
146 | else:
147 | mkdir(paths)
148 |
149 |
150 | def mkdir(path):
151 | if not os.path.exists(path):
152 | os.makedirs(path)
153 |
154 |
155 | def atoi(text):
156 | return int(text) if text.isdigit() else text
157 |
158 |
159 | def natural_keys(text):
160 | '''
161 | alist.sort(key=natural_keys) sorts in human order
162 | http://nedbatchelder.com/blog/200712/human_sorting.html
163 | (See Toothy's implementation in the comments)
164 | '''
165 | return [atoi(c) for c in re.split('(\d+)', text)]
166 |
167 |
168 | def natural_sort(items):
169 | items.sort(key=natural_keys)
170 |
171 |
172 | def str2bool(v):
173 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
174 | return True
175 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
176 | return False
177 | else:
178 | raise argparse.ArgumentTypeError('Boolean value expected.')
179 |
180 |
181 | def find_class_in_module(target_cls_name, module):
182 | target_cls_name = target_cls_name.replace('_', '').lower()
183 | clslib = importlib.import_module(module)
184 | cls = None
185 | for name, clsobj in clslib.__dict__.items():
186 | if name.lower() == target_cls_name:
187 | cls = clsobj
188 |
189 | if cls is None:
190 | print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name))
191 | exit(0)
192 |
193 | return cls
194 |
195 |
196 | def save_network(net, label, epoch, opt):
197 | save_filename = '%s_net_%s.pth' % (epoch, label)
198 | save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
199 | torch.save(net.cpu().state_dict(), save_path)
200 | if len(opt.gpu_ids) and torch.cuda.is_available():
201 | net.cuda()
202 |
203 | def load_network_path(net, save_path):
204 | weights = torch.load(save_path)
205 | new_dict = {}
206 | for k,v in weights.items():
207 | #if k.startswith("module.conv16") or k.startswith("module.conv17"):
208 | # continue
209 | if k.startswith("module."):
210 | k=k.replace("module.","")
211 | new_dict[k] = v
212 | net.load_state_dict(new_dict, strict=False)
213 | #net.load_state_dict(new_dict)
214 | return net
215 |
216 |
217 | def load_network(net, label, epoch, opt):
218 | save_filename = '%s_net_%s.pth' % (epoch, label)
219 | save_dir = os.path.join(opt.checkpoints_dir, opt.name)
220 | save_path = os.path.join(save_dir, save_filename)
221 | weights = torch.load(save_path)
222 | print("==============load path: =================")
223 | print(save_path)
224 | new_dict = {}
225 | for k,v in weights.items():
226 | if k.startswith("module."):
227 | k=k.replace("module.","")
228 | new_dict[k] = v
229 | net.load_state_dict(new_dict, strict=False)
230 | return net
231 |
232 |
233 | ###############################################################################
234 | # Code from
235 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py
236 | # Modified so it complies with the Citscape label map colors
237 | ###############################################################################
238 | def uint82bin(n, count=8):
239 | """returns the binary of integer n, count refers to amount of bits"""
240 | return ''.join([str((n >> y) & 1) for y in range(count - 1, -1, -1)])
241 |
242 |
243 | def labelcolormap(N):
244 | if N == 35: # cityscape
245 | cmap = np.array([(0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (111, 74, 0), (81, 0, 81),
246 | (128, 64, 128), (244, 35, 232), (250, 170, 160), (230, 150, 140), (70, 70, 70), (102, 102, 156), (190, 153, 153),
247 | (180, 165, 180), (150, 100, 100), (150, 120, 90), (153, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0),
248 | (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70),
249 | (0, 60, 100), (0, 0, 90), (0, 0, 110), (0, 80, 100), (0, 0, 230), (119, 11, 32), (0, 0, 142)],
250 | dtype=np.uint8)
251 | else:
252 | cmap = np.zeros((N, 3), dtype=np.uint8)
253 | for i in range(N):
254 | r, g, b = 0, 0, 0
255 | id = i + 1 # let's give 0 a color
256 | for j in range(7):
257 | str_id = uint82bin(id)
258 | r = r ^ (np.uint8(str_id[-1]) << (7 - j))
259 | g = g ^ (np.uint8(str_id[-2]) << (7 - j))
260 | b = b ^ (np.uint8(str_id[-3]) << (7 - j))
261 | id = id >> 3
262 | cmap[i, 0] = r
263 | cmap[i, 1] = g
264 | cmap[i, 2] = b
265 |
266 | if N == 182: # COCO
267 | important_colors = {
268 | 'sea': (54, 62, 167),
269 | 'sky-other': (95, 219, 255),
270 | 'tree': (140, 104, 47),
271 | 'clouds': (170, 170, 170),
272 | 'grass': (29, 195, 49)
273 | }
274 | for i in range(N):
275 | name = util.coco.id2label(i)
276 | if name in important_colors:
277 | color = important_colors[name]
278 | cmap[i] = np.array(list(color))
279 |
280 | return cmap
281 |
282 |
283 | class Colorize(object):
284 | def __init__(self, n=35):
285 | self.cmap = labelcolormap(n)
286 | self.cmap = torch.from_numpy(self.cmap[:n])
287 |
288 | def __call__(self, gray_image):
289 | size = gray_image.size()
290 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
291 |
292 | for label in range(0, len(self.cmap)):
293 | mask = (label == gray_image[0]).cpu()
294 | color_image[0][mask] = self.cmap[label][0]
295 | color_image[1][mask] = self.cmap[label][1]
296 | color_image[2][mask] = self.cmap[label][2]
297 |
298 | return color_image
299 |
--------------------------------------------------------------------------------
/util/visualizer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import os
7 | import ntpath
8 | import time
9 | from . import util
10 | from . import html
11 | import scipy.misc
12 | try:
13 | from StringIO import StringIO # Python 2.7
14 | except ImportError:
15 | from io import BytesIO # Python 3.x
16 |
17 | class Visualizer():
18 | def __init__(self, opt):
19 | self.opt = opt
20 | self.tf_log = opt.isTrain and opt.tf_log
21 | self.use_html = opt.isTrain and not opt.no_html
22 | self.win_size = opt.display_winsize
23 | self.name = opt.name
24 | if self.tf_log:
25 | import tensorflow as tf
26 | self.tf = tf
27 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs')
28 | self.writer = tf.summary.FileWriter(self.log_dir)
29 |
30 | if self.use_html:
31 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
32 | self.img_dir = os.path.join(self.web_dir, 'images')
33 | print('create web directory %s...' % self.web_dir)
34 | util.mkdirs([self.web_dir, self.img_dir])
35 | if opt.isTrain:
36 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
37 | with open(self.log_name, "a") as log_file:
38 | now = time.strftime("%c")
39 | log_file.write('================ Training Loss (%s) ================\n' % now)
40 |
41 | # |visuals|: dictionary of images to display or save
42 | def display_current_results(self, visuals, epoch, step):
43 |
44 | ## convert tensors to numpy arrays
45 | visuals = self.convert_visuals_to_numpy(visuals)
46 |
47 | if self.tf_log: # show images in tensorboard output
48 | img_summaries = []
49 | for label, image_numpy in visuals.items():
50 | # Write the image to a string
51 | try:
52 | s = StringIO()
53 | except:
54 | s = BytesIO()
55 | if len(image_numpy.shape) >= 4:
56 | image_numpy = image_numpy[0]
57 | scipy.misc.toimage(image_numpy).save(s, format="jpeg")
58 | # Create an Image object
59 | img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1])
60 | # Create a Summary value
61 | img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum))
62 |
63 | # Create and write Summary
64 | summary = self.tf.Summary(value=img_summaries)
65 | self.writer.add_summary(summary, step)
66 |
67 | if self.use_html: # save images to a html file
68 | for label, image_numpy in visuals.items():
69 | if isinstance(image_numpy, list):
70 | for i in range(len(image_numpy)):
71 | img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.3d_%s_%d.png' % (epoch, step, label, i))
72 | util.save_image(image_numpy[i], img_path)
73 | else:
74 | img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.3d_%s.png' % (epoch, step, label))
75 | if len(image_numpy.shape) >= 4:
76 | image_numpy = image_numpy[0]
77 | util.save_image(image_numpy, img_path)
78 |
79 | # update website
80 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=5)
81 | for n in range(epoch, 0, -1):
82 | webpage.add_header('epoch [%d]' % n)
83 | ims = []
84 | txts = []
85 | links = []
86 |
87 | for label, image_numpy in visuals.items():
88 | if isinstance(image_numpy, list):
89 | for i in range(len(image_numpy)):
90 | img_path = 'epoch%.3d_iter%.3d_%s_%d.png' % (n, step, label, i)
91 | ims.append(img_path)
92 | txts.append(label+str(i))
93 | links.append(img_path)
94 | else:
95 | img_path = 'epoch%.3d_iter%.3d_%s.png' % (n, step, label)
96 | ims.append(img_path)
97 | txts.append(label)
98 | links.append(img_path)
99 | if len(ims) < 10:
100 | webpage.add_images(ims, txts, links, width=self.win_size)
101 | else:
102 | num = int(round(len(ims)/2.0))
103 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size)
104 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size)
105 | webpage.save()
106 |
107 | # errors: dictionary of error labels and values
108 | def plot_current_errors(self, errors, step):
109 | if self.tf_log:
110 | for tag, value in errors.items():
111 | value = value.mean().float()
112 | summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])
113 | self.writer.add_summary(summary, step)
114 |
115 | # errors: same format as |errors| of plotCurrentErrors
116 | def print_current_errors(self, epoch, i, errors, t):
117 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
118 | for k, v in errors.items():
119 | #print(v)
120 | #if v != 0:
121 | v = v.mean().float()
122 | message += '%s: %.3f ' % (k, v)
123 |
124 | print(message)
125 | with open(self.log_name, "a") as log_file:
126 | log_file.write('%s\n' % message)
127 |
128 | def convert_visuals_to_numpy(self, visuals):
129 | for key, t in visuals.items():
130 | tile = self.opt.batchSize > 8
131 | if 'input_label' == key:
132 | t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile)
133 | else:
134 | t = util.tensor2im(t, tile=tile)
135 | visuals[key] = t
136 | return visuals
137 |
138 | # save image to the disk
139 | def save_images(self, webpage, visuals, image_path):
140 | visuals = self.convert_visuals_to_numpy(visuals)
141 |
142 | image_dir = webpage.get_image_dir()
143 | short_path = ntpath.basename(image_path[0])
144 | name = os.path.splitext(short_path)[0]
145 |
146 | webpage.add_header(name)
147 | ims = []
148 | txts = []
149 | links = []
150 |
151 | for label, image_numpy in visuals.items():
152 | image_name = os.path.join(label, '%s.png' % (name))
153 | save_path = os.path.join(image_dir, image_name)
154 | util.save_image(image_numpy, save_path, create_dir=True)
155 |
156 | ims.append(image_name)
157 | txts.append(label)
158 | links.append(image_name)
159 | webpage.add_images(ims, txts, links, width=self.win_size)
160 |
--------------------------------------------------------------------------------