├── .gitignore ├── LICENSE ├── README.md ├── checkpoint └── .gitignore ├── dataset.py ├── distributed.py ├── mask.py ├── model.py ├── sample.png ├── sample └── .gitignore └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # semantic-pyramid-pytorch 2 | Implementation of Semantic Pyramid for Image Generation (https://arxiv.org/abs/2003.06221) in PyTorch 3 | 4 | Details of model implementation will be different to offical model, as currently official model is not yet released. 5 | 6 | 7 | ## Usage 8 | 9 | ```bash 10 | > python -m torch.distributed.launch --nproc_per_node=[NUM GPUS] --master_port=[PORT] train.py [Places365 PATH] 11 | ``` 12 | 13 | 14 | ## Samples 15 | 16 | ![Samples from semantic pyramid model](sample.png) 17 | 18 | First column is source images that vgg features extracted from. Next columns is images generated by conditioned on the specific level of the features. (Conv1 ~ Conv5, FC7, FC8) -------------------------------------------------------------------------------- /checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | 6 | 7 | IMG_EXTENSIONS = ( 8 | '.jpg', 9 | '.jpeg', 10 | '.png', 11 | '.ppm', 12 | '.bmp', 13 | '.pgm', 14 | '.tif', 15 | '.tiff', 16 | '.webp', 17 | ) 18 | 19 | 20 | class Places365(Dataset): 21 | def __init__(self, root, transform=None): 22 | self.root = root 23 | self.transform = transform 24 | 25 | self.data = [] 26 | categories = set() 27 | 28 | for dirpath, dirnames, filenames in os.walk(self.root): 29 | if len(filenames) > 0: 30 | relpath = os.path.relpath(dirpath, root) 31 | _, category = os.path.split(dirpath) 32 | categories.add(category) 33 | 34 | for file in filenames: 35 | if file.lower().endswith(IMG_EXTENSIONS): 36 | self.data.append((os.path.join(relpath, file), category)) 37 | 38 | categories = sorted(list(categories)) 39 | 40 | self.category_map = {cat: i for i, cat in enumerate(categories)} 41 | 42 | self.n_class = 365 43 | 44 | def __len__(self): 45 | return len(self.data) 46 | 47 | def __getitem__(self, index): 48 | path, category = self.data[index] 49 | 50 | img = Image.open(os.path.join(self.root, path)).convert('RGB') 51 | label = self.category_map[category] 52 | 53 | if self.transform is not None: 54 | img = self.transform(img) 55 | 56 | return img, label 57 | -------------------------------------------------------------------------------- /distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils import data 7 | 8 | 9 | def get_rank(): 10 | if not dist.is_available(): 11 | return 0 12 | 13 | if not dist.is_initialized(): 14 | return 0 15 | 16 | return dist.get_rank() 17 | 18 | 19 | def synchronize(): 20 | if not dist.is_available(): 21 | return 22 | 23 | if not dist.is_initialized(): 24 | return 25 | 26 | world_size = dist.get_world_size() 27 | 28 | if world_size == 1: 29 | return 30 | 31 | dist.barrier() 32 | 33 | 34 | def get_world_size(): 35 | if not dist.is_available(): 36 | return 1 37 | 38 | if not dist.is_initialized(): 39 | return 1 40 | 41 | return dist.get_world_size() 42 | 43 | 44 | def all_gather(data): 45 | world_size = get_world_size() 46 | 47 | if world_size == 1: 48 | return [data] 49 | 50 | buffer = pickle.dumps(data) 51 | storage = torch.ByteStorage.from_buffer(buffer) 52 | tensor = torch.ByteTensor(storage).to('cuda') 53 | 54 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 55 | size_list = [torch.IntTensor([1]).to('cuda') for _ in range(world_size)] 56 | dist.all_gather(size_list, local_size) 57 | size_list = [int(size.item()) for size in size_list] 58 | max_size = max(size_list) 59 | 60 | tensor_list = [] 61 | for _ in size_list: 62 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 63 | 64 | if local_size != max_size: 65 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 66 | tensor = torch.cat((tensor, padding), 0) 67 | 68 | dist.all_gather(tensor_list, tensor) 69 | 70 | data_list = [] 71 | 72 | for size, tensor in zip(size_list, tensor_list): 73 | buffer = tensor.cpu().numpy().tobytes()[:size] 74 | data_list.append(pickle.loads(buffer)) 75 | 76 | return data_list 77 | 78 | 79 | def reduce_dict(input_dict, average=True): 80 | world_size = get_world_size() 81 | 82 | if world_size < 2: 83 | return input_dict 84 | 85 | with torch.no_grad(): 86 | keys = [] 87 | values = [] 88 | 89 | for k in sorted(input_dict.keys()): 90 | keys.append(k) 91 | values.append(input_dict[k]) 92 | 93 | values = torch.stack(values, 0) 94 | dist.reduce(values, dst=0) 95 | 96 | if dist.get_rank() == 0 and average: 97 | values /= world_size 98 | 99 | reduced_dict = {k: v for k, v in zip(keys, values)} 100 | 101 | return reduced_dict 102 | 103 | 104 | def data_sampler(dataset, shuffle, distributed): 105 | if distributed: 106 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 107 | 108 | if shuffle: 109 | return data.RandomSampler(dataset) 110 | 111 | else: 112 | return data.SequentialSampler(dataset) 113 | -------------------------------------------------------------------------------- /mask.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import cv2 5 | import torch 6 | from torch.nn import functional as F 7 | 8 | # Took from https://github.com/jshyunbin/inpainting_cGAN/blob/master/src/mask_generator.py 9 | def pattern_mask(img_size, kernel_size=7, num_points=1, ratio=0.25): 10 | mask = np.zeros((img_size, img_size), dtype=np.float32) 11 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) 12 | 13 | for num in range(num_points): 14 | coordinate = np.random.randint(img_size, size=2) 15 | mask[coordinate[0], coordinate[1]] = 1.0 16 | mask = cv2.dilate(mask, kernel, iterations=1) 17 | 18 | i = 0 19 | 20 | while np.sum(mask) < ratio * img_size * img_size: 21 | i += 1 22 | flag = True 23 | while flag: 24 | coordinate = np.random.randint(img_size, size=2) 25 | if mask[coordinate[0], coordinate[1]] == 1.0: 26 | mask2 = np.zeros((img_size, img_size), dtype=np.float32) 27 | mask2[coordinate[0], coordinate[1]] = 1.0 28 | mask2 = cv2.dilate(mask2, kernel, iterations=1) 29 | 30 | mask[mask + mask2 >= 1.0] = 1.0 31 | flag = False 32 | 33 | mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) 34 | 35 | return 1.0 - mask 36 | 37 | 38 | def make_crop_mask(crop_size, crop_kernel, sizes, device): 39 | crop_mask = ( 40 | torch.from_numpy(pattern_mask(crop_size, crop_kernel)) 41 | .view(1, 1, crop_size, crop_size) 42 | .to(device) 43 | ) 44 | masks = [] 45 | 46 | for size in sizes: 47 | if size[0] != crop_size or size[1] != crop_size: 48 | crop = F.interpolate(crop_mask, size=size, mode='nearest') 49 | 50 | else: 51 | crop = crop_mask 52 | 53 | masks.append(crop.squeeze()) 54 | 55 | return masks 56 | 57 | 58 | def make_mask_pyramid(selected, n_mask, sizes, device): 59 | masks = [] 60 | 61 | for i in range(n_mask): 62 | if i == selected: 63 | if i < len(sizes): 64 | m = torch.ones(*sizes[i], device=device) 65 | 66 | else: 67 | m = torch.ones(1, device=device) 68 | 69 | masks.append(m) 70 | 71 | else: 72 | if i < len(sizes): 73 | m = torch.zeros(*sizes[i], device=device) 74 | 75 | else: 76 | m = torch.zeros(1, device=device) 77 | 78 | masks.append(m) 79 | 80 | return masks 81 | 82 | 83 | def make_crop_mask_pyramid(selected, n_mask, crop_size, crop_kernel, sizes, device): 84 | masks = make_crop_mask(crop_size, crop_kernel, sizes[:selected], device) 85 | 86 | masks.append(torch.ones(*sizes[selected], device=device)) 87 | 88 | for h, w in sizes[selected + 1 :]: 89 | masks.append(torch.zeros(h, w, device=device)) 90 | 91 | for _ in range(n_mask - len(sizes)): 92 | masks.append(torch.zeros(1, device=device)) 93 | 94 | return masks 95 | 96 | 97 | def make_mask( 98 | batch_size, 99 | device, 100 | crop_prob=0.3, 101 | n_mask=7, 102 | sizes=((112, 112), (56, 56), (28, 28), (14, 14), (7, 7)), 103 | crop_size=56, 104 | crop_kernel=31, 105 | ): 106 | selected = torch.randint(0, n_mask, (batch_size,)) 107 | 108 | mask_batch = [] 109 | 110 | for sel in selected: 111 | if sel < len(sizes) and random.random() < crop_prob: 112 | masks = make_crop_mask_pyramid( 113 | sel, n_mask, crop_size, crop_kernel, sizes, device 114 | ) 115 | 116 | else: 117 | masks = make_mask_pyramid(sel, n_mask, sizes, device) 118 | 119 | mask_batch.append(masks) 120 | 121 | masks_zip = [] 122 | 123 | for masks in zip(*mask_batch): 124 | masks_zip.append(torch.stack(masks, 0).unsqueeze(1)) 125 | 126 | return masks_zip 127 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from torchvision.models import vgg16, vgg16_bn, vgg19, vgg19_bn 6 | 7 | 8 | def spectral_norm(module): 9 | nn.init.xavier_uniform_(module.weight, 2 ** 0.5) 10 | 11 | if hasattr(module, 'bias') and module.bias is not None: 12 | module.bias.data.zero_() 13 | 14 | return nn.utils.spectral_norm(module) 15 | 16 | 17 | def get_activation(name): 18 | if name == 'leaky_relu': 19 | activation = nn.LeakyReLU(0.2) 20 | 21 | elif name == 'relu': 22 | activation = nn.ReLU() 23 | 24 | return activation 25 | 26 | 27 | class VGGFeature(nn.Module): 28 | def __init__(self, arch, indices, use_fc=False, normalize=True, min_max=(-1, 1)): 29 | super().__init__() 30 | 31 | vgg = { 32 | 'vgg16': vgg16, 33 | 'vgg16_bn': vgg16_bn, 34 | 'vgg19': vgg19, 35 | 'vgg19_bn': vgg19_bn, 36 | }.get(arch)(pretrained=True) 37 | 38 | for p in vgg.parameters(): 39 | p.requires_grad = False 40 | 41 | self.slices = nn.ModuleList() 42 | 43 | for i, j in zip([-1] + indices, indices + [None]): 44 | if j is None: 45 | break 46 | 47 | self.slices.append(vgg.features[slice(i + 1, j + 1)]) 48 | 49 | self.use_fc = use_fc 50 | 51 | if use_fc: 52 | self.rest_layer = vgg.features[indices[-1] :] 53 | self.fc6 = vgg.classifier[:3] 54 | self.fc7 = vgg.classifier[3:6] 55 | self.fc8 = vgg.classifier[6] 56 | 57 | mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) 58 | std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) 59 | 60 | val_range = min_max[1] - min_max[0] 61 | mean = mean * (val_range) + min_max[0] 62 | std = std * val_range 63 | 64 | self.register_buffer('mean', mean) 65 | self.register_buffer('std', std) 66 | 67 | self.normalize = normalize 68 | 69 | def forward(self, input): 70 | if self.normalize: 71 | input = (input - self.mean) / self.std 72 | 73 | features = [] 74 | 75 | out = input 76 | for layer in self.slices: 77 | out = layer(out) 78 | features.append(out) 79 | 80 | fcs = [] 81 | 82 | if self.use_fc: 83 | out = self.rest_layer(out) 84 | out = torch.flatten(F.adaptive_avg_pool2d(out, (7, 7)), 1) 85 | 86 | fc6 = self.fc6(out) 87 | fc7 = self.fc7(fc6) 88 | fc8 = self.fc8(fc7) 89 | 90 | fcs = [fc6, fc7, fc8] 91 | 92 | return features, fcs 93 | 94 | 95 | class AdaptiveBatchNorm2d(nn.Module): 96 | def __init__(self, in_channel, embed_dim): 97 | super().__init__() 98 | 99 | self.norm = nn.BatchNorm2d(in_channel, affine=False) 100 | 101 | self.weight = spectral_norm(nn.Linear(embed_dim, in_channel, bias=False)) 102 | self.bias = spectral_norm(nn.Linear(embed_dim, in_channel, bias=False)) 103 | 104 | self.in_channel = in_channel 105 | self.embed_dim = embed_dim 106 | 107 | def forward(self, input, embed): 108 | out = self.norm(input) 109 | 110 | batch_size = input.shape[0] 111 | 112 | weight = self.weight(embed).view(batch_size, -1, 1, 1) 113 | bias = self.bias(embed).view(batch_size, -1, 1, 1) 114 | 115 | out = (weight + 1) * out + bias 116 | 117 | return out 118 | 119 | def __repr__(self): 120 | return ( 121 | f'{self.__class__.__name__}({self.in_channel}, embed_dim={self.embed_dim})' 122 | ) 123 | 124 | 125 | class ResBlock(nn.Module): 126 | def __init__( 127 | self, 128 | in_channel, 129 | out_channel, 130 | norm=False, 131 | embed_dim=None, 132 | upsample=False, 133 | downsample=False, 134 | first=False, 135 | activation='relu', 136 | ): 137 | super().__init__() 138 | 139 | self.first = first 140 | self.norm = norm 141 | 142 | bias = False if norm else True 143 | 144 | if norm: 145 | self.norm1 = AdaptiveBatchNorm2d(in_channel, embed_dim) 146 | 147 | if not self.first: 148 | self.activation1 = get_activation(activation) 149 | 150 | if upsample: 151 | self.upsample = nn.Upsample(scale_factor=2) 152 | 153 | else: 154 | self.upsample = None 155 | 156 | self.conv1 = spectral_norm( 157 | nn.Conv2d(in_channel, out_channel, 3, padding=1, bias=bias) 158 | ) 159 | 160 | if norm: 161 | self.norm2 = AdaptiveBatchNorm2d(out_channel, embed_dim) 162 | 163 | self.activation2 = get_activation(activation) 164 | 165 | self.conv2 = spectral_norm( 166 | nn.Conv2d(out_channel, out_channel, 3, padding=1, bias=bias) 167 | ) 168 | 169 | if downsample: 170 | self.downsample = nn.AvgPool2d(2) 171 | 172 | else: 173 | self.downsample = None 174 | 175 | self.skip = None 176 | 177 | if in_channel != out_channel or upsample or downsample: 178 | self.skip = spectral_norm(nn.Conv2d(in_channel, out_channel, 1, bias=False)) 179 | 180 | def forward(self, input, embed=None): 181 | out = input 182 | 183 | if self.norm: 184 | out = self.norm1(out, embed) 185 | 186 | if not self.first: 187 | out = self.activation1(out) 188 | 189 | if self.upsample: 190 | out = self.upsample(out) 191 | 192 | out = self.conv1(out) 193 | 194 | if self.norm: 195 | out = self.norm2(out, embed) 196 | 197 | out = self.activation2(out) 198 | out = self.conv2(out) 199 | 200 | if self.downsample: 201 | out = self.downsample(out) 202 | 203 | skip = input 204 | 205 | if self.skip is not None: 206 | if self.upsample: 207 | skip = self.upsample(skip) 208 | 209 | if self.downsample and self.first: 210 | skip = self.downsample(skip) 211 | 212 | skip = self.skip(skip) 213 | 214 | if self.downsample and not self.first: 215 | skip = self.downsample(skip) 216 | 217 | return out + skip 218 | 219 | 220 | class SelfAttention(nn.Module): 221 | def __init__(self, in_channel, divider=8): 222 | super().__init__() 223 | 224 | self.query = spectral_norm( 225 | nn.Conv2d(in_channel, in_channel // divider, 1, bias=False) 226 | ) 227 | self.key = spectral_norm( 228 | nn.Conv2d(in_channel, in_channel // divider, 1, bias=False) 229 | ) 230 | self.value = spectral_norm( 231 | nn.Conv2d(in_channel, in_channel // 2, 1, bias=False) 232 | ) 233 | self.out = spectral_norm(nn.Conv2d(in_channel // 2, in_channel, 1, bias=False)) 234 | 235 | self.divider = divider 236 | 237 | self.gamma = nn.Parameter(torch.tensor(0.0)) 238 | 239 | def forward(self, input): 240 | batch, channel, height, width = input.shape 241 | 242 | query = ( 243 | self.query(input) 244 | .view(batch, channel // self.divider, height * width) 245 | .transpose(1, 2) 246 | ) 247 | key = F.max_pool2d(self.key(input), 2).view( 248 | batch, channel // self.divider, height * width // 4 249 | ) 250 | value = F.max_pool2d(self.value(input), 2).view( 251 | batch, channel // 2, height * width // 4 252 | ) 253 | query_key = torch.bmm(query, key) 254 | attn = F.softmax(query_key, 2) 255 | attn = torch.bmm(value, attn.transpose(1, 2)).view( 256 | batch, channel // 2, height, width 257 | ) 258 | attn = self.out(attn) 259 | out = self.gamma * attn + input 260 | 261 | return out 262 | 263 | 264 | class Generator(nn.Module): 265 | def __init__( 266 | self, 267 | n_class, 268 | dim_z, 269 | dim_class, 270 | feature_channels=(64, 128, 256, 512, 512, 4096, 1000), 271 | channel_multiplier=64, 272 | channels=(8, 8, 4, 2, 2, 1), 273 | blocks='rrrrar', 274 | upsample='nuuunu', 275 | activation='relu', 276 | feature_kernel_size=1, 277 | ): 278 | super().__init__() 279 | 280 | self.n_resblock = len([c for c in blocks if c == 'r']) 281 | self.use_affine = [b == 'r' for b in blocks] 282 | 283 | self.embed = nn.Embedding(n_class, dim_class) 284 | 285 | self.linears = nn.ModuleList() 286 | self.feature_linears = nn.ModuleList() 287 | 288 | in_dim = dim_z 289 | feat_i = 6 290 | for _ in range(2): 291 | dim = feature_channels[feat_i] 292 | 293 | self.linears.append( 294 | nn.Sequential(spectral_norm(nn.Linear(in_dim, dim)), nn.LeakyReLU(0.2)) 295 | ) 296 | 297 | self.feature_linears.append(spectral_norm(nn.Linear(dim, dim))) 298 | 299 | in_dim = dim 300 | feat_i -= 1 301 | 302 | self.linear_expand = spectral_norm( 303 | nn.Linear(in_dim, 7 * 7 * feature_channels[-3]) 304 | ) 305 | 306 | self.blocks = nn.ModuleList() 307 | self.feature_blocks = nn.ModuleList() 308 | 309 | in_channel = channels[0] * channel_multiplier 310 | for block, ch, up in zip(blocks, channels, upsample): 311 | if block == 'r': 312 | self.blocks.append( 313 | ResBlock( 314 | in_channel, 315 | ch * channel_multiplier, 316 | norm=True, 317 | embed_dim=dim_class, 318 | upsample=up == 'u', 319 | activation=activation, 320 | ) 321 | ) 322 | 323 | self.feature_blocks.append( 324 | spectral_norm( 325 | nn.Conv2d( 326 | feature_channels[feat_i], 327 | ch * channel_multiplier, 328 | feature_kernel_size, 329 | padding=(feature_kernel_size - 1) // 2, 330 | bias=False, 331 | ) 332 | ) 333 | ) 334 | 335 | feat_i -= 1 336 | 337 | elif block == 'a': 338 | self.blocks.append(SelfAttention(in_channel)) 339 | 340 | in_channel = ch * channel_multiplier 341 | 342 | self.colorize = nn.Sequential( 343 | nn.BatchNorm2d(in_channel), 344 | get_activation(activation), 345 | nn.Upsample(scale_factor=2), 346 | spectral_norm(nn.Conv2d(in_channel, 3, 3, padding=1)), 347 | nn.Tanh(), 348 | ) 349 | 350 | def forward(self, input, class_id, features, masks): 351 | embed = self.embed(class_id) 352 | 353 | batch_size = input.shape[0] 354 | 355 | feat_i = len(features) - 1 356 | 357 | out = input 358 | 359 | for linear, feat_linear in zip(self.linears, self.feature_linears): 360 | out = linear(out) 361 | 362 | # print(out.shape, features[feat_i].shape, masks[feat_i].shape) 363 | 364 | out = out + feat_linear(features[feat_i] * masks[feat_i].squeeze(-1)) 365 | 366 | # print(out.shape) 367 | 368 | feat_i -= 1 369 | 370 | out = self.linear_expand(out).view(batch_size, -1, 7, 7) 371 | 372 | layer_i = feat_i 373 | 374 | for affine, block in zip(self.use_affine, self.blocks): 375 | # print(out.shape) 376 | if affine: 377 | out = block(out, embed) 378 | # print(out.shape, features[feat_i].shape, masks[feat_i].shape) 379 | out = out + self.feature_blocks[layer_i - feat_i]( 380 | features[feat_i] * masks[feat_i] 381 | ) 382 | feat_i -= 1 383 | 384 | else: 385 | out = block(out) 386 | 387 | out = self.colorize(out) 388 | 389 | return out 390 | 391 | 392 | class Discriminator(nn.Module): 393 | def __init__( 394 | self, 395 | n_class, 396 | channel_multiplier=64, 397 | channels=(1, 2, 2, 4, 8, 16, 16), 398 | blocks='rrarrrr', 399 | downsample='ddndddn', 400 | activation='relu', 401 | ): 402 | super().__init__() 403 | 404 | blocks_list = [] 405 | 406 | in_channel = 3 407 | for i, (block, ch, down) in enumerate(zip(blocks, channels, downsample)): 408 | if block == 'r': 409 | blocks_list.append( 410 | ResBlock( 411 | in_channel, 412 | ch * channel_multiplier, 413 | downsample=down == 'd', 414 | first=i == 0, 415 | activation=activation, 416 | ) 417 | ) 418 | 419 | elif block == 'a': 420 | blocks_list.append(SelfAttention(in_channel)) 421 | 422 | in_channel = ch * channel_multiplier 423 | 424 | blocks_list += [get_activation(activation)] 425 | 426 | self.blocks = nn.Sequential(*blocks_list) 427 | 428 | self.embed = spectral_norm(nn.Embedding(n_class, in_channel)) 429 | self.linear = spectral_norm(nn.Linear(in_channel, 1)) 430 | 431 | def forward(self, input, class_id): 432 | out = self.blocks(input) 433 | 434 | out = out.sum([2, 3]) 435 | out_linear = self.linear(out) 436 | 437 | embed = self.embed(class_id) 438 | prod = (out * embed).sum(1, keepdim=True) 439 | 440 | out_linear = out_linear + prod 441 | 442 | return out_linear.squeeze(1) 443 | -------------------------------------------------------------------------------- /sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/semantic-pyramid-pytorch/be3cb40dc8a2f8dd92295628db77265fbe6be45e/sample.png -------------------------------------------------------------------------------- /sample/.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from torch import optim, nn 6 | from torch.nn import functional as F 7 | from torch.utils import data 8 | from torchvision import transforms, utils 9 | from tqdm import tqdm 10 | 11 | from dataset import Places365 12 | from model import VGGFeature, Generator, Discriminator 13 | from mask import make_mask 14 | import distributed as dist 15 | 16 | 17 | def requires_grad(module, flag): 18 | for m in module.parameters(): 19 | m.requires_grad = flag 20 | 21 | 22 | def d_ls_loss(real_predict, fake_predict): 23 | loss = (real_predict - 1).pow(2).mean() + fake_predict.pow(2).mean() 24 | 25 | return loss 26 | 27 | 28 | def g_ls_loss(real_predict, fake_predict): 29 | loss = (fake_predict - 1).pow(2).mean() 30 | 31 | return loss 32 | 33 | 34 | def recon_loss(features_fake, features_real, masks): 35 | r_loss = 0 36 | 37 | for f_fake, f_real, m in zip(features_fake, features_real, masks): 38 | if f_fake.ndim == 4: 39 | f_fake = F.max_pool2d(f_fake, 2, ceil_mode=True) 40 | f_real = F.max_pool2d(f_real, 2, ceil_mode=True) 41 | f_mask = F.max_pool2d(m, 2, ceil_mode=True) 42 | 43 | r_loss = ( 44 | r_loss + (F.l1_loss(f_fake, f_real, reduction="none") * f_mask).mean() 45 | ) 46 | 47 | else: 48 | r_loss = ( 49 | r_loss 50 | + (F.l1_loss(f_fake, f_real, reduction="none") * m.squeeze(-1)).mean() 51 | ) 52 | 53 | return r_loss 54 | 55 | 56 | def diversity_loss(z1, z2, fake1, fake2, eps=1e-8): 57 | div_z = F.l1_loss(z1, z2, reduction="none").mean(1) 58 | div_fake = F.l1_loss(fake1, fake2, reduction="none").mean((1, 2, 3)) 59 | 60 | d_loss = (div_z / (div_fake + eps)).mean() 61 | 62 | return d_loss 63 | 64 | 65 | def accumulate(model1, model2, decay=0.999): 66 | par1 = dict(model1.named_parameters()) 67 | par2 = dict(model2.named_parameters()) 68 | 69 | for k in par1.keys(): 70 | par1[k].data.mul_(decay).add_(1 - decay, par2[k].data) 71 | 72 | 73 | def sample_data(loader): 74 | loader_iter = iter(loader) 75 | 76 | while True: 77 | try: 78 | yield next(loader_iter) 79 | 80 | except StopIteration: 81 | loader_iter = iter(loader) 82 | 83 | yield next(loader_iter) 84 | 85 | 86 | def train(args, dataset, gen, dis, g_ema, device): 87 | if args.distributed: 88 | g_module = gen.module 89 | d_module = dis.module 90 | 91 | else: 92 | g_module = gen 93 | d_module = dis 94 | 95 | vgg = VGGFeature("vgg16", [4, 9, 16, 23, 30], use_fc=True).eval().to(device) 96 | requires_grad(vgg, False) 97 | 98 | g_optim = optim.Adam(gen.parameters(), lr=1e-4, betas=(0, 0.999)) 99 | d_optim = optim.Adam(dis.parameters(), lr=1e-4, betas=(0, 0.999)) 100 | 101 | loader = data.DataLoader( 102 | dataset, 103 | batch_size=args.batch, 104 | num_workers=4, 105 | sampler=dist.data_sampler(dataset, shuffle=True, distributed=args.distributed), 106 | drop_last=True, 107 | ) 108 | 109 | loader_iter = sample_data(loader) 110 | 111 | pbar = range(args.start_iter, args.iter) 112 | 113 | if dist.get_rank() == 0: 114 | pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True) 115 | 116 | eps = 1e-8 117 | 118 | for i in pbar: 119 | real, class_id = next(loader_iter) 120 | 121 | real = real.to(device) 122 | class_id = class_id.to(device) 123 | 124 | masks = make_mask(real.shape[0], device, args.crop_prob) 125 | features, fcs = vgg(real) 126 | features = features + fcs[1:] 127 | 128 | requires_grad(dis, True) 129 | requires_grad(gen, False) 130 | 131 | real_pred = dis(real, class_id) 132 | 133 | z = torch.randn(args.batch, args.dim_z, device=device) 134 | 135 | fake = gen(z, class_id, features, masks) 136 | 137 | fake_pred = dis(fake, class_id) 138 | 139 | d_loss = d_ls_loss(real_pred, fake_pred) 140 | 141 | d_optim.zero_grad() 142 | d_loss.backward() 143 | d_optim.step() 144 | 145 | z1 = torch.randn(args.batch, args.dim_z, device=device) 146 | z2 = torch.randn(args.batch, args.dim_z, device=device) 147 | 148 | requires_grad(gen, True) 149 | requires_grad(dis, False) 150 | 151 | masks = make_mask(real.shape[0], device, args.crop_prob) 152 | 153 | if args.distributed: 154 | gen.broadcast_buffers = True 155 | 156 | fake1 = gen(z1, class_id, features, masks) 157 | 158 | if args.distributed: 159 | gen.broadcast_buffers = False 160 | 161 | fake2 = gen(z2, class_id, features, masks) 162 | 163 | fake_pred = dis(fake1, class_id) 164 | 165 | a_loss = g_ls_loss(None, fake_pred) 166 | 167 | features_fake, fcs_fake = vgg(fake1) 168 | features_fake = features_fake + fcs_fake[1:] 169 | 170 | r_loss = recon_loss(features_fake, features, masks) 171 | div_loss = diversity_loss(z1, z2, fake1, fake2, eps) 172 | 173 | g_loss = a_loss + args.rec_weight * r_loss + args.div_weight * div_loss 174 | 175 | g_optim.zero_grad() 176 | g_loss.backward() 177 | g_optim.step() 178 | 179 | accumulate(g_ema, g_module) 180 | 181 | if dist.get_rank() == 0: 182 | pbar.set_description( 183 | f"d: {d_loss.item():.4f}; g: {a_loss.item():.4f}; rec: {r_loss.item():.4f}; div: {div_loss.item():.4f}" 184 | ) 185 | 186 | if i % 100 == 0: 187 | utils.save_image( 188 | fake1, 189 | f"sample/{str(i).zfill(6)}.png", 190 | nrow=int(args.batch ** 0.5), 191 | normalize=True, 192 | range=(-1, 1), 193 | ) 194 | 195 | if i % 10000 == 0: 196 | torch.save( 197 | { 198 | "args": args, 199 | "g_ema": g_ema.state_dict(), 200 | "g": g_module.state_dict(), 201 | "d": d_module.state_dict(), 202 | }, 203 | f"checkpoint/{str(i).zfill(6)}.pt", 204 | ) 205 | 206 | 207 | if __name__ == "__main__": 208 | device = "cuda" 209 | 210 | torch.backends.cudnn.benchmark = True 211 | 212 | parser = argparse.ArgumentParser() 213 | 214 | parser.add_argument("--local_rank", type=int, default=0) 215 | parser.add_argument("--ckpt", type=str, default=None) 216 | parser.add_argument("--iter", type=int, default=500000) 217 | parser.add_argument("--start_iter", type=int, default=0) 218 | parser.add_argument("--batch", type=int, default=32) 219 | parser.add_argument("--dim_z", type=int, default=128) 220 | parser.add_argument("--dim_class", type=int, default=128) 221 | parser.add_argument("--rec_weight", type=float, default=0.1) 222 | parser.add_argument("--div_weight", type=float, default=0.1) 223 | parser.add_argument("--crop_prob", type=float, default=0.3) 224 | parser.add_argument("path", metavar="PATH") 225 | 226 | args = parser.parse_args() 227 | 228 | n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 229 | args.distributed = n_gpu > 1 230 | 231 | transform = transforms.Compose( 232 | [ 233 | transforms.Resize(224), 234 | transforms.RandomCrop(224), 235 | transforms.RandomHorizontalFlip(), 236 | transforms.ToTensor(), 237 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 238 | ] 239 | ) 240 | 241 | dset = Places365(args.path, transform=transform) 242 | args.n_class = dset.n_class 243 | 244 | if args.distributed: 245 | torch.cuda.set_device(args.local_rank) 246 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 247 | dist.synchronize() 248 | 249 | gen = Generator(args.n_class, args.dim_z, args.dim_class).to(device) 250 | g_ema = Generator(args.n_class, args.dim_z, args.dim_class).to(device) 251 | accumulate(g_ema, gen, 0) 252 | dis = Discriminator(args.n_class).to(device) 253 | 254 | if args.ckpt is not None: 255 | ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) 256 | 257 | gen.load_state_dict(ckpt["g"]) 258 | g_ema.load_state_dict(ckpt["g_ema"]) 259 | dis.load_state_dict(ckpt["d"]) 260 | 261 | if args.distributed: 262 | gen = nn.parallel.DistributedDataParallel( 263 | gen, 264 | device_ids=[args.local_rank], 265 | output_device=args.local_rank, 266 | broadcast_buffers=True, 267 | ) 268 | 269 | dis = nn.parallel.DistributedDataParallel( 270 | dis, 271 | device_ids=[args.local_rank], 272 | output_device=args.local_rank, 273 | broadcast_buffers=True, 274 | ) 275 | 276 | train(args, dset, gen, dis, g_ema, device) 277 | --------------------------------------------------------------------------------