├── .gitignore ├── LICENSE ├── Px437_IBM_VGA_8x16.ttf ├── README.md ├── checkpoint └── .gitignore ├── generate.py ├── generated.png ├── model.py ├── prepare_data.py ├── sample └── .gitignore ├── stylegan2 ├── .gitignore ├── LICENSE ├── LICENSE-FID ├── LICENSE-LPIPS ├── LICENSE-NVIDIA ├── README.md ├── calc_inception.py ├── checkpoint │ └── .gitignore ├── convert_weight.py ├── dataset.py ├── distributed.py ├── doc │ ├── sample-metfaces.png │ ├── sample.png │ ├── stylegan2-church-config-f.png │ └── stylegan2-ffhq-config-f.png ├── fid.py ├── generate.py ├── inception.py ├── inception_ffhq.pkl ├── lpips │ ├── __init__.py │ ├── base_model.py │ ├── dist_model.py │ ├── networks_basic.py │ ├── pretrained_networks.py │ └── weights │ │ ├── v0.0 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth │ │ └── v0.1 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth ├── model.py ├── non_leaking.py ├── op │ ├── __init__.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── ppl.py ├── projector.py ├── sample │ └── .gitignore └── train.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | wandb/ 132 | *.lmdb 133 | -------------------------------------------------------------------------------- /Px437_IBM_VGA_8x16.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/swapping-autoencoder-pytorch/8265a8a4497ea098c83bbb47bf33960c999e7d7e/Px437_IBM_VGA_8x16.ttf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # swapping-autoencoder-pytorch 2 | Unofficial implementation of Swapping Autoencoder for Deep Image Manipulation (https://arxiv.org/abs/2007.00653) in PyTorch 3 | 4 | ## Usage 5 | 6 | First create lmdb datasets: 7 | 8 | > python prepare_data.py --out LMDB_PATH --n_worker N_WORKER --size SIZE1,SIZE2,SIZE3,... DATASET_PATH 9 | 10 | This will convert images to jpeg and pre-resizes it. This implementation does not use progressive growing, but you can create multiple resolution datasets using size arguments with comma separated lists, for the cases that you want to try another resolutions later. 11 | 12 | Then you can train model in distributed settings 13 | 14 | > python -m torch.distributed.launch --nproc_per_node=N_GPU --master_port=PORT train.py --batch BATCH_SIZE LMDB_PATH 15 | 16 | train.py supports Weights & Biases logging. If you want to use it, add --wandb arguments to the script. 17 | 18 | ### Generate samples 19 | 20 | You can test trained model using `generate.py` 21 | 22 | > python generate.py --ckpt [CHECKPOINT PATH] IMG1 IMG2 IMG3 ... 23 | 24 | ## Samples 25 | 26 | ![Generated sample image](generated.png) 27 | -------------------------------------------------------------------------------- /checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image, ImageFont, ImageDraw 6 | from torchvision import utils 7 | 8 | 9 | from model import Encoder, Generator 10 | 11 | 12 | def render(text, size, font, background=(0, 0, 0), foreground=(85, 255, 85)): 13 | total_height = 0 14 | max_width = 0 15 | 16 | for line in text.split("\n"): 17 | text_width, text_height = font.getsize(line) 18 | max_width = max(max_width, text_width) 19 | total_height += text_height 20 | 21 | width, height = size 22 | start_w = max((width - max_width) // 2, 0) 23 | start_h = max((height - total_height) // 2, 0) 24 | 25 | image = Image.new("RGB", size, background) 26 | draw = ImageDraw.Draw(image) 27 | draw.text((start_w, start_h), text, font=font, fill=foreground) 28 | 29 | return image 30 | 31 | 32 | def pil_to_tensor(pil_img): 33 | return ( 34 | torch.from_numpy(np.array(pil_img)) 35 | .to(torch.float32) 36 | .div(255) 37 | .add(-0.5) 38 | .mul(2) 39 | .permute(2, 0, 1) 40 | .unsqueeze(0) 41 | ) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | 47 | parser.add_argument("--out", type=str, default="generated.png") 48 | parser.add_argument("--ckpt", type=str, required=True) 49 | parser.add_argument("files", type=str, nargs="+") 50 | 51 | args = parser.parse_args() 52 | 53 | ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) 54 | ckpt_args = ckpt["args"] 55 | imgsize = ckpt_args.size 56 | 57 | enc = Encoder(ckpt_args.channel) 58 | gen = Generator(ckpt_args.channel) 59 | enc.load_state_dict(ckpt["e_ema"]) 60 | gen.load_state_dict(ckpt["g_ema"]) 61 | enc.eval() 62 | gen.eval() 63 | 64 | imgs = [] 65 | 66 | for imgpath in args.files[: len(args.files) // 2 * 2]: 67 | img = Image.open(imgpath).convert("RGB").resize((imgsize, imgsize)) 68 | img_a = ( 69 | torch.from_numpy(np.array(img)) 70 | .to(torch.float32) 71 | .div(255) 72 | .add_(-0.5) 73 | .mul_(2) 74 | .permute(2, 0, 1) 75 | ) 76 | imgs.append(img_a) 77 | 78 | imgs = torch.stack(imgs, 0) 79 | img1, img2 = imgs.chunk(2, dim=0) 80 | 81 | with torch.no_grad(): 82 | struct1, texture1 = enc(img1) 83 | struct2, texture2 = enc(img2) 84 | 85 | out1 = gen(struct1, texture1) 86 | out2 = gen(struct2, texture2) 87 | out12 = gen(struct1, texture2) 88 | out21 = gen(struct2, texture1) 89 | 90 | font = ImageFont.truetype( 91 | "/root/works/sandbox/swapping-autoencoder/Px437_IBM_VGA_8x16.ttf", 16 92 | ) 93 | 94 | guide1 = render('original\nfirst half of batch → "A"', (256, 256), font) 95 | guide2 = render('reconstruction of "A"', (256, 256), font) 96 | guide3 = render('original\nsecond half of batch → "B"', (256, 256), font) 97 | guide4 = render('reconstruction of "B"', (256, 256), font) 98 | guide5 = render( 99 | 'swapped image\nstructure of "A"\n+\ntexture of "B"', (256, 256), font 100 | ) 101 | guide6 = render( 102 | 'swapped image\nstructure of "B"\n+\ntexture of "A"', (256, 256), font 103 | ) 104 | 105 | imgsets = [ 106 | pil_to_tensor(guide1), 107 | img1, 108 | pil_to_tensor(guide2), 109 | out1, 110 | pil_to_tensor(guide3), 111 | img2, 112 | pil_to_tensor(guide4), 113 | out2, 114 | pil_to_tensor(guide5), 115 | out12, 116 | pil_to_tensor(guide6), 117 | out21, 118 | ] 119 | imgsets = torch.cat(imgsets, 0) 120 | grid = utils.save_image( 121 | imgsets, args.out, nrow=out1.shape[0] + 1, normalize=True, range=(-1, 1) 122 | ) 123 | 124 | -------------------------------------------------------------------------------- /generated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/swapping-autoencoder-pytorch/8265a8a4497ea098c83bbb47bf33960c999e7d7e/generated.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from stylegan2.model import StyledConv, Blur, EqualLinear, EqualConv2d, ScaledLeakyReLU 8 | from stylegan2.op import FusedLeakyReLU 9 | 10 | 11 | class EqualConvTranspose2d(nn.Module): 12 | def __init__( 13 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 14 | ): 15 | super().__init__() 16 | 17 | self.weight = nn.Parameter( 18 | torch.randn(in_channel, out_channel, kernel_size, kernel_size) 19 | ) 20 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 21 | 22 | self.stride = stride 23 | self.padding = padding 24 | 25 | if bias: 26 | self.bias = nn.Parameter(torch.zeros(out_channel)) 27 | 28 | else: 29 | self.bias = None 30 | 31 | def forward(self, input): 32 | out = F.conv_transpose2d( 33 | input, 34 | self.weight * self.scale, 35 | bias=self.bias, 36 | stride=self.stride, 37 | padding=self.padding, 38 | ) 39 | 40 | return out 41 | 42 | def __repr__(self): 43 | return ( 44 | f"{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]}," 45 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" 46 | ) 47 | 48 | 49 | class ConvLayer(nn.Sequential): 50 | def __init__( 51 | self, 52 | in_channel, 53 | out_channel, 54 | kernel_size, 55 | upsample=False, 56 | downsample=False, 57 | blur_kernel=(1, 3, 3, 1), 58 | bias=True, 59 | activate=True, 60 | padding="zero", 61 | ): 62 | layers = [] 63 | 64 | self.padding = 0 65 | stride = 1 66 | 67 | if downsample: 68 | factor = 2 69 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 70 | pad0 = (p + 1) // 2 71 | pad1 = p // 2 72 | 73 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 74 | 75 | stride = 2 76 | 77 | if upsample: 78 | layers.append( 79 | EqualConvTranspose2d( 80 | in_channel, 81 | out_channel, 82 | kernel_size, 83 | padding=0, 84 | stride=2, 85 | bias=bias and not activate, 86 | ) 87 | ) 88 | 89 | factor = 2 90 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 91 | pad0 = (p + 1) // 2 + factor - 1 92 | pad1 = p // 2 + 1 93 | 94 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 95 | 96 | else: 97 | if not downsample: 98 | if padding == "zero": 99 | self.padding = (kernel_size - 1) // 2 100 | 101 | elif padding == "reflect": 102 | padding = (kernel_size - 1) // 2 103 | 104 | if padding > 0: 105 | layers.append(nn.ReflectionPad2d(padding)) 106 | 107 | self.padding = 0 108 | 109 | elif padding != "valid": 110 | raise ValueError('Padding should be "zero", "reflect", or "valid"') 111 | 112 | layers.append( 113 | EqualConv2d( 114 | in_channel, 115 | out_channel, 116 | kernel_size, 117 | padding=self.padding, 118 | stride=stride, 119 | bias=bias and not activate, 120 | ) 121 | ) 122 | 123 | if activate: 124 | if bias: 125 | layers.append(FusedLeakyReLU(out_channel)) 126 | 127 | else: 128 | layers.append(ScaledLeakyReLU(0.2)) 129 | 130 | super().__init__(*layers) 131 | 132 | 133 | class StyledResBlock(nn.Module): 134 | def __init__( 135 | self, in_channel, out_channel, style_dim, upsample, blur_kernel=(1, 3, 3, 1) 136 | ): 137 | super().__init__() 138 | 139 | self.conv1 = StyledConv( 140 | in_channel, 141 | out_channel, 142 | 3, 143 | style_dim, 144 | upsample=upsample, 145 | blur_kernel=blur_kernel, 146 | ) 147 | 148 | self.conv2 = StyledConv(out_channel, out_channel, 3, style_dim) 149 | 150 | if upsample or in_channel != out_channel: 151 | self.skip = ConvLayer( 152 | in_channel, 153 | out_channel, 154 | 1, 155 | upsample=upsample, 156 | blur_kernel=blur_kernel, 157 | bias=False, 158 | activate=False, 159 | ) 160 | 161 | else: 162 | self.skip = None 163 | 164 | def forward(self, input, style, noise=None): 165 | out = self.conv1(input, style, noise) 166 | out = self.conv2(out, style, noise) 167 | 168 | if self.skip is not None: 169 | skip = self.skip(input) 170 | 171 | else: 172 | skip = input 173 | 174 | return (out + skip) / math.sqrt(2) 175 | 176 | 177 | class ResBlock(nn.Module): 178 | def __init__( 179 | self, 180 | in_channel, 181 | out_channel, 182 | downsample, 183 | padding="zero", 184 | blur_kernel=(1, 3, 3, 1), 185 | ): 186 | super().__init__() 187 | 188 | self.conv1 = ConvLayer(in_channel, out_channel, 3, padding=padding) 189 | 190 | self.conv2 = ConvLayer( 191 | out_channel, 192 | out_channel, 193 | 3, 194 | downsample=downsample, 195 | padding=padding, 196 | blur_kernel=blur_kernel, 197 | ) 198 | 199 | if downsample or in_channel != out_channel: 200 | self.skip = ConvLayer( 201 | in_channel, 202 | out_channel, 203 | 1, 204 | downsample=downsample, 205 | blur_kernel=blur_kernel, 206 | bias=False, 207 | activate=False, 208 | ) 209 | 210 | else: 211 | self.skip = None 212 | 213 | def forward(self, input): 214 | out = self.conv1(input) 215 | out = self.conv2(out) 216 | 217 | if self.skip is not None: 218 | skip = self.skip(input) 219 | 220 | else: 221 | skip = input 222 | 223 | # print(out.shape) 224 | 225 | return (out + skip) / math.sqrt(2) 226 | 227 | 228 | class Encoder(nn.Module): 229 | def __init__( 230 | self, 231 | channel, 232 | structure_channel=8, 233 | texture_channel=2048, 234 | blur_kernel=(1, 3, 3, 1), 235 | ): 236 | super().__init__() 237 | 238 | stem = [ConvLayer(3, channel, 1)] 239 | 240 | in_channel = channel 241 | for i in range(1, 5): 242 | ch = channel * (2 ** i) 243 | stem.append(ResBlock(in_channel, ch, downsample=True, padding="reflect")) 244 | in_channel = ch 245 | 246 | self.stem = nn.Sequential(*stem) 247 | 248 | self.structure = nn.Sequential( 249 | ConvLayer(ch, ch, 1), ConvLayer(ch, structure_channel, 1) 250 | ) 251 | 252 | self.texture = nn.Sequential( 253 | ConvLayer(ch, ch * 2, 3, downsample=True, padding="valid"), 254 | ConvLayer(ch * 2, ch * 4, 3, downsample=True, padding="valid"), 255 | nn.AdaptiveAvgPool2d(1), 256 | ConvLayer(ch * 4, ch * 4, 1), 257 | ) 258 | 259 | def forward(self, input): 260 | out = self.stem(input) 261 | 262 | structure = self.structure(out) 263 | texture = torch.flatten(self.texture(out), 1) 264 | 265 | return structure, texture 266 | 267 | 268 | class Generator(nn.Module): 269 | def __init__( 270 | self, 271 | channel, 272 | structure_channel=8, 273 | texture_channel=2048, 274 | blur_kernel=(1, 3, 3, 1), 275 | ): 276 | super().__init__() 277 | 278 | ch_multiplier = (4, 8, 12, 16, 16, 16, 8, 4) 279 | upsample = (False, False, False, False, True, True, True, True) 280 | 281 | self.layers = nn.ModuleList() 282 | in_ch = structure_channel 283 | for ch_mul, up in zip(ch_multiplier, upsample): 284 | self.layers.append( 285 | StyledResBlock( 286 | in_ch, channel * ch_mul, texture_channel, up, blur_kernel 287 | ) 288 | ) 289 | in_ch = channel * ch_mul 290 | 291 | self.to_rgb = ConvLayer(in_ch, 3, 1, activate=False) 292 | 293 | def forward(self, structure, texture, noises=None): 294 | if noises is None: 295 | noises = [None] * len(self.layers) 296 | 297 | out = structure 298 | for layer, noise in zip(self.layers, noises): 299 | out = layer(out, texture, noise) 300 | 301 | out = self.to_rgb(out) 302 | 303 | return out 304 | 305 | 306 | class Discriminator(nn.Module): 307 | def __init__(self, size, channel_multiplier=1, blur_kernel=(1, 3, 3, 1)): 308 | super().__init__() 309 | 310 | channels = { 311 | 4: 512, 312 | 8: 512, 313 | 16: 512, 314 | 32: 512, 315 | 64: 256 * channel_multiplier, 316 | 128: 128 * channel_multiplier, 317 | 256: 64 * channel_multiplier, 318 | 512: 32 * channel_multiplier, 319 | 1024: 16 * channel_multiplier, 320 | } 321 | 322 | convs = [ConvLayer(3, channels[size], 1)] 323 | 324 | log_size = int(math.log(size, 2)) 325 | 326 | in_channel = channels[size] 327 | 328 | for i in range(log_size, 2, -1): 329 | out_channel = channels[2 ** (i - 1)] 330 | 331 | convs.append(ResBlock(in_channel, out_channel, downsample=True)) 332 | 333 | in_channel = out_channel 334 | 335 | self.convs = nn.Sequential(*convs) 336 | 337 | self.final_conv = ConvLayer(in_channel, channels[4], 3) 338 | self.final_linear = nn.Sequential( 339 | EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), 340 | EqualLinear(channels[4], 1), 341 | ) 342 | 343 | def forward(self, input): 344 | out = self.convs(input) 345 | out = self.final_conv(out) 346 | 347 | out = out.view(out.shape[0], -1) 348 | out = self.final_linear(out) 349 | 350 | return out 351 | 352 | 353 | class CooccurDiscriminator(nn.Module): 354 | def __init__(self, channel, size=256): 355 | super().__init__() 356 | 357 | encoder = [ConvLayer(3, channel, 1)] 358 | 359 | ch_multiplier = (2, 4, 8, 12, 12, 24) 360 | downsample = (True, True, True, True, True, False) 361 | in_ch = channel 362 | for ch_mul, down in zip(ch_multiplier, downsample): 363 | encoder.append(ResBlock(in_ch, channel * ch_mul, down)) 364 | in_ch = channel * ch_mul 365 | 366 | if size > 511: 367 | k_size = 3 368 | feat_size = 2 * 2 369 | 370 | else: 371 | k_size = 2 372 | feat_size = 1 * 1 373 | 374 | encoder.append(ConvLayer(in_ch, channel * 12, k_size, padding="valid")) 375 | 376 | self.encoder = nn.Sequential(*encoder) 377 | 378 | self.linear = nn.Sequential( 379 | EqualLinear( 380 | channel * 12 * 2 * feat_size, channel * 32, activation="fused_lrelu" 381 | ), 382 | EqualLinear(channel * 32, channel * 32, activation="fused_lrelu"), 383 | EqualLinear(channel * 32, channel * 16, activation="fused_lrelu"), 384 | EqualLinear(channel * 16, 1), 385 | ) 386 | 387 | def forward(self, input, reference=None, ref_batch=None, ref_input=None): 388 | # print(input.shape) 389 | out_input = self.encoder(input) 390 | 391 | if ref_input is None: 392 | ref_input = self.encoder(reference) 393 | _, channel, height, width = ref_input.shape 394 | ref_input = ref_input.view(-1, ref_batch, channel, height, width) 395 | ref_input = ref_input.mean(1) 396 | 397 | out = torch.cat((out_input, ref_input), 1) 398 | out = torch.flatten(out, 1) 399 | out = self.linear(out) 400 | 401 | return out, ref_input 402 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from io import BytesIO 3 | import multiprocessing 4 | from functools import partial 5 | 6 | from PIL import Image 7 | import lmdb 8 | from tqdm import tqdm 9 | from torchvision import datasets 10 | from torchvision.transforms import functional as trans_fn 11 | 12 | 13 | def resize_and_convert(img, size, resample, quality=100): 14 | img = trans_fn.resize(img, size, resample) 15 | img = trans_fn.center_crop(img, size) 16 | buffer = BytesIO() 17 | img.save(buffer, format='jpeg', quality=quality) 18 | val = buffer.getvalue() 19 | 20 | return val 21 | 22 | 23 | def resize_multiple(img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100): 24 | imgs = [] 25 | 26 | for size in sizes: 27 | imgs.append(resize_and_convert(img, size, resample, quality)) 28 | 29 | return imgs 30 | 31 | 32 | def resize_worker(img_file, sizes, resample): 33 | i, file = img_file 34 | img = Image.open(file) 35 | img = img.convert('RGB') 36 | out = resize_multiple(img, sizes=sizes, resample=resample) 37 | 38 | return i, out 39 | 40 | 41 | def prepare(env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS): 42 | resize_fn = partial(resize_worker, sizes=sizes, resample=resample) 43 | 44 | files = sorted(dataset.imgs, key=lambda x: x[0]) 45 | files = [(i, file) for i, (file, label) in enumerate(files)] 46 | total = 0 47 | 48 | with multiprocessing.Pool(n_worker) as pool: 49 | for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)): 50 | for size, img in zip(sizes, imgs): 51 | key = f'{size}-{str(i).zfill(5)}'.encode('utf-8') 52 | 53 | with env.begin(write=True) as txn: 54 | txn.put(key, img) 55 | 56 | total += 1 57 | 58 | with env.begin(write=True) as txn: 59 | txn.put('length'.encode('utf-8'), str(total).encode('utf-8')) 60 | 61 | 62 | if __name__ == '__main__': 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument('--out', type=str) 65 | parser.add_argument('--size', type=str, default='128,256,512,1024') 66 | parser.add_argument('--n_worker', type=int, default=8) 67 | parser.add_argument('--resample', type=str, default='lanczos') 68 | parser.add_argument('path', type=str) 69 | 70 | args = parser.parse_args() 71 | 72 | resample_map = {'lanczos': Image.LANCZOS, 'bilinear': Image.BILINEAR} 73 | resample = resample_map[args.resample] 74 | 75 | sizes = [int(s.strip()) for s in args.size.split(',')] 76 | 77 | print(f'Make dataset of image sizes:', ', '.join(str(s) for s in sizes)) 78 | 79 | imgset = datasets.ImageFolder(args.path) 80 | 81 | with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env: 82 | prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample) 83 | -------------------------------------------------------------------------------- /sample/.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | -------------------------------------------------------------------------------- /stylegan2/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | wandb/ 132 | *.lmdb/ 133 | *.pkl 134 | -------------------------------------------------------------------------------- /stylegan2/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /stylegan2/LICENSE-FID: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /stylegan2/LICENSE-LPIPS: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | -------------------------------------------------------------------------------- /stylegan2/LICENSE-NVIDIA: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | 3 | 4 | Nvidia Source Code License-NC 5 | 6 | ======================================================================= 7 | 8 | 1. Definitions 9 | 10 | "Licensor" means any person or entity that distributes its Work. 11 | 12 | "Software" means the original work of authorship made available under 13 | this License. 14 | 15 | "Work" means the Software and any additions to or derivative works of 16 | the Software that are made available under this License. 17 | 18 | "Nvidia Processors" means any central processing unit (CPU), graphics 19 | processing unit (GPU), field-programmable gate array (FPGA), 20 | application-specific integrated circuit (ASIC) or any combination 21 | thereof designed, made, sold, or provided by Nvidia or its affiliates. 22 | 23 | The terms "reproduce," "reproduction," "derivative works," and 24 | "distribution" have the meaning as provided under U.S. copyright law; 25 | provided, however, that for the purposes of this License, derivative 26 | works shall not include works that remain separable from, or merely 27 | link (or bind by name) to the interfaces of, the Work. 28 | 29 | Works, including the Software, are "made available" under this License 30 | by including in or with the Work either (a) a copyright notice 31 | referencing the applicability of this License to the Work, or (b) a 32 | copy of this License. 33 | 34 | 2. License Grants 35 | 36 | 2.1 Copyright Grant. Subject to the terms and conditions of this 37 | License, each Licensor grants to you a perpetual, worldwide, 38 | non-exclusive, royalty-free, copyright license to reproduce, 39 | prepare derivative works of, publicly display, publicly perform, 40 | sublicense and distribute its Work and any resulting derivative 41 | works in any form. 42 | 43 | 3. Limitations 44 | 45 | 3.1 Redistribution. You may reproduce or distribute the Work only 46 | if (a) you do so under this License, (b) you include a complete 47 | copy of this License with your distribution, and (c) you retain 48 | without modification any copyright, patent, trademark, or 49 | attribution notices that are present in the Work. 50 | 51 | 3.2 Derivative Works. You may specify that additional or different 52 | terms apply to the use, reproduction, and distribution of your 53 | derivative works of the Work ("Your Terms") only if (a) Your Terms 54 | provide that the use limitation in Section 3.3 applies to your 55 | derivative works, and (b) you identify the specific derivative 56 | works that are subject to Your Terms. Notwithstanding Your Terms, 57 | this License (including the redistribution requirements in Section 58 | 3.1) will continue to apply to the Work itself. 59 | 60 | 3.3 Use Limitation. The Work and any derivative works thereof only 61 | may be used or intended for use non-commercially. The Work or 62 | derivative works thereof may be used or intended for use by Nvidia 63 | or its affiliates commercially or non-commercially. As used herein, 64 | "non-commercially" means for research or evaluation purposes only. 65 | 66 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 67 | against any Licensor (including any claim, cross-claim or 68 | counterclaim in a lawsuit) to enforce any patents that you allege 69 | are infringed by any Work, then your rights under this License from 70 | such Licensor (including the grants in Sections 2.1 and 2.2) will 71 | terminate immediately. 72 | 73 | 3.5 Trademarks. This License does not grant any rights to use any 74 | Licensor's or its affiliates' names, logos, or trademarks, except 75 | as necessary to reproduce the notices described in this License. 76 | 77 | 3.6 Termination. If you violate any term of this License, then your 78 | rights under this License (including the grants in Sections 2.1 and 79 | 2.2) will terminate immediately. 80 | 81 | 4. Disclaimer of Warranty. 82 | 83 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 84 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 85 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 86 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 87 | THIS LICENSE. 88 | 89 | 5. Limitation of Liability. 90 | 91 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 92 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 93 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 94 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 95 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 96 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 97 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 98 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 99 | THE POSSIBILITY OF SUCH DAMAGES. 100 | 101 | ======================================================================= 102 | -------------------------------------------------------------------------------- /stylegan2/README.md: -------------------------------------------------------------------------------- 1 | # StyleGAN 2 in PyTorch 2 | 3 | Implementation of Analyzing and Improving the Image Quality of StyleGAN (https://arxiv.org/abs/1912.04958) in PyTorch 4 | 5 | ## Notice 6 | 7 | I have tried to match official implementation as close as possible, but maybe there are some details I missed. So please use this implementation with care. 8 | 9 | ## Requirements 10 | 11 | I have tested on: 12 | 13 | * PyTorch 1.3.1 14 | * CUDA 10.1/10.2 15 | 16 | ## Usage 17 | 18 | First create lmdb datasets: 19 | 20 | > python prepare_data.py --out LMDB_PATH --n_worker N_WORKER --size SIZE1,SIZE2,SIZE3,... DATASET_PATH 21 | 22 | This will convert images to jpeg and pre-resizes it. This implementation does not use progressive growing, but you can create multiple resolution datasets using size arguments with comma separated lists, for the cases that you want to try another resolutions later. 23 | 24 | Then you can train model in distributed settings 25 | 26 | > python -m torch.distributed.launch --nproc_per_node=N_GPU --master_port=PORT train.py --batch BATCH_SIZE LMDB_PATH 27 | 28 | train.py supports Weights & Biases logging. If you want to use it, add --wandb arguments to the script. 29 | 30 | ### Convert weight from official checkpoints 31 | 32 | You need to clone official repositories, (https://github.com/NVlabs/stylegan2) as it is requires for load official checkpoints. 33 | 34 | For example, if you cloned repositories in ~/stylegan2 and downloaded stylegan2-ffhq-config-f.pkl, You can convert it like this: 35 | 36 | > python convert_weight.py --repo ~/stylegan2 stylegan2-ffhq-config-f.pkl 37 | 38 | This will create converted stylegan2-ffhq-config-f.pt file. 39 | 40 | ### Generate samples 41 | 42 | > python generate.py --sample N_FACES --pics N_PICS --ckpt PATH_CHECKPOINT 43 | 44 | You should change your size (--size 256 for example) if you train with another dimension. 45 | 46 | ### Project images to latent spaces 47 | 48 | > python projector.py --ckpt [CHECKPOINT] --size [GENERATOR_OUTPUT_SIZE] FILE1 FILE2 ... 49 | 50 | ## Pretrained Checkpoints 51 | 52 | [Link](https://drive.google.com/open?id=1PQutd-JboOCOZqmd95XWxWrO8gGEvRcO) 53 | 54 | I have trained the 256px model on FFHQ 550k iterations. I got FID about 4.5. Maybe data preprocessing, resolution, training loop could made this difference, but currently I don't know the exact reason of FID differences. 55 | 56 | ## Samples 57 | 58 | ![Sample with truncation](doc/sample.png) 59 | 60 | Sample from FFHQ. At 110,000 iterations. (trained on 3.52M images) 61 | 62 | ![MetFaces sample with non-leaking augmentations](doc/sample-metfaces.png) 63 | 64 | Sample from MetFaces with Non-leaking augmentations. At 150,000 iterations. (trained on 4.8M images) 65 | 66 | 67 | ### Samples from converted weights 68 | 69 | ![Sample from FFHQ](doc/stylegan2-ffhq-config-f.png) 70 | 71 | Sample from FFHQ (1024px) 72 | 73 | ![Sample from LSUN Church](doc/stylegan2-church-config-f.png) 74 | 75 | Sample from LSUN Church (256px) 76 | 77 | ## License 78 | 79 | Model details and custom CUDA kernel codes are from official repostiories: https://github.com/NVlabs/stylegan2 80 | 81 | Codes for Learned Perceptual Image Patch Similarity, LPIPS came from https://github.com/richzhang/PerceptualSimilarity 82 | 83 | To match FID scores more closely to tensorflow official implementations, I have used FID Inception V3 implementations in https://github.com/mseitzer/pytorch-fid 84 | -------------------------------------------------------------------------------- /stylegan2/calc_inception.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import os 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | from torchvision.models import inception_v3, Inception3 11 | import numpy as np 12 | from tqdm import tqdm 13 | 14 | from inception import InceptionV3 15 | from dataset import MultiResolutionDataset 16 | 17 | 18 | class Inception3Feature(Inception3): 19 | def forward(self, x): 20 | if x.shape[2] != 299 or x.shape[3] != 299: 21 | x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True) 22 | 23 | x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3 24 | x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32 25 | x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32 26 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64 27 | 28 | x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64 29 | x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80 30 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192 31 | 32 | x = self.Mixed_5b(x) # 35 x 35 x 192 33 | x = self.Mixed_5c(x) # 35 x 35 x 256 34 | x = self.Mixed_5d(x) # 35 x 35 x 288 35 | 36 | x = self.Mixed_6a(x) # 35 x 35 x 288 37 | x = self.Mixed_6b(x) # 17 x 17 x 768 38 | x = self.Mixed_6c(x) # 17 x 17 x 768 39 | x = self.Mixed_6d(x) # 17 x 17 x 768 40 | x = self.Mixed_6e(x) # 17 x 17 x 768 41 | 42 | x = self.Mixed_7a(x) # 17 x 17 x 768 43 | x = self.Mixed_7b(x) # 8 x 8 x 1280 44 | x = self.Mixed_7c(x) # 8 x 8 x 2048 45 | 46 | x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048 47 | 48 | return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048 49 | 50 | 51 | def load_patched_inception_v3(): 52 | # inception = inception_v3(pretrained=True) 53 | # inception_feat = Inception3Feature() 54 | # inception_feat.load_state_dict(inception.state_dict()) 55 | inception_feat = InceptionV3([3], normalize_input=False) 56 | 57 | return inception_feat 58 | 59 | 60 | @torch.no_grad() 61 | def extract_features(loader, inception, device): 62 | pbar = tqdm(loader) 63 | 64 | feature_list = [] 65 | 66 | for img in pbar: 67 | img = img.to(device) 68 | feature = inception(img)[0].view(img.shape[0], -1) 69 | feature_list.append(feature.to('cpu')) 70 | 71 | features = torch.cat(feature_list, 0) 72 | 73 | return features 74 | 75 | 76 | if __name__ == '__main__': 77 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 78 | 79 | parser = argparse.ArgumentParser( 80 | description='Calculate Inception v3 features for datasets' 81 | ) 82 | parser.add_argument('--size', type=int, default=256) 83 | parser.add_argument('--batch', default=64, type=int, help='batch size') 84 | parser.add_argument('--n_sample', type=int, default=50000) 85 | parser.add_argument('--flip', action='store_true') 86 | parser.add_argument('path', metavar='PATH', help='path to datset lmdb file') 87 | 88 | args = parser.parse_args() 89 | 90 | inception = load_patched_inception_v3() 91 | inception = nn.DataParallel(inception).eval().to(device) 92 | 93 | transform = transforms.Compose( 94 | [ 95 | transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0), 96 | transforms.ToTensor(), 97 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 98 | ] 99 | ) 100 | 101 | dset = MultiResolutionDataset(args.path, transform=transform, resolution=args.size) 102 | loader = DataLoader(dset, batch_size=args.batch, num_workers=4) 103 | 104 | features = extract_features(loader, inception, device).numpy() 105 | 106 | features = features[: args.n_sample] 107 | 108 | print(f'extracted {features.shape[0]} features') 109 | 110 | mean = np.mean(features, 0) 111 | cov = np.cov(features, rowvar=False) 112 | 113 | name = os.path.splitext(os.path.basename(args.path))[0] 114 | 115 | with open(f'inception_{name}.pkl', 'wb') as f: 116 | pickle.dump({'mean': mean, 'cov': cov, 'size': args.size, 'path': args.path}, f) 117 | -------------------------------------------------------------------------------- /stylegan2/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | -------------------------------------------------------------------------------- /stylegan2/convert_weight.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import pickle 5 | import math 6 | 7 | import torch 8 | import numpy as np 9 | from torchvision import utils 10 | 11 | from model import Generator, Discriminator 12 | 13 | 14 | def convert_modconv(vars, source_name, target_name, flip=False): 15 | weight = vars[source_name + "/weight"].value().eval() 16 | mod_weight = vars[source_name + "/mod_weight"].value().eval() 17 | mod_bias = vars[source_name + "/mod_bias"].value().eval() 18 | noise = vars[source_name + "/noise_strength"].value().eval() 19 | bias = vars[source_name + "/bias"].value().eval() 20 | 21 | dic = { 22 | "conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0), 23 | "conv.modulation.weight": mod_weight.transpose((1, 0)), 24 | "conv.modulation.bias": mod_bias + 1, 25 | "noise.weight": np.array([noise]), 26 | "activate.bias": bias, 27 | } 28 | 29 | dic_torch = {} 30 | 31 | for k, v in dic.items(): 32 | dic_torch[target_name + "." + k] = torch.from_numpy(v) 33 | 34 | if flip: 35 | dic_torch[target_name + ".conv.weight"] = torch.flip( 36 | dic_torch[target_name + ".conv.weight"], [3, 4] 37 | ) 38 | 39 | return dic_torch 40 | 41 | 42 | def convert_conv(vars, source_name, target_name, bias=True, start=0): 43 | weight = vars[source_name + "/weight"].value().eval() 44 | 45 | dic = {"weight": weight.transpose((3, 2, 0, 1))} 46 | 47 | if bias: 48 | dic["bias"] = vars[source_name + "/bias"].value().eval() 49 | 50 | dic_torch = {} 51 | 52 | dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"]) 53 | 54 | if bias: 55 | dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"]) 56 | 57 | return dic_torch 58 | 59 | 60 | def convert_torgb(vars, source_name, target_name): 61 | weight = vars[source_name + "/weight"].value().eval() 62 | mod_weight = vars[source_name + "/mod_weight"].value().eval() 63 | mod_bias = vars[source_name + "/mod_bias"].value().eval() 64 | bias = vars[source_name + "/bias"].value().eval() 65 | 66 | dic = { 67 | "conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0), 68 | "conv.modulation.weight": mod_weight.transpose((1, 0)), 69 | "conv.modulation.bias": mod_bias + 1, 70 | "bias": bias.reshape((1, 3, 1, 1)), 71 | } 72 | 73 | dic_torch = {} 74 | 75 | for k, v in dic.items(): 76 | dic_torch[target_name + "." + k] = torch.from_numpy(v) 77 | 78 | return dic_torch 79 | 80 | 81 | def convert_dense(vars, source_name, target_name): 82 | weight = vars[source_name + "/weight"].value().eval() 83 | bias = vars[source_name + "/bias"].value().eval() 84 | 85 | dic = {"weight": weight.transpose((1, 0)), "bias": bias} 86 | 87 | dic_torch = {} 88 | 89 | for k, v in dic.items(): 90 | dic_torch[target_name + "." + k] = torch.from_numpy(v) 91 | 92 | return dic_torch 93 | 94 | 95 | def update(state_dict, new): 96 | for k, v in new.items(): 97 | if k not in state_dict: 98 | raise KeyError(k + " is not found") 99 | 100 | if v.shape != state_dict[k].shape: 101 | raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}") 102 | 103 | state_dict[k] = v 104 | 105 | 106 | def discriminator_fill_statedict(statedict, vars, size): 107 | log_size = int(math.log(size, 2)) 108 | 109 | update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0")) 110 | 111 | conv_i = 1 112 | 113 | for i in range(log_size - 2, 0, -1): 114 | reso = 4 * 2 ** i 115 | update( 116 | statedict, 117 | convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"), 118 | ) 119 | update( 120 | statedict, 121 | convert_conv( 122 | vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1 123 | ), 124 | ) 125 | update( 126 | statedict, 127 | convert_conv( 128 | vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False 129 | ), 130 | ) 131 | conv_i += 1 132 | 133 | update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv")) 134 | update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0")) 135 | update(statedict, convert_dense(vars, f"Output", "final_linear.1")) 136 | 137 | return statedict 138 | 139 | 140 | def fill_statedict(state_dict, vars, size): 141 | log_size = int(math.log(size, 2)) 142 | 143 | for i in range(8): 144 | update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"style.{i + 1}")) 145 | 146 | update( 147 | state_dict, 148 | { 149 | "input.input": torch.from_numpy( 150 | vars["G_synthesis/4x4/Const/const"].value().eval() 151 | ) 152 | }, 153 | ) 154 | 155 | update(state_dict, convert_torgb(vars, "G_synthesis/4x4/ToRGB", "to_rgb1")) 156 | 157 | for i in range(log_size - 2): 158 | reso = 4 * 2 ** (i + 1) 159 | update( 160 | state_dict, 161 | convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"to_rgbs.{i}"), 162 | ) 163 | 164 | update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "conv1")) 165 | 166 | conv_i = 0 167 | 168 | for i in range(log_size - 2): 169 | reso = 4 * 2 ** (i + 1) 170 | update( 171 | state_dict, 172 | convert_modconv( 173 | vars, 174 | f"G_synthesis/{reso}x{reso}/Conv0_up", 175 | f"convs.{conv_i}", 176 | flip=True, 177 | ), 178 | ) 179 | update( 180 | state_dict, 181 | convert_modconv( 182 | vars, f"G_synthesis/{reso}x{reso}/Conv1", f"convs.{conv_i + 1}" 183 | ), 184 | ) 185 | conv_i += 2 186 | 187 | for i in range(0, (log_size - 2) * 2 + 1): 188 | update( 189 | state_dict, 190 | { 191 | f"noises.noise_{i}": torch.from_numpy( 192 | vars[f"G_synthesis/noise{i}"].value().eval() 193 | ) 194 | }, 195 | ) 196 | 197 | return state_dict 198 | 199 | 200 | if __name__ == "__main__": 201 | device = "cuda" 202 | 203 | parser = argparse.ArgumentParser() 204 | parser.add_argument("--repo", type=str, required=True) 205 | parser.add_argument("--gen", action="store_true") 206 | parser.add_argument("--disc", action="store_true") 207 | parser.add_argument("--channel_multiplier", type=int, default=2) 208 | parser.add_argument("path", metavar="PATH") 209 | 210 | args = parser.parse_args() 211 | 212 | sys.path.append(args.repo) 213 | 214 | import dnnlib 215 | from dnnlib import tflib 216 | 217 | tflib.init_tf() 218 | 219 | with open(args.path, "rb") as f: 220 | generator, discriminator, g_ema = pickle.load(f) 221 | 222 | size = g_ema.output_shape[2] 223 | 224 | g = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier) 225 | state_dict = g.state_dict() 226 | state_dict = fill_statedict(state_dict, g_ema.vars, size) 227 | 228 | g.load_state_dict(state_dict) 229 | 230 | latent_avg = torch.from_numpy(g_ema.vars["dlatent_avg"].value().eval()) 231 | 232 | ckpt = {"g_ema": state_dict, "latent_avg": latent_avg} 233 | 234 | if args.gen: 235 | g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier) 236 | g_train_state = g_train.state_dict() 237 | g_train_state = fill_statedict(g_train_state, generator.vars, size) 238 | ckpt["g"] = g_train_state 239 | 240 | if args.disc: 241 | disc = Discriminator(size, channel_multiplier=args.channel_multiplier) 242 | d_state = disc.state_dict() 243 | d_state = discriminator_fill_statedict(d_state, discriminator.vars, size) 244 | ckpt["d"] = d_state 245 | 246 | name = os.path.splitext(os.path.basename(args.path))[0] 247 | torch.save(ckpt, name + ".pt") 248 | 249 | batch_size = {256: 16, 512: 9, 1024: 4} 250 | n_sample = batch_size.get(size, 25) 251 | 252 | g = g.to(device) 253 | 254 | z = np.random.RandomState(0).randn(n_sample, 512).astype("float32") 255 | 256 | with torch.no_grad(): 257 | img_pt, _ = g( 258 | [torch.from_numpy(z).to(device)], 259 | truncation=0.5, 260 | truncation_latent=latent_avg.to(device), 261 | randomize_noise=False, 262 | ) 263 | 264 | Gs_kwargs = dnnlib.EasyDict() 265 | Gs_kwargs.randomize_noise = False 266 | img_tf = g_ema.run(z, None, **Gs_kwargs) 267 | img_tf = torch.from_numpy(img_tf).to(device) 268 | 269 | img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp( 270 | 0.0, 1.0 271 | ) 272 | 273 | img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0) 274 | 275 | print(img_diff.abs().max()) 276 | 277 | utils.save_image( 278 | img_concat, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1) 279 | ) 280 | 281 | -------------------------------------------------------------------------------- /stylegan2/dataset.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import lmdb 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class MultiResolutionDataset(Dataset): 9 | def __init__(self, path, transform, resolution=256): 10 | self.env = lmdb.open( 11 | path, 12 | max_readers=32, 13 | readonly=True, 14 | lock=False, 15 | readahead=False, 16 | meminit=False, 17 | ) 18 | 19 | if not self.env: 20 | raise IOError('Cannot open lmdb dataset', path) 21 | 22 | with self.env.begin(write=False) as txn: 23 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 24 | 25 | self.resolution = resolution 26 | self.transform = transform 27 | 28 | def __len__(self): 29 | return self.length 30 | 31 | def __getitem__(self, index): 32 | with self.env.begin(write=False) as txn: 33 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') 34 | img_bytes = txn.get(key) 35 | 36 | buffer = BytesIO(img_bytes) 37 | img = Image.open(buffer) 38 | img = self.transform(img) 39 | 40 | return img 41 | -------------------------------------------------------------------------------- /stylegan2/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils.data.sampler import Sampler 7 | 8 | 9 | def get_rank(): 10 | if not dist.is_available(): 11 | return 0 12 | 13 | if not dist.is_initialized(): 14 | return 0 15 | 16 | return dist.get_rank() 17 | 18 | 19 | def synchronize(): 20 | if not dist.is_available(): 21 | return 22 | 23 | if not dist.is_initialized(): 24 | return 25 | 26 | world_size = dist.get_world_size() 27 | 28 | if world_size == 1: 29 | return 30 | 31 | dist.barrier() 32 | 33 | 34 | def get_world_size(): 35 | if not dist.is_available(): 36 | return 1 37 | 38 | if not dist.is_initialized(): 39 | return 1 40 | 41 | return dist.get_world_size() 42 | 43 | 44 | def reduce_sum(tensor): 45 | if not dist.is_available(): 46 | return tensor 47 | 48 | if not dist.is_initialized(): 49 | return tensor 50 | 51 | tensor = tensor.clone() 52 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 53 | 54 | return tensor 55 | 56 | 57 | def gather_grad(params): 58 | world_size = get_world_size() 59 | 60 | if world_size == 1: 61 | return 62 | 63 | for param in params: 64 | if param.grad is not None: 65 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 66 | param.grad.data.div_(world_size) 67 | 68 | 69 | def all_gather(data): 70 | world_size = get_world_size() 71 | 72 | if world_size == 1: 73 | return [data] 74 | 75 | buffer = pickle.dumps(data) 76 | storage = torch.ByteStorage.from_buffer(buffer) 77 | tensor = torch.ByteTensor(storage).to('cuda') 78 | 79 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 80 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 81 | dist.all_gather(size_list, local_size) 82 | size_list = [int(size.item()) for size in size_list] 83 | max_size = max(size_list) 84 | 85 | tensor_list = [] 86 | for _ in size_list: 87 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 88 | 89 | if local_size != max_size: 90 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 91 | tensor = torch.cat((tensor, padding), 0) 92 | 93 | dist.all_gather(tensor_list, tensor) 94 | 95 | data_list = [] 96 | 97 | for size, tensor in zip(size_list, tensor_list): 98 | buffer = tensor.cpu().numpy().tobytes()[:size] 99 | data_list.append(pickle.loads(buffer)) 100 | 101 | return data_list 102 | 103 | 104 | def reduce_loss_dict(loss_dict): 105 | world_size = get_world_size() 106 | 107 | if world_size < 2: 108 | return loss_dict 109 | 110 | with torch.no_grad(): 111 | keys = [] 112 | losses = [] 113 | 114 | for k in sorted(loss_dict.keys()): 115 | keys.append(k) 116 | losses.append(loss_dict[k]) 117 | 118 | losses = torch.stack(losses, 0) 119 | dist.reduce(losses, dst=0) 120 | 121 | if dist.get_rank() == 0: 122 | losses /= world_size 123 | 124 | reduced_losses = {k: v for k, v in zip(keys, losses)} 125 | 126 | return reduced_losses 127 | -------------------------------------------------------------------------------- /stylegan2/doc/sample-metfaces.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/swapping-autoencoder-pytorch/8265a8a4497ea098c83bbb47bf33960c999e7d7e/stylegan2/doc/sample-metfaces.png -------------------------------------------------------------------------------- /stylegan2/doc/sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/swapping-autoencoder-pytorch/8265a8a4497ea098c83bbb47bf33960c999e7d7e/stylegan2/doc/sample.png -------------------------------------------------------------------------------- /stylegan2/doc/stylegan2-church-config-f.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/swapping-autoencoder-pytorch/8265a8a4497ea098c83bbb47bf33960c999e7d7e/stylegan2/doc/stylegan2-church-config-f.png -------------------------------------------------------------------------------- /stylegan2/doc/stylegan2-ffhq-config-f.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/swapping-autoencoder-pytorch/8265a8a4497ea098c83bbb47bf33960c999e7d7e/stylegan2/doc/stylegan2-ffhq-config-f.png -------------------------------------------------------------------------------- /stylegan2/fid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | import torch 5 | from torch import nn 6 | import numpy as np 7 | from scipy import linalg 8 | from tqdm import tqdm 9 | 10 | from model import Generator 11 | from calc_inception import load_patched_inception_v3 12 | 13 | 14 | @torch.no_grad() 15 | def extract_feature_from_samples( 16 | generator, inception, truncation, truncation_latent, batch_size, n_sample, device 17 | ): 18 | n_batch = n_sample // batch_size 19 | resid = n_sample - (n_batch * batch_size) 20 | batch_sizes = [batch_size] * n_batch + [resid] 21 | features = [] 22 | 23 | for batch in tqdm(batch_sizes): 24 | latent = torch.randn(batch, 512, device=device) 25 | img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent) 26 | feat = inception(img)[0].view(img.shape[0], -1) 27 | features.append(feat.to('cpu')) 28 | 29 | features = torch.cat(features, 0) 30 | 31 | return features 32 | 33 | 34 | def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6): 35 | cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False) 36 | 37 | if not np.isfinite(cov_sqrt).all(): 38 | print('product of cov matrices is singular') 39 | offset = np.eye(sample_cov.shape[0]) * eps 40 | cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset)) 41 | 42 | if np.iscomplexobj(cov_sqrt): 43 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 44 | m = np.max(np.abs(cov_sqrt.imag)) 45 | 46 | raise ValueError(f'Imaginary component {m}') 47 | 48 | cov_sqrt = cov_sqrt.real 49 | 50 | mean_diff = sample_mean - real_mean 51 | mean_norm = mean_diff @ mean_diff 52 | 53 | trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt) 54 | 55 | fid = mean_norm + trace 56 | 57 | return fid 58 | 59 | 60 | if __name__ == '__main__': 61 | device = 'cuda' 62 | 63 | parser = argparse.ArgumentParser() 64 | 65 | parser.add_argument('--truncation', type=float, default=1) 66 | parser.add_argument('--truncation_mean', type=int, default=4096) 67 | parser.add_argument('--batch', type=int, default=64) 68 | parser.add_argument('--n_sample', type=int, default=50000) 69 | parser.add_argument('--size', type=int, default=256) 70 | parser.add_argument('--inception', type=str, default=None, required=True) 71 | parser.add_argument('ckpt', metavar='CHECKPOINT') 72 | 73 | args = parser.parse_args() 74 | 75 | ckpt = torch.load(args.ckpt) 76 | 77 | g = Generator(args.size, 512, 8).to(device) 78 | g.load_state_dict(ckpt['g_ema']) 79 | g = nn.DataParallel(g) 80 | g.eval() 81 | 82 | if args.truncation < 1: 83 | with torch.no_grad(): 84 | mean_latent = g.mean_latent(args.truncation_mean) 85 | 86 | else: 87 | mean_latent = None 88 | 89 | inception = nn.DataParallel(load_patched_inception_v3()).to(device) 90 | inception.eval() 91 | 92 | features = extract_feature_from_samples( 93 | g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device 94 | ).numpy() 95 | print(f'extracted {features.shape[0]} features') 96 | 97 | sample_mean = np.mean(features, 0) 98 | sample_cov = np.cov(features, rowvar=False) 99 | 100 | with open(args.inception, 'rb') as f: 101 | embeds = pickle.load(f) 102 | real_mean = embeds['mean'] 103 | real_cov = embeds['cov'] 104 | 105 | fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) 106 | 107 | print('fid:', fid) 108 | -------------------------------------------------------------------------------- /stylegan2/generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torchvision import utils 5 | from model import Generator 6 | from tqdm import tqdm 7 | def generate(args, g_ema, device, mean_latent): 8 | 9 | with torch.no_grad(): 10 | g_ema.eval() 11 | for i in tqdm(range(args.pics)): 12 | sample_z = torch.randn(args.sample, args.latent, device=device) 13 | 14 | sample, _ = g_ema([sample_z], truncation=args.truncation, truncation_latent=mean_latent) 15 | 16 | utils.save_image( 17 | sample, 18 | f'sample/{str(i).zfill(6)}.png', 19 | nrow=1, 20 | normalize=True, 21 | range=(-1, 1), 22 | ) 23 | 24 | if __name__ == '__main__': 25 | device = 'cuda' 26 | 27 | parser = argparse.ArgumentParser() 28 | 29 | parser.add_argument('--size', type=int, default=1024) 30 | parser.add_argument('--sample', type=int, default=1) 31 | parser.add_argument('--pics', type=int, default=20) 32 | parser.add_argument('--truncation', type=float, default=1) 33 | parser.add_argument('--truncation_mean', type=int, default=4096) 34 | parser.add_argument('--ckpt', type=str, default="stylegan2-ffhq-config-f.pt") 35 | parser.add_argument('--channel_multiplier', type=int, default=2) 36 | 37 | args = parser.parse_args() 38 | 39 | args.latent = 512 40 | args.n_mlp = 8 41 | 42 | g_ema = Generator( 43 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier 44 | ).to(device) 45 | checkpoint = torch.load(args.ckpt) 46 | 47 | g_ema.load_state_dict(checkpoint['g_ema']) 48 | 49 | if args.truncation < 1: 50 | with torch.no_grad(): 51 | mean_latent = g_ema.mean_latent(args.truncation_mean) 52 | else: 53 | mean_latent = None 54 | 55 | generate(args, g_ema, device, mean_latent) 56 | -------------------------------------------------------------------------------- /stylegan2/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=[DEFAULT_BLOCK_INDEX], 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | 39 | Parameters 40 | ---------- 41 | output_blocks : list of int 42 | Indices of blocks to return features of. Possible values are: 43 | - 0: corresponds to output of first max pooling 44 | - 1: corresponds to output of second max pooling 45 | - 2: corresponds to output which is fed to aux classifier 46 | - 3: corresponds to output of final average pooling 47 | resize_input : bool 48 | If true, bilinearly resizes input to width and height 299 before 49 | feeding input to model. As the network without fully connected 50 | layers is fully convolutional, it should be able to handle inputs 51 | of arbitrary size, so resizing might not be strictly needed 52 | normalize_input : bool 53 | If true, scales the input from range (0, 1) to the range the 54 | pretrained Inception network expects, namely (-1, 1) 55 | requires_grad : bool 56 | If true, parameters of the model require gradients. Possibly useful 57 | for finetuning the network 58 | use_fid_inception : bool 59 | If true, uses the pretrained Inception model used in Tensorflow's 60 | FID implementation. If false, uses the pretrained Inception model 61 | available in torchvision. The FID Inception model has different 62 | weights and a slightly different structure from torchvision's 63 | Inception model. If you want to compute FID scores, you are 64 | strongly advised to set this parameter to true to get comparable 65 | results. 66 | """ 67 | super(InceptionV3, self).__init__() 68 | 69 | self.resize_input = resize_input 70 | self.normalize_input = normalize_input 71 | self.output_blocks = sorted(output_blocks) 72 | self.last_needed_block = max(output_blocks) 73 | 74 | assert self.last_needed_block <= 3, \ 75 | 'Last possible output block index is 3' 76 | 77 | self.blocks = nn.ModuleList() 78 | 79 | if use_fid_inception: 80 | inception = fid_inception_v3() 81 | else: 82 | inception = models.inception_v3(pretrained=True) 83 | 84 | # Block 0: input to maxpool1 85 | block0 = [ 86 | inception.Conv2d_1a_3x3, 87 | inception.Conv2d_2a_3x3, 88 | inception.Conv2d_2b_3x3, 89 | nn.MaxPool2d(kernel_size=3, stride=2) 90 | ] 91 | self.blocks.append(nn.Sequential(*block0)) 92 | 93 | # Block 1: maxpool1 to maxpool2 94 | if self.last_needed_block >= 1: 95 | block1 = [ 96 | inception.Conv2d_3b_1x1, 97 | inception.Conv2d_4a_3x3, 98 | nn.MaxPool2d(kernel_size=3, stride=2) 99 | ] 100 | self.blocks.append(nn.Sequential(*block1)) 101 | 102 | # Block 2: maxpool2 to aux classifier 103 | if self.last_needed_block >= 2: 104 | block2 = [ 105 | inception.Mixed_5b, 106 | inception.Mixed_5c, 107 | inception.Mixed_5d, 108 | inception.Mixed_6a, 109 | inception.Mixed_6b, 110 | inception.Mixed_6c, 111 | inception.Mixed_6d, 112 | inception.Mixed_6e, 113 | ] 114 | self.blocks.append(nn.Sequential(*block2)) 115 | 116 | # Block 3: aux classifier to final avgpool 117 | if self.last_needed_block >= 3: 118 | block3 = [ 119 | inception.Mixed_7a, 120 | inception.Mixed_7b, 121 | inception.Mixed_7c, 122 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 123 | ] 124 | self.blocks.append(nn.Sequential(*block3)) 125 | 126 | for param in self.parameters(): 127 | param.requires_grad = requires_grad 128 | 129 | def forward(self, inp): 130 | """Get Inception feature maps 131 | 132 | Parameters 133 | ---------- 134 | inp : torch.autograd.Variable 135 | Input tensor of shape Bx3xHxW. Values are expected to be in 136 | range (0, 1) 137 | 138 | Returns 139 | ------- 140 | List of torch.autograd.Variable, corresponding to the selected output 141 | block, sorted ascending by index 142 | """ 143 | outp = [] 144 | x = inp 145 | 146 | if self.resize_input: 147 | x = F.interpolate(x, 148 | size=(299, 299), 149 | mode='bilinear', 150 | align_corners=False) 151 | 152 | if self.normalize_input: 153 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 154 | 155 | for idx, block in enumerate(self.blocks): 156 | x = block(x) 157 | if idx in self.output_blocks: 158 | outp.append(x) 159 | 160 | if idx == self.last_needed_block: 161 | break 162 | 163 | return outp 164 | 165 | 166 | def fid_inception_v3(): 167 | """Build pretrained Inception model for FID computation 168 | 169 | The Inception model for FID computation uses a different set of weights 170 | and has a slightly different structure than torchvision's Inception. 171 | 172 | This method first constructs torchvision's Inception and then patches the 173 | necessary parts that are different in the FID Inception model. 174 | """ 175 | inception = models.inception_v3(num_classes=1008, 176 | aux_logits=False, 177 | pretrained=False) 178 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 179 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 180 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 181 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 182 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 183 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 184 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 185 | inception.Mixed_7b = FIDInceptionE_1(1280) 186 | inception.Mixed_7c = FIDInceptionE_2(2048) 187 | 188 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 189 | inception.load_state_dict(state_dict) 190 | return inception 191 | 192 | 193 | class FIDInceptionA(models.inception.InceptionA): 194 | """InceptionA block patched for FID computation""" 195 | def __init__(self, in_channels, pool_features): 196 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 197 | 198 | def forward(self, x): 199 | branch1x1 = self.branch1x1(x) 200 | 201 | branch5x5 = self.branch5x5_1(x) 202 | branch5x5 = self.branch5x5_2(branch5x5) 203 | 204 | branch3x3dbl = self.branch3x3dbl_1(x) 205 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 206 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 207 | 208 | # Patch: Tensorflow's average pool does not use the padded zero's in 209 | # its average calculation 210 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 211 | count_include_pad=False) 212 | branch_pool = self.branch_pool(branch_pool) 213 | 214 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 215 | return torch.cat(outputs, 1) 216 | 217 | 218 | class FIDInceptionC(models.inception.InceptionC): 219 | """InceptionC block patched for FID computation""" 220 | def __init__(self, in_channels, channels_7x7): 221 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 222 | 223 | def forward(self, x): 224 | branch1x1 = self.branch1x1(x) 225 | 226 | branch7x7 = self.branch7x7_1(x) 227 | branch7x7 = self.branch7x7_2(branch7x7) 228 | branch7x7 = self.branch7x7_3(branch7x7) 229 | 230 | branch7x7dbl = self.branch7x7dbl_1(x) 231 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 232 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 233 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 234 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 235 | 236 | # Patch: Tensorflow's average pool does not use the padded zero's in 237 | # its average calculation 238 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 239 | count_include_pad=False) 240 | branch_pool = self.branch_pool(branch_pool) 241 | 242 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 243 | return torch.cat(outputs, 1) 244 | 245 | 246 | class FIDInceptionE_1(models.inception.InceptionE): 247 | """First InceptionE block patched for FID computation""" 248 | def __init__(self, in_channels): 249 | super(FIDInceptionE_1, self).__init__(in_channels) 250 | 251 | def forward(self, x): 252 | branch1x1 = self.branch1x1(x) 253 | 254 | branch3x3 = self.branch3x3_1(x) 255 | branch3x3 = [ 256 | self.branch3x3_2a(branch3x3), 257 | self.branch3x3_2b(branch3x3), 258 | ] 259 | branch3x3 = torch.cat(branch3x3, 1) 260 | 261 | branch3x3dbl = self.branch3x3dbl_1(x) 262 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 263 | branch3x3dbl = [ 264 | self.branch3x3dbl_3a(branch3x3dbl), 265 | self.branch3x3dbl_3b(branch3x3dbl), 266 | ] 267 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 268 | 269 | # Patch: Tensorflow's average pool does not use the padded zero's in 270 | # its average calculation 271 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 272 | count_include_pad=False) 273 | branch_pool = self.branch_pool(branch_pool) 274 | 275 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 276 | return torch.cat(outputs, 1) 277 | 278 | 279 | class FIDInceptionE_2(models.inception.InceptionE): 280 | """Second InceptionE block patched for FID computation""" 281 | def __init__(self, in_channels): 282 | super(FIDInceptionE_2, self).__init__(in_channels) 283 | 284 | def forward(self, x): 285 | branch1x1 = self.branch1x1(x) 286 | 287 | branch3x3 = self.branch3x3_1(x) 288 | branch3x3 = [ 289 | self.branch3x3_2a(branch3x3), 290 | self.branch3x3_2b(branch3x3), 291 | ] 292 | branch3x3 = torch.cat(branch3x3, 1) 293 | 294 | branch3x3dbl = self.branch3x3dbl_1(x) 295 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 296 | branch3x3dbl = [ 297 | self.branch3x3dbl_3a(branch3x3dbl), 298 | self.branch3x3dbl_3b(branch3x3dbl), 299 | ] 300 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 301 | 302 | # Patch: The FID Inception model uses max pooling instead of average 303 | # pooling. This is likely an error in this specific Inception 304 | # implementation, as other Inception models use average pooling here 305 | # (which matches the description in the paper). 306 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 307 | branch_pool = self.branch_pool(branch_pool) 308 | 309 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 310 | return torch.cat(outputs, 1) 311 | -------------------------------------------------------------------------------- /stylegan2/inception_ffhq.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/swapping-autoencoder-pytorch/8265a8a4497ea098c83bbb47bf33960c999e7d7e/stylegan2/inception_ffhq.pkl -------------------------------------------------------------------------------- /stylegan2/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | from skimage.measure import compare_ssim 8 | import torch 9 | from torch.autograd import Variable 10 | 11 | from lpips import dist_model 12 | 13 | class PerceptualLoss(torch.nn.Module): 14 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric) 15 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 16 | super(PerceptualLoss, self).__init__() 17 | print('Setting up Perceptual loss...') 18 | self.use_gpu = use_gpu 19 | self.spatial = spatial 20 | self.gpu_ids = gpu_ids 21 | self.model = dist_model.DistModel() 22 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids) 23 | print('...[%s] initialized'%self.model.name()) 24 | print('...Done') 25 | 26 | def forward(self, pred, target, normalize=False): 27 | """ 28 | Pred and target are Variables. 29 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 30 | If normalize is False, assumes the images are already between [-1,+1] 31 | 32 | Inputs pred and target are Nx3xHxW 33 | Output pytorch Variable N long 34 | """ 35 | 36 | if normalize: 37 | target = 2 * target - 1 38 | pred = 2 * pred - 1 39 | 40 | return self.model.forward(target, pred) 41 | 42 | def normalize_tensor(in_feat,eps=1e-10): 43 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 44 | return in_feat/(norm_factor+eps) 45 | 46 | def l2(p0, p1, range=255.): 47 | return .5*np.mean((p0 / range - p1 / range)**2) 48 | 49 | def psnr(p0, p1, peak=255.): 50 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 51 | 52 | def dssim(p0, p1, range=255.): 53 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 54 | 55 | def rgb2lab(in_img,mean_cent=False): 56 | from skimage import color 57 | img_lab = color.rgb2lab(in_img) 58 | if(mean_cent): 59 | img_lab[:,:,0] = img_lab[:,:,0]-50 60 | return img_lab 61 | 62 | def tensor2np(tensor_obj): 63 | # change dimension of a tensor object into a numpy array 64 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 65 | 66 | def np2tensor(np_obj): 67 | # change dimenion of np array into tensor array 68 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 69 | 70 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 71 | # image tensor to lab tensor 72 | from skimage import color 73 | 74 | img = tensor2im(image_tensor) 75 | img_lab = color.rgb2lab(img) 76 | if(mc_only): 77 | img_lab[:,:,0] = img_lab[:,:,0]-50 78 | if(to_norm and not mc_only): 79 | img_lab[:,:,0] = img_lab[:,:,0]-50 80 | img_lab = img_lab/100. 81 | 82 | return np2tensor(img_lab) 83 | 84 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 85 | from skimage import color 86 | import warnings 87 | warnings.filterwarnings("ignore") 88 | 89 | lab = tensor2np(lab_tensor)*100. 90 | lab[:,:,0] = lab[:,:,0]+50 91 | 92 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 93 | if(return_inbnd): 94 | # convert back to lab, see if we match 95 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 96 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 97 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 98 | return (im2tensor(rgb_back),mask) 99 | else: 100 | return im2tensor(rgb_back) 101 | 102 | def rgb2lab(input): 103 | from skimage import color 104 | return color.rgb2lab(input / 255.) 105 | 106 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 107 | image_numpy = image_tensor[0].cpu().float().numpy() 108 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 109 | return image_numpy.astype(imtype) 110 | 111 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 112 | return torch.Tensor((image / factor - cent) 113 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 114 | 115 | def tensor2vec(vector_tensor): 116 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 117 | 118 | def voc_ap(rec, prec, use_07_metric=False): 119 | """ ap = voc_ap(rec, prec, [use_07_metric]) 120 | Compute VOC AP given precision and recall. 121 | If use_07_metric is true, uses the 122 | VOC 07 11 point method (default:False). 123 | """ 124 | if use_07_metric: 125 | # 11 point metric 126 | ap = 0. 127 | for t in np.arange(0., 1.1, 0.1): 128 | if np.sum(rec >= t) == 0: 129 | p = 0 130 | else: 131 | p = np.max(prec[rec >= t]) 132 | ap = ap + p / 11. 133 | else: 134 | # correct AP calculation 135 | # first append sentinel values at the end 136 | mrec = np.concatenate(([0.], rec, [1.])) 137 | mpre = np.concatenate(([0.], prec, [0.])) 138 | 139 | # compute the precision envelope 140 | for i in range(mpre.size - 1, 0, -1): 141 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 142 | 143 | # to calculate area under PR curve, look for points 144 | # where X axis (recall) changes value 145 | i = np.where(mrec[1:] != mrec[:-1])[0] 146 | 147 | # and sum (\Delta recall) * prec 148 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 149 | return ap 150 | 151 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 152 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 153 | image_numpy = image_tensor[0].cpu().float().numpy() 154 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 155 | return image_numpy.astype(imtype) 156 | 157 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 158 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 159 | return torch.Tensor((image / factor - cent) 160 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 161 | -------------------------------------------------------------------------------- /stylegan2/lpips/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | from pdb import set_trace as st 6 | from IPython import embed 7 | 8 | class BaseModel(): 9 | def __init__(self): 10 | pass; 11 | 12 | def name(self): 13 | return 'BaseModel' 14 | 15 | def initialize(self, use_gpu=True, gpu_ids=[0]): 16 | self.use_gpu = use_gpu 17 | self.gpu_ids = gpu_ids 18 | 19 | def forward(self): 20 | pass 21 | 22 | def get_image_paths(self): 23 | pass 24 | 25 | def optimize_parameters(self): 26 | pass 27 | 28 | def get_current_visuals(self): 29 | return self.input 30 | 31 | def get_current_errors(self): 32 | return {} 33 | 34 | def save(self, label): 35 | pass 36 | 37 | # helper saving function that can be used by subclasses 38 | def save_network(self, network, path, network_label, epoch_label): 39 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 40 | save_path = os.path.join(path, save_filename) 41 | torch.save(network.state_dict(), save_path) 42 | 43 | # helper loading function that can be used by subclasses 44 | def load_network(self, network, network_label, epoch_label): 45 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 46 | save_path = os.path.join(self.save_dir, save_filename) 47 | print('Loading network from %s'%save_path) 48 | network.load_state_dict(torch.load(save_path)) 49 | 50 | def update_learning_rate(): 51 | pass 52 | 53 | def get_image_paths(self): 54 | return self.image_paths 55 | 56 | def save_done(self, flag=False): 57 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 58 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 59 | -------------------------------------------------------------------------------- /stylegan2/lpips/dist_model.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | import os 9 | from collections import OrderedDict 10 | from torch.autograd import Variable 11 | import itertools 12 | from .base_model import BaseModel 13 | from scipy.ndimage import zoom 14 | import fractions 15 | import functools 16 | import skimage.transform 17 | from tqdm import tqdm 18 | 19 | from IPython import embed 20 | 21 | from . import networks_basic as networks 22 | import lpips as util 23 | 24 | class DistModel(BaseModel): 25 | def name(self): 26 | return self.model_name 27 | 28 | def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, 29 | use_gpu=True, printNet=False, spatial=False, 30 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): 31 | ''' 32 | INPUTS 33 | model - ['net-lin'] for linearly calibrated network 34 | ['net'] for off-the-shelf network 35 | ['L2'] for L2 distance in Lab colorspace 36 | ['SSIM'] for ssim in RGB colorspace 37 | net - ['squeeze','alex','vgg'] 38 | model_path - if None, will look in weights/[NET_NAME].pth 39 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM 40 | use_gpu - bool - whether or not to use a GPU 41 | printNet - bool - whether or not to print network architecture out 42 | spatial - bool - whether to output an array containing varying distances across spatial dimensions 43 | spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). 44 | spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. 45 | spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). 46 | is_train - bool - [True] for training mode 47 | lr - float - initial learning rate 48 | beta1 - float - initial momentum term for adam 49 | version - 0.1 for latest, 0.0 was original (with a bug) 50 | gpu_ids - int array - [0] by default, gpus to use 51 | ''' 52 | BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) 53 | 54 | self.model = model 55 | self.net = net 56 | self.is_train = is_train 57 | self.spatial = spatial 58 | self.gpu_ids = gpu_ids 59 | self.model_name = '%s [%s]'%(model,net) 60 | 61 | if(self.model == 'net-lin'): # pretrained net + linear layer 62 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, 63 | use_dropout=True, spatial=spatial, version=version, lpips=True) 64 | kw = {} 65 | if not use_gpu: 66 | kw['map_location'] = 'cpu' 67 | if(model_path is None): 68 | import inspect 69 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net))) 70 | 71 | if(not is_train): 72 | print('Loading model from: %s'%model_path) 73 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False) 74 | 75 | elif(self.model=='net'): # pretrained network 76 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) 77 | elif(self.model in ['L2','l2']): 78 | self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing 79 | self.model_name = 'L2' 80 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']): 81 | self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace) 82 | self.model_name = 'SSIM' 83 | else: 84 | raise ValueError("Model [%s] not recognized." % self.model) 85 | 86 | self.parameters = list(self.net.parameters()) 87 | 88 | if self.is_train: # training mode 89 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) 90 | self.rankLoss = networks.BCERankingLoss() 91 | self.parameters += list(self.rankLoss.net.parameters()) 92 | self.lr = lr 93 | self.old_lr = lr 94 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) 95 | else: # test mode 96 | self.net.eval() 97 | 98 | if(use_gpu): 99 | self.net.to(gpu_ids[0]) 100 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) 101 | if(self.is_train): 102 | self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 103 | 104 | if(printNet): 105 | print('---------- Networks initialized -------------') 106 | networks.print_network(self.net) 107 | print('-----------------------------------------------') 108 | 109 | def forward(self, in0, in1, retPerLayer=False): 110 | ''' Function computes the distance between image patches in0 and in1 111 | INPUTS 112 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] 113 | OUTPUT 114 | computed distances between in0 and in1 115 | ''' 116 | 117 | return self.net.forward(in0, in1, retPerLayer=retPerLayer) 118 | 119 | # ***** TRAINING FUNCTIONS ***** 120 | def optimize_parameters(self): 121 | self.forward_train() 122 | self.optimizer_net.zero_grad() 123 | self.backward_train() 124 | self.optimizer_net.step() 125 | self.clamp_weights() 126 | 127 | def clamp_weights(self): 128 | for module in self.net.modules(): 129 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)): 130 | module.weight.data = torch.clamp(module.weight.data,min=0) 131 | 132 | def set_input(self, data): 133 | self.input_ref = data['ref'] 134 | self.input_p0 = data['p0'] 135 | self.input_p1 = data['p1'] 136 | self.input_judge = data['judge'] 137 | 138 | if(self.use_gpu): 139 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) 140 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) 141 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) 142 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) 143 | 144 | self.var_ref = Variable(self.input_ref,requires_grad=True) 145 | self.var_p0 = Variable(self.input_p0,requires_grad=True) 146 | self.var_p1 = Variable(self.input_p1,requires_grad=True) 147 | 148 | def forward_train(self): # run forward pass 149 | # print(self.net.module.scaling_layer.shift) 150 | # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) 151 | 152 | self.d0 = self.forward(self.var_ref, self.var_p0) 153 | self.d1 = self.forward(self.var_ref, self.var_p1) 154 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) 155 | 156 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) 157 | 158 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) 159 | 160 | return self.loss_total 161 | 162 | def backward_train(self): 163 | torch.mean(self.loss_total).backward() 164 | 165 | def compute_accuracy(self,d0,d1,judge): 166 | ''' d0, d1 are Variables, judge is a Tensor ''' 167 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) 210 | self.old_lr = lr 211 | 212 | def score_2afc_dataset(data_loader, func, name=''): 213 | ''' Function computes Two Alternative Forced Choice (2AFC) score using 214 | distance function 'func' in dataset 'data_loader' 215 | INPUTS 216 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside 217 | func - callable distance function - calling d=func(in0,in1) should take 2 218 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N 219 | OUTPUTS 220 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators 221 | [1] - dictionary with following elements 222 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches 223 | gts - N array in [0,1], preferred patch selected by human evaluators 224 | (closer to "0" for left patch p0, "1" for right patch p1, 225 | "0.6" means 60pct people preferred right patch, 40pct preferred left) 226 | scores - N array in [0,1], corresponding to what percentage function agreed with humans 227 | CONSTS 228 | N - number of test triplets in data_loader 229 | ''' 230 | 231 | d0s = [] 232 | d1s = [] 233 | gts = [] 234 | 235 | for data in tqdm(data_loader.load_data(), desc=name): 236 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() 237 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() 238 | gts+=data['judge'].cpu().numpy().flatten().tolist() 239 | 240 | d0s = np.array(d0s) 241 | d1s = np.array(d1s) 242 | gts = np.array(gts) 243 | scores = (d0s 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | grad_bias = grad_input.sum(dim).detach() 39 | 40 | return grad_input, grad_bias 41 | 42 | @staticmethod 43 | def backward(ctx, gradgrad_input, gradgrad_bias): 44 | (out,) = ctx.saved_tensors 45 | gradgrad_out = fused.fused_bias_act( 46 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 47 | ) 48 | 49 | return gradgrad_out, None, None, None 50 | 51 | 52 | class FusedLeakyReLUFunction(Function): 53 | @staticmethod 54 | def forward(ctx, input, bias, negative_slope, scale): 55 | empty = input.new_empty(0) 56 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 57 | ctx.save_for_backward(out) 58 | ctx.negative_slope = negative_slope 59 | ctx.scale = scale 60 | 61 | return out 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | (out,) = ctx.saved_tensors 66 | 67 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 68 | grad_output, out, ctx.negative_slope, ctx.scale 69 | ) 70 | 71 | return grad_input, grad_bias, None, None 72 | 73 | 74 | class FusedLeakyReLU(nn.Module): 75 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 76 | super().__init__() 77 | 78 | self.bias = nn.Parameter(torch.zeros(channel)) 79 | self.negative_slope = negative_slope 80 | self.scale = scale 81 | 82 | def forward(self, input): 83 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 84 | 85 | 86 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 87 | if input.device.type == "cpu": 88 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 89 | return ( 90 | F.leaky_relu( 91 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 92 | ) 93 | * scale 94 | ) 95 | 96 | else: 97 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 98 | -------------------------------------------------------------------------------- /stylegan2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /stylegan2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /stylegan2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /stylegan2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | upfirdn2d_op = load( 11 | "upfirdn2d", 12 | sources=[ 13 | os.path.join(module_path, "upfirdn2d.cpp"), 14 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 15 | ], 16 | ) 17 | 18 | 19 | class UpFirDn2dBackward(Function): 20 | @staticmethod 21 | def forward( 22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 23 | ): 24 | 25 | up_x, up_y = up 26 | down_x, down_y = down 27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 28 | 29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 30 | 31 | grad_input = upfirdn2d_op.upfirdn2d( 32 | grad_output, 33 | grad_kernel, 34 | down_x, 35 | down_y, 36 | up_x, 37 | up_y, 38 | g_pad_x0, 39 | g_pad_x1, 40 | g_pad_y0, 41 | g_pad_y1, 42 | ) 43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 44 | 45 | ctx.save_for_backward(kernel) 46 | 47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 48 | 49 | ctx.up_x = up_x 50 | ctx.up_y = up_y 51 | ctx.down_x = down_x 52 | ctx.down_y = down_y 53 | ctx.pad_x0 = pad_x0 54 | ctx.pad_x1 = pad_x1 55 | ctx.pad_y0 = pad_y0 56 | ctx.pad_y1 = pad_y1 57 | ctx.in_size = in_size 58 | ctx.out_size = out_size 59 | 60 | return grad_input 61 | 62 | @staticmethod 63 | def backward(ctx, gradgrad_input): 64 | kernel, = ctx.saved_tensors 65 | 66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 67 | 68 | gradgrad_out = upfirdn2d_op.upfirdn2d( 69 | gradgrad_input, 70 | kernel, 71 | ctx.up_x, 72 | ctx.up_y, 73 | ctx.down_x, 74 | ctx.down_y, 75 | ctx.pad_x0, 76 | ctx.pad_x1, 77 | ctx.pad_y0, 78 | ctx.pad_y1, 79 | ) 80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 81 | gradgrad_out = gradgrad_out.view( 82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 83 | ) 84 | 85 | return gradgrad_out, None, None, None, None, None, None, None, None 86 | 87 | 88 | class UpFirDn2d(Function): 89 | @staticmethod 90 | def forward(ctx, input, kernel, up, down, pad): 91 | up_x, up_y = up 92 | down_x, down_y = down 93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 94 | 95 | kernel_h, kernel_w = kernel.shape 96 | batch, channel, in_h, in_w = input.shape 97 | ctx.in_size = input.shape 98 | 99 | input = input.reshape(-1, in_h, in_w, 1) 100 | 101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 102 | 103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 105 | ctx.out_size = (out_h, out_w) 106 | 107 | ctx.up = (up_x, up_y) 108 | ctx.down = (down_x, down_y) 109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 110 | 111 | g_pad_x0 = kernel_w - pad_x0 - 1 112 | g_pad_y0 = kernel_h - pad_y0 - 1 113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 115 | 116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 117 | 118 | out = upfirdn2d_op.upfirdn2d( 119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 120 | ) 121 | # out = out.view(major, out_h, out_w, minor) 122 | out = out.view(-1, channel, out_h, out_w) 123 | 124 | return out 125 | 126 | @staticmethod 127 | def backward(ctx, grad_output): 128 | kernel, grad_kernel = ctx.saved_tensors 129 | 130 | grad_input = UpFirDn2dBackward.apply( 131 | grad_output, 132 | kernel, 133 | grad_kernel, 134 | ctx.up, 135 | ctx.down, 136 | ctx.pad, 137 | ctx.g_pad, 138 | ctx.in_size, 139 | ctx.out_size, 140 | ) 141 | 142 | return grad_input, None, None, None, None 143 | 144 | 145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 146 | if input.device.type == "cpu": 147 | out = upfirdn2d_native( 148 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 149 | ) 150 | 151 | else: 152 | out = UpFirDn2d.apply( 153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 154 | ) 155 | 156 | return out 157 | 158 | 159 | def upfirdn2d_native( 160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 161 | ): 162 | _, channel, in_h, in_w = input.shape 163 | input = input.reshape(-1, in_h, in_w, 1) 164 | 165 | _, in_h, in_w, minor = input.shape 166 | kernel_h, kernel_w = kernel.shape 167 | 168 | out = input.view(-1, in_h, 1, in_w, 1, minor) 169 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 170 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 171 | 172 | out = F.pad( 173 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 174 | ) 175 | out = out[ 176 | :, 177 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 178 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 179 | :, 180 | ] 181 | 182 | out = out.permute(0, 3, 1, 2) 183 | out = out.reshape( 184 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 185 | ) 186 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 187 | out = F.conv2d(out, w) 188 | out = out.reshape( 189 | -1, 190 | minor, 191 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 192 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 193 | ) 194 | out = out.permute(0, 2, 3, 1) 195 | out = out[:, ::down_y, ::down_x, :] 196 | 197 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 198 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 199 | 200 | return out.view(-1, channel, out_h, out_w) 201 | -------------------------------------------------------------------------------- /stylegan2/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /stylegan2/ppl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | import lpips 9 | from model import Generator 10 | 11 | 12 | def normalize(x): 13 | return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True)) 14 | 15 | 16 | def slerp(a, b, t): 17 | a = normalize(a) 18 | b = normalize(b) 19 | d = (a * b).sum(-1, keepdim=True) 20 | p = t * torch.acos(d) 21 | c = normalize(b - d * a) 22 | d = a * torch.cos(p) + c * torch.sin(p) 23 | 24 | return normalize(d) 25 | 26 | 27 | def lerp(a, b, t): 28 | return a + (b - a) * t 29 | 30 | 31 | if __name__ == '__main__': 32 | device = 'cuda' 33 | 34 | parser = argparse.ArgumentParser() 35 | 36 | parser.add_argument('--space', choices=['z', 'w']) 37 | parser.add_argument('--batch', type=int, default=64) 38 | parser.add_argument('--n_sample', type=int, default=5000) 39 | parser.add_argument('--size', type=int, default=256) 40 | parser.add_argument('--eps', type=float, default=1e-4) 41 | parser.add_argument('--crop', action='store_true') 42 | parser.add_argument('ckpt', metavar='CHECKPOINT') 43 | 44 | args = parser.parse_args() 45 | 46 | latent_dim = 512 47 | 48 | ckpt = torch.load(args.ckpt) 49 | 50 | g = Generator(args.size, latent_dim, 8).to(device) 51 | g.load_state_dict(ckpt['g_ema']) 52 | g.eval() 53 | 54 | percept = lpips.PerceptualLoss( 55 | model='net-lin', net='vgg', use_gpu=device.startswith('cuda') 56 | ) 57 | 58 | distances = [] 59 | 60 | n_batch = args.n_sample // args.batch 61 | resid = args.n_sample - (n_batch * args.batch) 62 | batch_sizes = [args.batch] * n_batch + [resid] 63 | 64 | with torch.no_grad(): 65 | for batch in tqdm(batch_sizes): 66 | noise = g.make_noise() 67 | 68 | inputs = torch.randn([batch * 2, latent_dim], device=device) 69 | lerp_t = torch.rand(batch, device=device) 70 | 71 | if args.space == 'w': 72 | latent = g.get_latent(inputs) 73 | latent_t0, latent_t1 = latent[::2], latent[1::2] 74 | latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None]) 75 | latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps) 76 | latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape) 77 | 78 | image, _ = g([latent_e], input_is_latent=True, noise=noise) 79 | 80 | if args.crop: 81 | c = image.shape[2] // 8 82 | image = image[:, :, c * 3 : c * 7, c * 2 : c * 6] 83 | 84 | factor = image.shape[2] // 256 85 | 86 | if factor > 1: 87 | image = F.interpolate( 88 | image, size=(256, 256), mode='bilinear', align_corners=False 89 | ) 90 | 91 | dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / ( 92 | args.eps ** 2 93 | ) 94 | distances.append(dist.to('cpu').numpy()) 95 | 96 | distances = np.concatenate(distances, 0) 97 | 98 | lo = np.percentile(distances, 1, interpolation='lower') 99 | hi = np.percentile(distances, 99, interpolation='higher') 100 | filtered_dist = np.extract( 101 | np.logical_and(lo <= distances, distances <= hi), distances 102 | ) 103 | 104 | print('ppl:', filtered_dist.mean()) 105 | -------------------------------------------------------------------------------- /stylegan2/projector.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | 5 | import torch 6 | from torch import optim 7 | from torch.nn import functional as F 8 | from torchvision import transforms 9 | from PIL import Image 10 | from tqdm import tqdm 11 | 12 | import lpips 13 | from model import Generator 14 | 15 | 16 | def noise_regularize(noises): 17 | loss = 0 18 | 19 | for noise in noises: 20 | size = noise.shape[2] 21 | 22 | while True: 23 | loss = ( 24 | loss 25 | + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) 26 | + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2) 27 | ) 28 | 29 | if size <= 8: 30 | break 31 | 32 | noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2]) 33 | noise = noise.mean([3, 5]) 34 | size //= 2 35 | 36 | return loss 37 | 38 | 39 | def noise_normalize_(noises): 40 | for noise in noises: 41 | mean = noise.mean() 42 | std = noise.std() 43 | 44 | noise.data.add_(-mean).div_(std) 45 | 46 | 47 | def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): 48 | lr_ramp = min(1, (1 - t) / rampdown) 49 | lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) 50 | lr_ramp = lr_ramp * min(1, t / rampup) 51 | 52 | return initial_lr * lr_ramp 53 | 54 | 55 | def latent_noise(latent, strength): 56 | noise = torch.randn_like(latent) * strength 57 | 58 | return latent + noise 59 | 60 | 61 | def make_image(tensor): 62 | return ( 63 | tensor.detach() 64 | .clamp_(min=-1, max=1) 65 | .add(1) 66 | .div_(2) 67 | .mul(255) 68 | .type(torch.uint8) 69 | .permute(0, 2, 3, 1) 70 | .to("cpu") 71 | .numpy() 72 | ) 73 | 74 | 75 | if __name__ == "__main__": 76 | device = "cuda" 77 | 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument("--ckpt", type=str, required=True) 80 | parser.add_argument("--size", type=int, default=256) 81 | parser.add_argument("--lr_rampup", type=float, default=0.05) 82 | parser.add_argument("--lr_rampdown", type=float, default=0.25) 83 | parser.add_argument("--lr", type=float, default=0.1) 84 | parser.add_argument("--noise", type=float, default=0.05) 85 | parser.add_argument("--noise_ramp", type=float, default=0.75) 86 | parser.add_argument("--step", type=int, default=1000) 87 | parser.add_argument("--noise_regularize", type=float, default=1e5) 88 | parser.add_argument("--mse", type=float, default=0) 89 | parser.add_argument("--w_plus", action="store_true") 90 | parser.add_argument("files", metavar="FILES", nargs="+") 91 | 92 | args = parser.parse_args() 93 | 94 | n_mean_latent = 10000 95 | 96 | resize = min(args.size, 256) 97 | 98 | transform = transforms.Compose( 99 | [ 100 | transforms.Resize(resize), 101 | transforms.CenterCrop(resize), 102 | transforms.ToTensor(), 103 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 104 | ] 105 | ) 106 | 107 | imgs = [] 108 | 109 | for imgfile in args.files: 110 | img = transform(Image.open(imgfile).convert("RGB")) 111 | imgs.append(img) 112 | 113 | imgs = torch.stack(imgs, 0).to(device) 114 | 115 | g_ema = Generator(args.size, 512, 8) 116 | g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False) 117 | g_ema.eval() 118 | g_ema = g_ema.to(device) 119 | 120 | with torch.no_grad(): 121 | noise_sample = torch.randn(n_mean_latent, 512, device=device) 122 | latent_out = g_ema.style(noise_sample) 123 | 124 | latent_mean = latent_out.mean(0) 125 | latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5 126 | 127 | percept = lpips.PerceptualLoss( 128 | model="net-lin", net="vgg", use_gpu=device.startswith("cuda") 129 | ) 130 | 131 | noises_single = g_ema.make_noise() 132 | noises = [] 133 | for noise in noises_single: 134 | noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_()) 135 | 136 | latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1) 137 | 138 | if args.w_plus: 139 | latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1) 140 | 141 | latent_in.requires_grad = True 142 | 143 | for noise in noises: 144 | noise.requires_grad = True 145 | 146 | optimizer = optim.Adam([latent_in] + noises, lr=args.lr) 147 | 148 | pbar = tqdm(range(args.step)) 149 | latent_path = [] 150 | 151 | for i in pbar: 152 | t = i / args.step 153 | lr = get_lr(t, args.lr) 154 | optimizer.param_groups[0]["lr"] = lr 155 | noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2 156 | latent_n = latent_noise(latent_in, noise_strength.item()) 157 | 158 | img_gen, _ = g_ema([latent_n], input_is_latent=True, noise=noises) 159 | 160 | batch, channel, height, width = img_gen.shape 161 | 162 | if height > 256: 163 | factor = height // 256 164 | 165 | img_gen = img_gen.reshape( 166 | batch, channel, height // factor, factor, width // factor, factor 167 | ) 168 | img_gen = img_gen.mean([3, 5]) 169 | 170 | p_loss = percept(img_gen, imgs).sum() 171 | n_loss = noise_regularize(noises) 172 | mse_loss = F.mse_loss(img_gen, imgs) 173 | 174 | loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss 175 | 176 | optimizer.zero_grad() 177 | loss.backward() 178 | optimizer.step() 179 | 180 | noise_normalize_(noises) 181 | 182 | if (i + 1) % 100 == 0: 183 | latent_path.append(latent_in.detach().clone()) 184 | 185 | pbar.set_description( 186 | ( 187 | f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};" 188 | f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}" 189 | ) 190 | ) 191 | 192 | img_gen, _ = g_ema([latent_path[-1]], input_is_latent=True, noise=noises) 193 | 194 | filename = os.path.splitext(os.path.basename(args.files[0]))[0] + ".pt" 195 | 196 | img_ar = make_image(img_gen) 197 | 198 | result_file = {} 199 | for i, input_name in enumerate(args.files): 200 | noise_single = [] 201 | for noise in noises: 202 | noise_single.append(noise[i : i + 1]) 203 | 204 | result_file[input_name] = { 205 | "img": img_gen[i], 206 | "latent": latent_in[i], 207 | "noise": noise_single, 208 | } 209 | 210 | img_name = os.path.splitext(os.path.basename(input_name))[0] + "-project.png" 211 | pil_img = Image.fromarray(img_ar[i]) 212 | pil_img.save(img_name) 213 | 214 | torch.save(result_file, filename) 215 | -------------------------------------------------------------------------------- /stylegan2/sample/.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | -------------------------------------------------------------------------------- /stylegan2/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import random 4 | import os 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn, autograd, optim 9 | from torch.nn import functional as F 10 | from torch.utils import data 11 | import torch.distributed as dist 12 | from torchvision import transforms, utils 13 | from tqdm import tqdm 14 | 15 | try: 16 | import wandb 17 | 18 | except ImportError: 19 | wandb = None 20 | 21 | from model import Generator, Discriminator 22 | from dataset import MultiResolutionDataset 23 | from distributed import ( 24 | get_rank, 25 | synchronize, 26 | reduce_loss_dict, 27 | reduce_sum, 28 | get_world_size, 29 | ) 30 | from non_leaking import augment 31 | 32 | 33 | def data_sampler(dataset, shuffle, distributed): 34 | if distributed: 35 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 36 | 37 | if shuffle: 38 | return data.RandomSampler(dataset) 39 | 40 | else: 41 | return data.SequentialSampler(dataset) 42 | 43 | 44 | def requires_grad(model, flag=True): 45 | for p in model.parameters(): 46 | p.requires_grad = flag 47 | 48 | 49 | def accumulate(model1, model2, decay=0.999): 50 | par1 = dict(model1.named_parameters()) 51 | par2 = dict(model2.named_parameters()) 52 | 53 | for k in par1.keys(): 54 | par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay) 55 | 56 | 57 | def sample_data(loader): 58 | while True: 59 | for batch in loader: 60 | yield batch 61 | 62 | 63 | def d_logistic_loss(real_pred, fake_pred): 64 | real_loss = F.softplus(-real_pred) 65 | fake_loss = F.softplus(fake_pred) 66 | 67 | return real_loss.mean() + fake_loss.mean() 68 | 69 | 70 | def d_r1_loss(real_pred, real_img): 71 | grad_real, = autograd.grad( 72 | outputs=real_pred.sum(), inputs=real_img, create_graph=True 73 | ) 74 | grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() 75 | 76 | return grad_penalty 77 | 78 | 79 | def g_nonsaturating_loss(fake_pred): 80 | loss = F.softplus(-fake_pred).mean() 81 | 82 | return loss 83 | 84 | 85 | def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): 86 | noise = torch.randn_like(fake_img) / math.sqrt( 87 | fake_img.shape[2] * fake_img.shape[3] 88 | ) 89 | grad, = autograd.grad( 90 | outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True 91 | ) 92 | path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) 93 | 94 | path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) 95 | 96 | path_penalty = (path_lengths - path_mean).pow(2).mean() 97 | 98 | return path_penalty, path_mean.detach(), path_lengths 99 | 100 | 101 | def make_noise(batch, latent_dim, n_noise, device): 102 | if n_noise == 1: 103 | return torch.randn(batch, latent_dim, device=device) 104 | 105 | noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0) 106 | 107 | return noises 108 | 109 | 110 | def mixing_noise(batch, latent_dim, prob, device): 111 | if prob > 0 and random.random() < prob: 112 | return make_noise(batch, latent_dim, 2, device) 113 | 114 | else: 115 | return [make_noise(batch, latent_dim, 1, device)] 116 | 117 | 118 | def set_grad_none(model, targets): 119 | for n, p in model.named_parameters(): 120 | if n in targets: 121 | p.grad = None 122 | 123 | 124 | def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): 125 | loader = sample_data(loader) 126 | 127 | pbar = range(args.iter) 128 | 129 | if get_rank() == 0: 130 | pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) 131 | 132 | mean_path_length = 0 133 | 134 | d_loss_val = 0 135 | r1_loss = torch.tensor(0.0, device=device) 136 | g_loss_val = 0 137 | path_loss = torch.tensor(0.0, device=device) 138 | path_lengths = torch.tensor(0.0, device=device) 139 | mean_path_length_avg = 0 140 | loss_dict = {} 141 | 142 | if args.distributed: 143 | g_module = generator.module 144 | d_module = discriminator.module 145 | 146 | else: 147 | g_module = generator 148 | d_module = discriminator 149 | 150 | accum = 0.5 ** (32 / (10 * 1000)) 151 | ada_augment = torch.tensor([0.0, 0.0], device=device) 152 | ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 153 | ada_aug_step = args.ada_target / args.ada_length 154 | r_t_stat = 0 155 | 156 | sample_z = torch.randn(args.n_sample, args.latent, device=device) 157 | 158 | for idx in pbar: 159 | i = idx + args.start_iter 160 | 161 | if i > args.iter: 162 | print("Done!") 163 | 164 | break 165 | 166 | real_img = next(loader) 167 | real_img = real_img.to(device) 168 | 169 | requires_grad(generator, False) 170 | requires_grad(discriminator, True) 171 | 172 | noise = mixing_noise(args.batch, args.latent, args.mixing, device) 173 | fake_img, _ = generator(noise) 174 | 175 | if args.augment: 176 | real_img_aug, _ = augment(real_img, ada_aug_p) 177 | fake_img, _ = augment(fake_img, ada_aug_p) 178 | 179 | else: 180 | real_img_aug = real_img 181 | 182 | fake_pred = discriminator(fake_img) 183 | real_pred = discriminator(real_img_aug) 184 | d_loss = d_logistic_loss(real_pred, fake_pred) 185 | 186 | loss_dict["d"] = d_loss 187 | loss_dict["real_score"] = real_pred.mean() 188 | loss_dict["fake_score"] = fake_pred.mean() 189 | 190 | discriminator.zero_grad() 191 | d_loss.backward() 192 | d_optim.step() 193 | 194 | if args.augment and args.augment_p == 0: 195 | ada_augment += torch.tensor( 196 | (torch.sign(real_pred).sum().item(), real_pred.shape[0]), device=device 197 | ) 198 | ada_augment = reduce_sum(ada_augment) 199 | 200 | if ada_augment[1] > 255: 201 | pred_signs, n_pred = ada_augment.tolist() 202 | 203 | r_t_stat = pred_signs / n_pred 204 | 205 | if r_t_stat > args.ada_target: 206 | sign = 1 207 | 208 | else: 209 | sign = -1 210 | 211 | ada_aug_p += sign * ada_aug_step * n_pred 212 | ada_aug_p = min(1, max(0, ada_aug_p)) 213 | ada_augment.mul_(0) 214 | 215 | d_regularize = i % args.d_reg_every == 0 216 | 217 | if d_regularize: 218 | real_img.requires_grad = True 219 | real_pred = discriminator(real_img) 220 | r1_loss = d_r1_loss(real_pred, real_img) 221 | 222 | discriminator.zero_grad() 223 | (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() 224 | 225 | d_optim.step() 226 | 227 | loss_dict["r1"] = r1_loss 228 | 229 | requires_grad(generator, True) 230 | requires_grad(discriminator, False) 231 | 232 | noise = mixing_noise(args.batch, args.latent, args.mixing, device) 233 | fake_img, _ = generator(noise) 234 | 235 | if args.augment: 236 | fake_img, _ = augment(fake_img, ada_aug_p) 237 | 238 | fake_pred = discriminator(fake_img) 239 | g_loss = g_nonsaturating_loss(fake_pred) 240 | 241 | loss_dict["g"] = g_loss 242 | 243 | generator.zero_grad() 244 | g_loss.backward() 245 | g_optim.step() 246 | 247 | g_regularize = i % args.g_reg_every == 0 248 | 249 | if g_regularize: 250 | path_batch_size = max(1, args.batch // args.path_batch_shrink) 251 | noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) 252 | fake_img, latents = generator(noise, return_latents=True) 253 | 254 | path_loss, mean_path_length, path_lengths = g_path_regularize( 255 | fake_img, latents, mean_path_length 256 | ) 257 | 258 | generator.zero_grad() 259 | weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss 260 | 261 | if args.path_batch_shrink: 262 | weighted_path_loss += 0 * fake_img[0, 0, 0, 0] 263 | 264 | weighted_path_loss.backward() 265 | 266 | g_optim.step() 267 | 268 | mean_path_length_avg = ( 269 | reduce_sum(mean_path_length).item() / get_world_size() 270 | ) 271 | 272 | loss_dict["path"] = path_loss 273 | loss_dict["path_length"] = path_lengths.mean() 274 | 275 | accumulate(g_ema, g_module, accum) 276 | 277 | loss_reduced = reduce_loss_dict(loss_dict) 278 | 279 | d_loss_val = loss_reduced["d"].mean().item() 280 | g_loss_val = loss_reduced["g"].mean().item() 281 | r1_val = loss_reduced["r1"].mean().item() 282 | path_loss_val = loss_reduced["path"].mean().item() 283 | real_score_val = loss_reduced["real_score"].mean().item() 284 | fake_score_val = loss_reduced["fake_score"].mean().item() 285 | path_length_val = loss_reduced["path_length"].mean().item() 286 | 287 | if get_rank() == 0: 288 | pbar.set_description( 289 | ( 290 | f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " 291 | f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " 292 | f"augment: {ada_aug_p:.4f}" 293 | ) 294 | ) 295 | 296 | if wandb and args.wandb: 297 | wandb.log( 298 | { 299 | "Generator": g_loss_val, 300 | "Discriminator": d_loss_val, 301 | "Augment": ada_aug_p, 302 | "Rt": r_t_stat, 303 | "R1": r1_val, 304 | "Path Length Regularization": path_loss_val, 305 | "Mean Path Length": mean_path_length, 306 | "Real Score": real_score_val, 307 | "Fake Score": fake_score_val, 308 | "Path Length": path_length_val, 309 | } 310 | ) 311 | 312 | if i % 100 == 0: 313 | with torch.no_grad(): 314 | g_ema.eval() 315 | sample, _ = g_ema([sample_z]) 316 | utils.save_image( 317 | sample, 318 | f"sample/{str(i).zfill(6)}.png", 319 | nrow=int(args.n_sample ** 0.5), 320 | normalize=True, 321 | range=(-1, 1), 322 | ) 323 | 324 | if i % 10000 == 0: 325 | torch.save( 326 | { 327 | "g": g_module.state_dict(), 328 | "d": d_module.state_dict(), 329 | "g_ema": g_ema.state_dict(), 330 | "g_optim": g_optim.state_dict(), 331 | "d_optim": d_optim.state_dict(), 332 | "args": args, 333 | "ada_aug_p": ada_aug_p, 334 | }, 335 | f"checkpoint/{str(i).zfill(6)}.pt", 336 | ) 337 | 338 | 339 | if __name__ == "__main__": 340 | device = "cuda" 341 | 342 | parser = argparse.ArgumentParser() 343 | 344 | parser.add_argument("path", type=str) 345 | parser.add_argument("--iter", type=int, default=800000) 346 | parser.add_argument("--batch", type=int, default=16) 347 | parser.add_argument("--n_sample", type=int, default=64) 348 | parser.add_argument("--size", type=int, default=256) 349 | parser.add_argument("--r1", type=float, default=10) 350 | parser.add_argument("--path_regularize", type=float, default=2) 351 | parser.add_argument("--path_batch_shrink", type=int, default=2) 352 | parser.add_argument("--d_reg_every", type=int, default=16) 353 | parser.add_argument("--g_reg_every", type=int, default=4) 354 | parser.add_argument("--mixing", type=float, default=0.9) 355 | parser.add_argument("--ckpt", type=str, default=None) 356 | parser.add_argument("--lr", type=float, default=0.002) 357 | parser.add_argument("--channel_multiplier", type=int, default=2) 358 | parser.add_argument("--wandb", action="store_true") 359 | parser.add_argument("--local_rank", type=int, default=0) 360 | parser.add_argument("--augment", action="store_true") 361 | parser.add_argument("--augment_p", type=float, default=0) 362 | parser.add_argument("--ada_target", type=float, default=0.6) 363 | parser.add_argument("--ada_length", type=int, default=500 * 1000) 364 | 365 | args = parser.parse_args() 366 | 367 | n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 368 | args.distributed = n_gpu > 1 369 | 370 | if args.distributed: 371 | torch.cuda.set_device(args.local_rank) 372 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 373 | synchronize() 374 | 375 | args.latent = 512 376 | args.n_mlp = 8 377 | 378 | args.start_iter = 0 379 | 380 | generator = Generator( 381 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier 382 | ).to(device) 383 | discriminator = Discriminator( 384 | args.size, channel_multiplier=args.channel_multiplier 385 | ).to(device) 386 | g_ema = Generator( 387 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier 388 | ).to(device) 389 | g_ema.eval() 390 | accumulate(g_ema, generator, 0) 391 | 392 | g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) 393 | d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) 394 | 395 | g_optim = optim.Adam( 396 | generator.parameters(), 397 | lr=args.lr * g_reg_ratio, 398 | betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), 399 | ) 400 | d_optim = optim.Adam( 401 | discriminator.parameters(), 402 | lr=args.lr * d_reg_ratio, 403 | betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), 404 | ) 405 | 406 | if args.ckpt is not None: 407 | print("load model:", args.ckpt) 408 | 409 | ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) 410 | 411 | try: 412 | ckpt_name = os.path.basename(args.ckpt) 413 | args.start_iter = int(os.path.splitext(ckpt_name)[0]) 414 | 415 | except ValueError: 416 | pass 417 | 418 | generator.load_state_dict(ckpt["g"]) 419 | discriminator.load_state_dict(ckpt["d"]) 420 | g_ema.load_state_dict(ckpt["g_ema"]) 421 | 422 | g_optim.load_state_dict(ckpt["g_optim"]) 423 | d_optim.load_state_dict(ckpt["d_optim"]) 424 | 425 | if args.distributed: 426 | generator = nn.parallel.DistributedDataParallel( 427 | generator, 428 | device_ids=[args.local_rank], 429 | output_device=args.local_rank, 430 | broadcast_buffers=False, 431 | ) 432 | 433 | discriminator = nn.parallel.DistributedDataParallel( 434 | discriminator, 435 | device_ids=[args.local_rank], 436 | output_device=args.local_rank, 437 | broadcast_buffers=False, 438 | ) 439 | 440 | transform = transforms.Compose( 441 | [ 442 | transforms.RandomHorizontalFlip(), 443 | transforms.ToTensor(), 444 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 445 | ] 446 | ) 447 | 448 | dataset = MultiResolutionDataset(args.path, transform, args.size) 449 | loader = data.DataLoader( 450 | dataset, 451 | batch_size=args.batch, 452 | sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), 453 | drop_last=True, 454 | ) 455 | 456 | if get_rank() == 0 and wandb is not None and args.wandb: 457 | wandb.init(project="stylegan 2") 458 | 459 | train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device) 460 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import random 4 | import os 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn, autograd, optim 9 | from torch.nn import functional as F 10 | from torch.utils import data 11 | import torch.distributed as dist 12 | from torchvision import transforms, utils 13 | from tqdm import tqdm 14 | 15 | try: 16 | import wandb 17 | 18 | except ImportError: 19 | wandb = None 20 | 21 | from model import Encoder, Generator, Discriminator, CooccurDiscriminator 22 | from stylegan2.dataset import MultiResolutionDataset 23 | from stylegan2.distributed import ( 24 | get_rank, 25 | synchronize, 26 | reduce_loss_dict, 27 | reduce_sum, 28 | get_world_size, 29 | ) 30 | 31 | 32 | def data_sampler(dataset, shuffle, distributed): 33 | if distributed: 34 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 35 | 36 | if shuffle: 37 | return data.RandomSampler(dataset) 38 | 39 | else: 40 | return data.SequentialSampler(dataset) 41 | 42 | 43 | def requires_grad(model, flag=True): 44 | for p in model.parameters(): 45 | p.requires_grad = flag 46 | 47 | 48 | def accumulate(model1, model2, decay=0.999): 49 | par1 = dict(model1.named_parameters()) 50 | par2 = dict(model2.named_parameters()) 51 | 52 | for k in par1.keys(): 53 | par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay) 54 | 55 | 56 | def sample_data(loader): 57 | while True: 58 | for batch in loader: 59 | yield batch 60 | 61 | 62 | def d_logistic_loss(real_pred, fake_pred): 63 | real_loss = F.softplus(-real_pred) 64 | fake_loss = F.softplus(fake_pred) 65 | 66 | return real_loss.mean() + fake_loss.mean() 67 | 68 | 69 | def d_r1_loss(real_pred, real_img): 70 | (grad_real,) = autograd.grad( 71 | outputs=real_pred.sum(), inputs=real_img, create_graph=True 72 | ) 73 | grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() 74 | 75 | return grad_penalty 76 | 77 | 78 | def g_nonsaturating_loss(fake_pred): 79 | loss = F.softplus(-fake_pred).mean() 80 | 81 | return loss 82 | 83 | 84 | def set_grad_none(model, targets): 85 | for n, p in model.named_parameters(): 86 | if n in targets: 87 | p.grad = None 88 | 89 | 90 | def patchify_image(img, n_crop, min_size=1 / 8, max_size=1 / 4): 91 | crop_size = torch.rand(n_crop) * (max_size - min_size) + min_size 92 | batch, channel, height, width = img.shape 93 | target_h = int(height * max_size) 94 | target_w = int(width * max_size) 95 | crop_h = (crop_size * height).type(torch.int64).tolist() 96 | crop_w = (crop_size * width).type(torch.int64).tolist() 97 | 98 | patches = [] 99 | for c_h, c_w in zip(crop_h, crop_w): 100 | c_y = random.randrange(0, height - c_h) 101 | c_x = random.randrange(0, width - c_w) 102 | 103 | cropped = img[:, :, c_y : c_y + c_h, c_x : c_x + c_w] 104 | cropped = F.interpolate( 105 | cropped, size=(target_h, target_w), mode="bilinear", align_corners=False 106 | ) 107 | 108 | patches.append(cropped) 109 | 110 | patches = torch.stack(patches, 1).view(-1, channel, target_h, target_w) 111 | 112 | return patches 113 | 114 | 115 | def train( 116 | args, 117 | loader, 118 | encoder, 119 | generator, 120 | discriminator, 121 | cooccur, 122 | g_optim, 123 | d_optim, 124 | e_ema, 125 | g_ema, 126 | device, 127 | ): 128 | loader = sample_data(loader) 129 | 130 | pbar = range(args.iter) 131 | 132 | if get_rank() == 0: 133 | pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) 134 | 135 | d_loss_val = 0 136 | r1_loss = torch.tensor(0.0, device=device) 137 | g_loss_val = 0 138 | loss_dict = {} 139 | 140 | if args.distributed: 141 | e_module = encoder.module 142 | g_module = generator.module 143 | d_module = discriminator.module 144 | c_module = cooccur.module 145 | 146 | else: 147 | e_module = encoder 148 | g_module = generator 149 | d_module = discriminator 150 | c_module = cooccur 151 | 152 | accum = 0.5 ** (32 / (10 * 1000)) 153 | 154 | for idx in pbar: 155 | i = idx + args.start_iter 156 | 157 | if i > args.iter: 158 | print("Done!") 159 | 160 | break 161 | 162 | real_img = next(loader) 163 | real_img = real_img.to(device) 164 | 165 | requires_grad(encoder, False) 166 | requires_grad(generator, False) 167 | requires_grad(discriminator, True) 168 | requires_grad(cooccur, True) 169 | 170 | real_img1, real_img2 = real_img.chunk(2, dim=0) 171 | 172 | structure1, texture1 = encoder(real_img1) 173 | _, texture2 = encoder(real_img2) 174 | 175 | fake_img1 = generator(structure1, texture1) 176 | fake_img2 = generator(structure1, texture2) 177 | 178 | fake_pred = discriminator(torch.cat((fake_img1, fake_img2), 0)) 179 | real_pred = discriminator(real_img) 180 | d_loss = d_logistic_loss(real_pred, fake_pred) 181 | 182 | fake_patch = patchify_image(fake_img2, args.n_crop) 183 | real_patch = patchify_image(real_img2, args.n_crop) 184 | ref_patch = patchify_image(real_img2, args.ref_crop * args.n_crop) 185 | fake_patch_pred, ref_input = cooccur( 186 | fake_patch, ref_patch, ref_batch=args.ref_crop 187 | ) 188 | real_patch_pred, _ = cooccur(real_patch, ref_input=ref_input) 189 | cooccur_loss = d_logistic_loss(real_patch_pred, fake_patch_pred) 190 | 191 | loss_dict["d"] = d_loss 192 | loss_dict["cooccur"] = cooccur_loss 193 | loss_dict["real_score"] = real_pred.mean() 194 | fake_pred1, fake_pred2 = fake_pred.chunk(2, dim=0) 195 | loss_dict["fake_score"] = fake_pred1.mean() 196 | loss_dict["hybrid_score"] = fake_pred2.mean() 197 | 198 | d_optim.zero_grad() 199 | (d_loss + cooccur_loss).backward() 200 | d_optim.step() 201 | 202 | d_regularize = i % args.d_reg_every == 0 203 | 204 | if d_regularize: 205 | real_img.requires_grad = True 206 | real_pred = discriminator(real_img) 207 | r1_loss = d_r1_loss(real_pred, real_img) 208 | 209 | real_patch.requires_grad = True 210 | real_patch_pred, _ = cooccur(real_patch, ref_patch, ref_batch=args.ref_crop) 211 | cooccur_r1_loss = d_r1_loss(real_patch_pred, real_patch) 212 | 213 | d_optim.zero_grad() 214 | 215 | r1_loss_sum = args.r1 / 2 * r1_loss * args.d_reg_every 216 | r1_loss_sum += args.cooccur_r1 / 2 * cooccur_r1_loss * args.d_reg_every 217 | r1_loss_sum += 0 * real_pred[0, 0] + 0 * real_patch_pred[0, 0] 218 | r1_loss_sum.backward() 219 | 220 | d_optim.step() 221 | 222 | loss_dict["r1"] = r1_loss 223 | loss_dict["cooccur_r1"] = cooccur_r1_loss 224 | 225 | requires_grad(encoder, True) 226 | requires_grad(generator, True) 227 | requires_grad(discriminator, False) 228 | requires_grad(cooccur, False) 229 | 230 | real_img.requires_grad = False 231 | 232 | structure1, texture1 = encoder(real_img1) 233 | _, texture2 = encoder(real_img2) 234 | 235 | fake_img1 = generator(structure1, texture1) 236 | fake_img2 = generator(structure1, texture2) 237 | 238 | recon_loss = F.l1_loss(fake_img1, real_img1) 239 | 240 | fake_pred = discriminator(torch.cat((fake_img1, fake_img2), 0)) 241 | g_loss = g_nonsaturating_loss(fake_pred) 242 | 243 | fake_patch = patchify_image(fake_img2, args.n_crop) 244 | ref_patch = patchify_image(real_img2, args.ref_crop * args.n_crop) 245 | fake_patch_pred, _ = cooccur(fake_patch, ref_patch, ref_batch=args.ref_crop) 246 | g_cooccur_loss = g_nonsaturating_loss(fake_patch_pred) 247 | 248 | loss_dict["recon"] = recon_loss 249 | loss_dict["g"] = g_loss 250 | loss_dict["g_cooccur"] = g_cooccur_loss 251 | 252 | g_optim.zero_grad() 253 | (recon_loss + g_loss + g_cooccur_loss).backward() 254 | g_optim.step() 255 | 256 | accumulate(e_ema, e_module, accum) 257 | accumulate(g_ema, g_module, accum) 258 | 259 | loss_reduced = reduce_loss_dict(loss_dict) 260 | 261 | d_loss_val = loss_reduced["d"].mean().item() 262 | cooccur_val = loss_reduced["cooccur"].mean().item() 263 | recon_val = loss_reduced["recon"].mean().item() 264 | g_loss_val = loss_reduced["g"].mean().item() 265 | g_cooccur_val = loss_reduced["g_cooccur"].mean().item() 266 | r1_val = loss_reduced["r1"].mean().item() 267 | cooccur_r1_val = loss_reduced["cooccur_r1"].mean().item() 268 | real_score_val = loss_reduced["real_score"].mean().item() 269 | fake_score_val = loss_reduced["fake_score"].mean().item() 270 | hybrid_score_val = loss_reduced["hybrid_score"].mean().item() 271 | 272 | if get_rank() == 0: 273 | pbar.set_description( 274 | ( 275 | f"d: {d_loss_val:.4f}; c: {cooccur_val:.4f} g: {g_loss_val:.4f}; " 276 | f"g_cooccur: {g_cooccur_val:.4f}; recon: {recon_val:.4f}; r1: {r1_val:.4f}; " 277 | f"r1_cooccur: {cooccur_r1_val:.4f}" 278 | ) 279 | ) 280 | 281 | if wandb and args.wandb and i % 10 == 0: 282 | wandb.log( 283 | { 284 | "Generator": g_loss_val, 285 | "Discriminator": d_loss_val, 286 | "Cooccur": cooccur_val, 287 | "Recon": recon_val, 288 | "Generator Cooccur": g_cooccur_val, 289 | "R1": r1_val, 290 | "Cooccur R1": cooccur_r1_val, 291 | "Real Score": real_score_val, 292 | "Fake Score": fake_score_val, 293 | "Hybrid Score": hybrid_score_val, 294 | }, 295 | step=i, 296 | ) 297 | 298 | if i % 100 == 0: 299 | with torch.no_grad(): 300 | e_ema.eval() 301 | g_ema.eval() 302 | 303 | structure1, texture1 = e_ema(real_img1) 304 | _, texture2 = e_ema(real_img2) 305 | 306 | fake_img1 = g_ema(structure1, texture1) 307 | fake_img2 = g_ema(structure1, texture2) 308 | 309 | sample = torch.cat((fake_img1, fake_img2), 0) 310 | 311 | utils.save_image( 312 | sample, 313 | f"sample/{str(i).zfill(6)}.png", 314 | nrow=int(sample.shape[0] ** 0.5), 315 | normalize=True, 316 | range=(-1, 1), 317 | ) 318 | 319 | if i % 10000 == 0: 320 | torch.save( 321 | { 322 | "e": e_module.state_dict(), 323 | "g": g_module.state_dict(), 324 | "d": d_module.state_dict(), 325 | "cooccur": c_module.state_dict(), 326 | "e_ema": e_ema.state_dict(), 327 | "g_ema": g_ema.state_dict(), 328 | "g_optim": g_optim.state_dict(), 329 | "d_optim": d_optim.state_dict(), 330 | "args": args, 331 | }, 332 | f"checkpoint/{str(i).zfill(6)}.pt", 333 | ) 334 | 335 | 336 | if __name__ == "__main__": 337 | device = "cuda" 338 | 339 | torch.backends.cudnn.benchmark = True 340 | 341 | parser = argparse.ArgumentParser() 342 | 343 | parser.add_argument("path", type=str, nargs="+") 344 | parser.add_argument("--iter", type=int, default=800000) 345 | parser.add_argument("--batch", type=int, default=16) 346 | parser.add_argument("--size", type=int, default=256) 347 | parser.add_argument("--r1", type=float, default=10) 348 | parser.add_argument("--cooccur_r1", type=float, default=1) 349 | parser.add_argument("--ref_crop", type=int, default=4) 350 | parser.add_argument("--n_crop", type=int, default=8) 351 | parser.add_argument("--d_reg_every", type=int, default=16) 352 | parser.add_argument("--ckpt", type=str, default=None) 353 | parser.add_argument("--lr", type=float, default=0.002) 354 | parser.add_argument("--channel", type=int, default=32) 355 | parser.add_argument("--channel_multiplier", type=int, default=1) 356 | parser.add_argument("--wandb", action="store_true") 357 | parser.add_argument("--local_rank", type=int, default=0) 358 | 359 | args = parser.parse_args() 360 | 361 | n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 362 | args.distributed = n_gpu > 1 363 | 364 | if args.distributed: 365 | torch.cuda.set_device(args.local_rank) 366 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 367 | synchronize() 368 | 369 | args.latent = 512 370 | args.n_mlp = 8 371 | 372 | args.start_iter = 0 373 | 374 | encoder = Encoder(args.channel).to(device) 375 | generator = Generator(args.channel).to(device) 376 | 377 | discriminator = Discriminator( 378 | args.size, channel_multiplier=args.channel_multiplier 379 | ).to(device) 380 | cooccur = CooccurDiscriminator(args.channel).to(device) 381 | 382 | e_ema = Encoder(args.channel).to(device) 383 | g_ema = Generator(args.channel).to(device) 384 | e_ema.eval() 385 | g_ema.eval() 386 | accumulate(e_ema, encoder, 0) 387 | accumulate(g_ema, generator, 0) 388 | 389 | d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) 390 | 391 | g_optim = optim.Adam( 392 | list(encoder.parameters()) + list(generator.parameters()), 393 | lr=args.lr, 394 | betas=(0, 0.99), 395 | ) 396 | d_optim = optim.Adam( 397 | list(discriminator.parameters()) + list(cooccur.parameters()), 398 | lr=args.lr * d_reg_ratio, 399 | betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), 400 | ) 401 | 402 | if args.ckpt is not None: 403 | print("load model:", args.ckpt) 404 | 405 | ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) 406 | 407 | try: 408 | ckpt_name = os.path.basename(args.ckpt) 409 | args.start_iter = int(os.path.splitext(ckpt_name)[0]) 410 | 411 | except ValueError: 412 | pass 413 | 414 | encoder.load_state_dict(ckpt["e"]) 415 | generator.load_state_dict(ckpt["g"]) 416 | discriminator.load_state_dict(ckpt["d"]) 417 | cooccur.load_state_dict(ckpt["cooccur"]) 418 | e_ema.load_state_dict(ckpt["e_ema"]) 419 | g_ema.load_state_dict(ckpt["g_ema"]) 420 | 421 | g_optim.load_state_dict(ckpt["g_optim"]) 422 | d_optim.load_state_dict(ckpt["d_optim"]) 423 | 424 | if args.distributed: 425 | encoder = nn.parallel.DistributedDataParallel( 426 | encoder, 427 | device_ids=[args.local_rank], 428 | output_device=args.local_rank, 429 | broadcast_buffers=False, 430 | ) 431 | 432 | generator = nn.parallel.DistributedDataParallel( 433 | generator, 434 | device_ids=[args.local_rank], 435 | output_device=args.local_rank, 436 | broadcast_buffers=False, 437 | ) 438 | 439 | discriminator = nn.parallel.DistributedDataParallel( 440 | discriminator, 441 | device_ids=[args.local_rank], 442 | output_device=args.local_rank, 443 | broadcast_buffers=False, 444 | ) 445 | 446 | cooccur = nn.parallel.DistributedDataParallel( 447 | cooccur, 448 | device_ids=[args.local_rank], 449 | output_device=args.local_rank, 450 | broadcast_buffers=False, 451 | ) 452 | 453 | transform = transforms.Compose( 454 | [ 455 | transforms.RandomHorizontalFlip(), 456 | transforms.ToTensor(), 457 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 458 | ] 459 | ) 460 | 461 | datasets = [] 462 | 463 | for path in args.path: 464 | dataset = MultiResolutionDataset(path, transform, args.size) 465 | datasets.append(dataset) 466 | 467 | loader = data.DataLoader( 468 | data.ConcatDataset(datasets), 469 | batch_size=args.batch, 470 | sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), 471 | drop_last=True, 472 | ) 473 | 474 | if get_rank() == 0 and wandb is not None and args.wandb: 475 | wandb.init(project="swapping autoencoder") 476 | 477 | train( 478 | args, 479 | loader, 480 | encoder, 481 | generator, 482 | discriminator, 483 | cooccur, 484 | g_optim, 485 | d_optim, 486 | e_ema, 487 | g_ema, 488 | device, 489 | ) 490 | 491 | --------------------------------------------------------------------------------