├── .gitignore ├── LICENSE_stylegan2 ├── README.md ├── StyleSpace_example.ipynb ├── checkpoint └── .gitignore ├── generate.py ├── model.py ├── op ├── __init__.py ├── conv2d_gradfix.py ├── fused_act.py ├── fused_bias_act.cpp ├── fused_bias_act_kernel.cu ├── upfirdn2d.cpp ├── upfirdn2d.py └── upfirdn2d_kernel.cu ├── results ├── .gitignore ├── default.png ├── eye.png ├── hair.png ├── lip.png └── mouth.png └── stylespace.py /.gitignore: -------------------------------------------------------------------------------- 1 | results2/ 2 | results3/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /LICENSE_stylegan2: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simple StyleSpace Run 2 | 3 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EiEWC0426GfjUl7g621vKfROl54w0iaP?usp=sharing) 4 | 5 | | default | eye | hair | lip | mouth | 6 | |:---:|:---:|:---:|:---:|:---:| 7 | | ![img](./results/default.png) | ![img](./results/eye.png) | ![img](./results/hair.png) | ![img](./results/lip.png) | ![img](./results/mouth.png) | 8 | 9 | 10 | ### Run StyleSpace 11 | 1. locate `stylegan2-ffhq-config-f.pt` under the folder `checkpoint` 12 | 2. run `python stylespace.py` 13 | 14 | 15 | 16 | ## Credits 17 | 18 | StyleGAN2 model and implementation: 19 | https://github.com/rosinality/stylegan2-pytorch 20 | Copyright (c) 2019 Kim Seonghyeon 21 | License (MIT) https://github.com/rosinality/stylegan2-pytorch/blob/master/LICENSE 22 | 23 | StyleSpace pytorch implementation 24 | https://github.com/xrenaa/StyleSpace-pytorch 25 | -------------------------------------------------------------------------------- /checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torchvision import utils 5 | from model import Generator 6 | from tqdm import tqdm 7 | 8 | 9 | def generate(args, g_ema, device, mean_latent): 10 | 11 | with torch.no_grad(): 12 | g_ema.eval() 13 | for i in tqdm(range(args.pics)): 14 | sample_z = torch.randn(args.sample, args.latent, device=device) 15 | 16 | sample, _ = g_ema( 17 | [sample_z], truncation=args.truncation, truncation_latent=mean_latent 18 | ) 19 | 20 | utils.save_image( 21 | sample, 22 | f"results/{str(i).zfill(6)}.png", 23 | nrow=1, 24 | normalize=True, 25 | range=(-1, 1), 26 | ) 27 | 28 | 29 | if __name__ == "__main__": 30 | device = "cuda" 31 | 32 | parser = argparse.ArgumentParser(description="Generate samples from the generator") 33 | 34 | parser.add_argument( 35 | "--size", type=int, default=1024, help="output image size of the generator" 36 | ) 37 | parser.add_argument( 38 | "--sample", 39 | type=int, 40 | default=1, 41 | help="number of samples to be generated for each image", 42 | ) 43 | parser.add_argument( 44 | "--pics", type=int, default=5, help="number of images to be generated" 45 | ) 46 | parser.add_argument("--truncation", type=float, default=1, help="truncation ratio") 47 | parser.add_argument( 48 | "--truncation_mean", 49 | type=int, 50 | default=4096, 51 | help="number of vectors to calculate mean for the truncation", 52 | ) 53 | parser.add_argument( 54 | "--ckpt", 55 | type=str, 56 | default="checkpoint/stylegan2-ffhq-config-f.pt", 57 | help="path to the model checkpoint", 58 | ) 59 | parser.add_argument( 60 | "--channel_multiplier", 61 | type=int, 62 | default=2, 63 | help="channel multiplier of the generator. config-f = 2, else = 1", 64 | ) 65 | 66 | args = parser.parse_args() 67 | 68 | args.latent = 512 69 | args.n_mlp = 8 70 | 71 | g_ema = Generator( 72 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier 73 | ).to(device) 74 | checkpoint = torch.load(args.ckpt) 75 | 76 | g_ema.load_state_dict(checkpoint["g_ema"], strict=False) 77 | 78 | if args.truncation < 1: 79 | with torch.no_grad(): 80 | mean_latent = g_ema.mean_latent(args.truncation_mean) # mean_latent.shape = (1, 512) 81 | else: 82 | mean_latent = None 83 | 84 | generate(args, g_ema, device, mean_latent) 85 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix 9 | 10 | 11 | class PixelNorm(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def forward(self, input): 16 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 17 | 18 | 19 | def make_kernel(k): 20 | k = torch.tensor(k, dtype=torch.float32) 21 | 22 | if k.ndim == 1: 23 | k = k[None, :] * k[:, None] 24 | 25 | k /= k.sum() 26 | 27 | return k 28 | 29 | 30 | class Upsample(nn.Module): 31 | def __init__(self, kernel, factor=2): 32 | super().__init__() 33 | 34 | self.factor = factor 35 | kernel = make_kernel(kernel) * (factor ** 2) 36 | self.register_buffer("kernel", kernel) 37 | 38 | p = kernel.shape[0] - factor 39 | 40 | pad0 = (p + 1) // 2 + factor - 1 41 | pad1 = p // 2 42 | 43 | self.pad = (pad0, pad1) 44 | 45 | def forward(self, input): 46 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) 47 | 48 | return out 49 | 50 | 51 | class Downsample(nn.Module): 52 | def __init__(self, kernel, factor=2): 53 | super().__init__() 54 | 55 | self.factor = factor 56 | kernel = make_kernel(kernel) 57 | self.register_buffer("kernel", kernel) 58 | 59 | p = kernel.shape[0] - factor 60 | 61 | pad0 = (p + 1) // 2 62 | pad1 = p // 2 63 | 64 | self.pad = (pad0, pad1) 65 | 66 | def forward(self, input): 67 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) 68 | 69 | return out 70 | 71 | 72 | class Blur(nn.Module): 73 | def __init__(self, kernel, pad, upsample_factor=1): 74 | super().__init__() 75 | 76 | kernel = make_kernel(kernel) 77 | 78 | if upsample_factor > 1: 79 | kernel = kernel * (upsample_factor ** 2) 80 | 81 | self.register_buffer("kernel", kernel) 82 | 83 | self.pad = pad 84 | 85 | def forward(self, input): 86 | out = upfirdn2d(input, self.kernel, pad=self.pad) 87 | 88 | return out 89 | 90 | 91 | class EqualConv2d(nn.Module): 92 | def __init__( 93 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 94 | ): 95 | super().__init__() 96 | 97 | self.weight = nn.Parameter( 98 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 99 | ) 100 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 101 | 102 | self.stride = stride 103 | self.padding = padding 104 | 105 | if bias: 106 | self.bias = nn.Parameter(torch.zeros(out_channel)) 107 | 108 | else: 109 | self.bias = None 110 | 111 | def forward(self, input): 112 | out = conv2d_gradfix.conv2d( 113 | input, 114 | self.weight * self.scale, 115 | bias=self.bias, 116 | stride=self.stride, 117 | padding=self.padding, 118 | ) 119 | 120 | return out 121 | 122 | def __repr__(self): 123 | return ( 124 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," 125 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" 126 | ) 127 | 128 | 129 | class EqualLinear(nn.Module): 130 | def __init__( 131 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 132 | ): 133 | super().__init__() 134 | 135 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 136 | 137 | if bias: 138 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 139 | 140 | else: 141 | self.bias = None 142 | 143 | self.activation = activation 144 | 145 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 146 | self.lr_mul = lr_mul 147 | 148 | def forward(self, input): 149 | if self.activation: 150 | out = F.linear(input, self.weight * self.scale) 151 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 152 | 153 | else: 154 | out = F.linear( 155 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 156 | ) 157 | 158 | return out 159 | 160 | def __repr__(self): 161 | return ( 162 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" 163 | ) 164 | 165 | 166 | class ModulatedConv2d(nn.Module): 167 | def __init__( 168 | self, 169 | in_channel, 170 | out_channel, 171 | kernel_size, 172 | style_dim, 173 | demodulate=True, 174 | upsample=False, 175 | downsample=False, 176 | blur_kernel=[1, 3, 3, 1], 177 | fused=True, 178 | ): 179 | super().__init__() 180 | 181 | self.eps = 1e-8 182 | self.kernel_size = kernel_size 183 | self.in_channel = in_channel 184 | self.out_channel = out_channel 185 | self.upsample = upsample 186 | self.downsample = downsample 187 | 188 | if upsample: 189 | factor = 2 190 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 191 | pad0 = (p + 1) // 2 + factor - 1 192 | pad1 = p // 2 + 1 193 | 194 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) 195 | 196 | if downsample: 197 | factor = 2 198 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 199 | pad0 = (p + 1) // 2 200 | pad1 = p // 2 201 | 202 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 203 | 204 | fan_in = in_channel * kernel_size ** 2 205 | self.scale = 1 / math.sqrt(fan_in) 206 | self.padding = kernel_size // 2 207 | 208 | self.weight = nn.Parameter( 209 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 210 | ) 211 | 212 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 213 | 214 | self.demodulate = demodulate 215 | self.fused = fused 216 | 217 | def __repr__(self): 218 | return ( 219 | f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " 220 | f"upsample={self.upsample}, downsample={self.downsample})" 221 | ) 222 | 223 | def forward(self, input, style): 224 | batch, in_channel, height, width = input.shape 225 | 226 | if not self.fused: 227 | weight = self.scale * self.weight.squeeze(0) 228 | style = self.modulation(style) 229 | 230 | if self.demodulate: 231 | w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1) 232 | dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt() 233 | 234 | input = input * style.reshape(batch, in_channel, 1, 1) 235 | 236 | if self.upsample: 237 | weight = weight.transpose(0, 1) 238 | out = conv2d_gradfix.conv_transpose2d( 239 | input, weight, padding=0, stride=2 240 | ) 241 | out = self.blur(out) 242 | 243 | elif self.downsample: 244 | input = self.blur(input) 245 | out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2) 246 | 247 | else: 248 | out = conv2d_gradfix.conv2d(input, weight, padding=self.padding) 249 | 250 | if self.demodulate: 251 | out = out * dcoefs.view(batch, -1, 1, 1) 252 | 253 | return out 254 | 255 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1) 256 | weight = self.scale * self.weight * style 257 | 258 | if self.demodulate: 259 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 260 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 261 | 262 | weight = weight.view( 263 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 264 | ) 265 | 266 | if self.upsample: 267 | input = input.view(1, batch * in_channel, height, width) 268 | weight = weight.view( 269 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 270 | ) 271 | weight = weight.transpose(1, 2).reshape( 272 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 273 | ) 274 | out = conv2d_gradfix.conv_transpose2d( 275 | input, weight, padding=0, stride=2, groups=batch 276 | ) 277 | _, _, height, width = out.shape 278 | out = out.view(batch, self.out_channel, height, width) 279 | out = self.blur(out) 280 | 281 | elif self.downsample: 282 | input = self.blur(input) 283 | _, _, height, width = input.shape 284 | input = input.view(1, batch * in_channel, height, width) 285 | out = conv2d_gradfix.conv2d( 286 | input, weight, padding=0, stride=2, groups=batch 287 | ) 288 | _, _, height, width = out.shape 289 | out = out.view(batch, self.out_channel, height, width) 290 | 291 | else: 292 | input = input.view(1, batch * in_channel, height, width) 293 | out = conv2d_gradfix.conv2d( 294 | input, weight, padding=self.padding, groups=batch 295 | ) 296 | _, _, height, width = out.shape 297 | out = out.view(batch, self.out_channel, height, width) 298 | 299 | return out 300 | 301 | 302 | class NoiseInjection(nn.Module): 303 | def __init__(self): 304 | super().__init__() 305 | 306 | self.weight = nn.Parameter(torch.zeros(1)) 307 | 308 | def forward(self, image, noise=None): 309 | if noise is None: 310 | batch, _, height, width = image.shape 311 | noise = image.new_empty(batch, 1, height, width).normal_() 312 | 313 | return image + self.weight * noise 314 | 315 | 316 | class ConstantInput(nn.Module): 317 | def __init__(self, channel, size=4): 318 | super().__init__() 319 | 320 | self.input = nn.Parameter(torch.randn(1, channel, size, size)) 321 | 322 | def forward(self, input): 323 | batch = input.shape[0] 324 | out = self.input.repeat(batch, 1, 1, 1) 325 | 326 | return out 327 | 328 | 329 | class StyledConv(nn.Module): 330 | def __init__( 331 | self, 332 | in_channel, 333 | out_channel, 334 | kernel_size, 335 | style_dim, 336 | upsample=False, 337 | blur_kernel=[1, 3, 3, 1], 338 | demodulate=True, 339 | ): 340 | super().__init__() 341 | 342 | self.conv = ModulatedConv2d( 343 | in_channel, 344 | out_channel, 345 | kernel_size, 346 | style_dim, 347 | upsample=upsample, 348 | blur_kernel=blur_kernel, 349 | demodulate=demodulate, 350 | ) 351 | 352 | self.noise = NoiseInjection() 353 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 354 | # self.activate = ScaledLeakyReLU(0.2) 355 | self.activate = FusedLeakyReLU(out_channel) 356 | 357 | def forward(self, input, style, noise=None): 358 | out = self.conv(input, style) 359 | out = self.noise(out, noise=noise) 360 | # out = out + self.bias 361 | out = self.activate(out) 362 | 363 | return out 364 | 365 | 366 | class ToRGB(nn.Module): 367 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): 368 | super().__init__() 369 | 370 | if upsample: 371 | self.upsample = Upsample(blur_kernel) 372 | 373 | self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) 374 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) 375 | 376 | def forward(self, input, style, skip=None): 377 | out = self.conv(input, style) 378 | out = out + self.bias 379 | 380 | if skip is not None: 381 | skip = self.upsample(skip) 382 | 383 | out = out + skip 384 | 385 | return out 386 | 387 | 388 | class Generator(nn.Module): 389 | def __init__( 390 | self, 391 | size, 392 | style_dim, 393 | n_mlp, 394 | channel_multiplier=2, 395 | blur_kernel=[1, 3, 3, 1], 396 | lr_mlp=0.01, 397 | ): 398 | super().__init__() 399 | 400 | self.size = size 401 | 402 | self.style_dim = style_dim 403 | 404 | layers = [PixelNorm()] 405 | 406 | for i in range(n_mlp): 407 | layers.append( 408 | EqualLinear( 409 | style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" 410 | ) 411 | ) 412 | 413 | self.style = nn.Sequential(*layers) 414 | 415 | self.channels = { 416 | 4: 512, 417 | 8: 512, 418 | 16: 512, 419 | 32: 512, 420 | 64: 256 * channel_multiplier, 421 | 128: 128 * channel_multiplier, 422 | 256: 64 * channel_multiplier, 423 | 512: 32 * channel_multiplier, 424 | 1024: 16 * channel_multiplier, 425 | } 426 | 427 | self.input = ConstantInput(self.channels[4]) 428 | self.conv1 = StyledConv( 429 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel 430 | ) 431 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) 432 | 433 | self.log_size = int(math.log(size, 2)) 434 | self.num_layers = (self.log_size - 2) * 2 + 1 435 | 436 | self.convs = nn.ModuleList() 437 | self.upsamples = nn.ModuleList() 438 | self.to_rgbs = nn.ModuleList() 439 | self.noises = nn.Module() 440 | 441 | in_channel = self.channels[4] 442 | 443 | for layer_idx in range(self.num_layers): 444 | res = (layer_idx + 5) // 2 445 | shape = [1, 1, 2 ** res, 2 ** res] 446 | self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape)) 447 | 448 | for i in range(3, self.log_size + 1): 449 | out_channel = self.channels[2 ** i] 450 | 451 | self.convs.append( 452 | StyledConv( 453 | in_channel, 454 | out_channel, 455 | 3, 456 | style_dim, 457 | upsample=True, 458 | blur_kernel=blur_kernel, 459 | ) 460 | ) 461 | 462 | self.convs.append( 463 | StyledConv( 464 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel 465 | ) 466 | ) 467 | 468 | self.to_rgbs.append(ToRGB(out_channel, style_dim)) 469 | 470 | in_channel = out_channel 471 | 472 | self.n_latent = self.log_size * 2 - 2 473 | 474 | def make_noise(self): 475 | device = self.input.input.device 476 | 477 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] 478 | 479 | for i in range(3, self.log_size + 1): 480 | for _ in range(2): 481 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) 482 | 483 | return noises 484 | 485 | def mean_latent(self, n_latent): 486 | latent_in = torch.randn( 487 | n_latent, self.style_dim, device=self.input.input.device 488 | ) 489 | latent = self.style(latent_in).mean(0, keepdim=True) 490 | 491 | return latent 492 | 493 | def get_latent(self, input): 494 | return self.style(input) 495 | 496 | def forward( 497 | self, 498 | styles, 499 | return_latents=False, 500 | inject_index=None, 501 | truncation=1, 502 | truncation_latent=None, 503 | input_is_latent=False, 504 | noise=None, 505 | randomize_noise=True, 506 | ): 507 | if not input_is_latent: 508 | styles = [self.style(s) for s in styles] 509 | 510 | if noise is None: 511 | if randomize_noise: 512 | noise = [None] * self.num_layers 513 | else: 514 | noise = [ 515 | getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) 516 | ] 517 | 518 | if truncation < 1: 519 | style_t = [] 520 | 521 | for style in styles: 522 | style_t.append( 523 | truncation_latent + truncation * (style - truncation_latent) 524 | ) 525 | 526 | styles = style_t 527 | 528 | if len(styles) < 2: 529 | inject_index = self.n_latent # 1024*1024: 18, 256*256: 14 530 | 531 | if styles[0].ndim < 3: 532 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) # (1, 18, 512) 533 | else: 534 | latent = styles[0] 535 | 536 | else: 537 | if inject_index is None: 538 | inject_index = random.randint(1, self.n_latent - 1) 539 | 540 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 541 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) 542 | 543 | latent = torch.cat([latent, latent2], 1) 544 | 545 | out = self.input(latent) # latent (this is not required... only to get the batch size): (1, 18, 512), out: (1, 512, 4, 4) 546 | out = self.conv1(out, latent[:, 0], noise=noise[0]) # out: (1, 512, 4, 4) 547 | skip = self.to_rgb1(out, latent[:, 1]) # skip: (1, 3, 4, 4) 548 | 549 | i = 1 550 | for conv1, conv2, noise1, noise2, to_rgb in zip( 551 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs 552 | ): 553 | out = conv1(out, latent[:, i], noise=noise1) 554 | # print("out.shape: ", out.shape) 555 | out = conv2(out, latent[:, i + 1], noise=noise2) 556 | # print("out.shape: ", out.shape) 557 | skip = to_rgb(out, latent[:, i + 2], skip) 558 | # print("skip.shape: ", skip.shape) 559 | # print(f"-{i}-", latent[:, i].shape, f"={i+1}=", latent[:, i+1].shape, f"-{i+2}-", latent[:, i+2].shape) 560 | 561 | i += 2 562 | # print(i) 563 | # exit() 564 | 565 | image = skip 566 | 567 | if return_latents: 568 | return image, latent 569 | 570 | else: 571 | return image, None 572 | 573 | 574 | class ConvLayer(nn.Sequential): 575 | def __init__( 576 | self, 577 | in_channel, 578 | out_channel, 579 | kernel_size, 580 | downsample=False, 581 | blur_kernel=[1, 3, 3, 1], 582 | bias=True, 583 | activate=True, 584 | ): 585 | layers = [] 586 | 587 | if downsample: 588 | factor = 2 589 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 590 | pad0 = (p + 1) // 2 591 | pad1 = p // 2 592 | 593 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 594 | 595 | stride = 2 596 | self.padding = 0 597 | 598 | else: 599 | stride = 1 600 | self.padding = kernel_size // 2 601 | 602 | layers.append( 603 | EqualConv2d( 604 | in_channel, 605 | out_channel, 606 | kernel_size, 607 | padding=self.padding, 608 | stride=stride, 609 | bias=bias and not activate, 610 | ) 611 | ) 612 | 613 | if activate: 614 | layers.append(FusedLeakyReLU(out_channel, bias=bias)) 615 | 616 | super().__init__(*layers) 617 | 618 | 619 | class ResBlock(nn.Module): 620 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): 621 | super().__init__() 622 | 623 | self.conv1 = ConvLayer(in_channel, in_channel, 3) 624 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) 625 | 626 | self.skip = ConvLayer( 627 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False 628 | ) 629 | 630 | def forward(self, input): 631 | out = self.conv1(input) 632 | out = self.conv2(out) 633 | 634 | skip = self.skip(input) 635 | out = (out + skip) / math.sqrt(2) 636 | 637 | return out 638 | 639 | 640 | class Discriminator(nn.Module): 641 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): 642 | super().__init__() 643 | 644 | channels = { 645 | 4: 512, 646 | 8: 512, 647 | 16: 512, 648 | 32: 512, 649 | 64: 256 * channel_multiplier, 650 | 128: 128 * channel_multiplier, 651 | 256: 64 * channel_multiplier, 652 | 512: 32 * channel_multiplier, 653 | 1024: 16 * channel_multiplier, 654 | } 655 | 656 | convs = [ConvLayer(3, channels[size], 1)] 657 | 658 | log_size = int(math.log(size, 2)) 659 | 660 | in_channel = channels[size] 661 | 662 | for i in range(log_size, 2, -1): 663 | out_channel = channels[2 ** (i - 1)] 664 | 665 | convs.append(ResBlock(in_channel, out_channel, blur_kernel)) 666 | 667 | in_channel = out_channel 668 | 669 | self.convs = nn.Sequential(*convs) 670 | 671 | self.stddev_group = 4 672 | self.stddev_feat = 1 673 | 674 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) 675 | self.final_linear = nn.Sequential( 676 | EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), 677 | EqualLinear(channels[4], 1), 678 | ) 679 | 680 | def forward(self, input): 681 | out = self.convs(input) 682 | 683 | batch, channel, height, width = out.shape 684 | group = min(batch, self.stddev_group) 685 | stddev = out.view( 686 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width 687 | ) 688 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 689 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 690 | stddev = stddev.repeat(group, 1, height, width) 691 | out = torch.cat([out, stddev], 1) 692 | 693 | out = self.final_conv(out) 694 | 695 | out = out.view(batch, -1) 696 | out = self.final_linear(out) 697 | 698 | return out 699 | 700 | -------------------------------------------------------------------------------- /op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /op/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import warnings 3 | 4 | import torch 5 | from torch import autograd 6 | from torch.nn import functional as F 7 | 8 | enabled = True 9 | weight_gradients_disabled = False 10 | 11 | 12 | @contextlib.contextmanager 13 | def no_weight_gradients(): 14 | global weight_gradients_disabled 15 | 16 | old = weight_gradients_disabled 17 | weight_gradients_disabled = True 18 | yield 19 | weight_gradients_disabled = old 20 | 21 | 22 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 23 | if could_use_op(input): 24 | return conv2d_gradfix( 25 | transpose=False, 26 | weight_shape=weight.shape, 27 | stride=stride, 28 | padding=padding, 29 | output_padding=0, 30 | dilation=dilation, 31 | groups=groups, 32 | ).apply(input, weight, bias) 33 | 34 | return F.conv2d( 35 | input=input, 36 | weight=weight, 37 | bias=bias, 38 | stride=stride, 39 | padding=padding, 40 | dilation=dilation, 41 | groups=groups, 42 | ) 43 | 44 | 45 | def conv_transpose2d( 46 | input, 47 | weight, 48 | bias=None, 49 | stride=1, 50 | padding=0, 51 | output_padding=0, 52 | groups=1, 53 | dilation=1, 54 | ): 55 | if could_use_op(input): 56 | return conv2d_gradfix( 57 | transpose=True, 58 | weight_shape=weight.shape, 59 | stride=stride, 60 | padding=padding, 61 | output_padding=output_padding, 62 | groups=groups, 63 | dilation=dilation, 64 | ).apply(input, weight, bias) 65 | 66 | return F.conv_transpose2d( 67 | input=input, 68 | weight=weight, 69 | bias=bias, 70 | stride=stride, 71 | padding=padding, 72 | output_padding=output_padding, 73 | dilation=dilation, 74 | groups=groups, 75 | ) 76 | 77 | 78 | def could_use_op(input): 79 | if (not enabled) or (not torch.backends.cudnn.enabled): 80 | return False 81 | 82 | if input.device.type != "cuda": 83 | return False 84 | 85 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]): 86 | return True 87 | 88 | warnings.warn( 89 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()." 90 | ) 91 | 92 | return False 93 | 94 | 95 | def ensure_tuple(xs, ndim): 96 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 97 | 98 | return xs 99 | 100 | 101 | conv2d_gradfix_cache = dict() 102 | 103 | 104 | def conv2d_gradfix( 105 | transpose, weight_shape, stride, padding, output_padding, dilation, groups 106 | ): 107 | ndim = 2 108 | weight_shape = tuple(weight_shape) 109 | stride = ensure_tuple(stride, ndim) 110 | padding = ensure_tuple(padding, ndim) 111 | output_padding = ensure_tuple(output_padding, ndim) 112 | dilation = ensure_tuple(dilation, ndim) 113 | 114 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 115 | if key in conv2d_gradfix_cache: 116 | return conv2d_gradfix_cache[key] 117 | 118 | common_kwargs = dict( 119 | stride=stride, padding=padding, dilation=dilation, groups=groups 120 | ) 121 | 122 | def calc_output_padding(input_shape, output_shape): 123 | if transpose: 124 | return [0, 0] 125 | 126 | return [ 127 | input_shape[i + 2] 128 | - (output_shape[i + 2] - 1) * stride[i] 129 | - (1 - 2 * padding[i]) 130 | - dilation[i] * (weight_shape[i + 2] - 1) 131 | for i in range(ndim) 132 | ] 133 | 134 | class Conv2d(autograd.Function): 135 | @staticmethod 136 | def forward(ctx, input, weight, bias): 137 | if not transpose: 138 | out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 139 | 140 | else: 141 | out = F.conv_transpose2d( 142 | input=input, 143 | weight=weight, 144 | bias=bias, 145 | output_padding=output_padding, 146 | **common_kwargs, 147 | ) 148 | 149 | ctx.save_for_backward(input, weight) 150 | 151 | return out 152 | 153 | @staticmethod 154 | def backward(ctx, grad_output): 155 | input, weight = ctx.saved_tensors 156 | grad_input, grad_weight, grad_bias = None, None, None 157 | 158 | if ctx.needs_input_grad[0]: 159 | p = calc_output_padding( 160 | input_shape=input.shape, output_shape=grad_output.shape 161 | ) 162 | grad_input = conv2d_gradfix( 163 | transpose=(not transpose), 164 | weight_shape=weight_shape, 165 | output_padding=p, 166 | **common_kwargs, 167 | ).apply(grad_output, weight, None) 168 | 169 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 170 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 171 | 172 | if ctx.needs_input_grad[2]: 173 | grad_bias = grad_output.sum((0, 2, 3)) 174 | 175 | return grad_input, grad_weight, grad_bias 176 | 177 | class Conv2dGradWeight(autograd.Function): 178 | @staticmethod 179 | def forward(ctx, grad_output, input): 180 | op = torch._C._jit_get_operation( 181 | "aten::cudnn_convolution_backward_weight" 182 | if not transpose 183 | else "aten::cudnn_convolution_transpose_backward_weight" 184 | ) 185 | flags = [ 186 | torch.backends.cudnn.benchmark, 187 | torch.backends.cudnn.deterministic, 188 | torch.backends.cudnn.allow_tf32, 189 | ] 190 | grad_weight = op( 191 | weight_shape, 192 | grad_output, 193 | input, 194 | padding, 195 | stride, 196 | dilation, 197 | groups, 198 | *flags, 199 | ) 200 | ctx.save_for_backward(grad_output, input) 201 | 202 | return grad_weight 203 | 204 | @staticmethod 205 | def backward(ctx, grad_grad_weight): 206 | grad_output, input = ctx.saved_tensors 207 | grad_grad_output, grad_grad_input = None, None 208 | 209 | if ctx.needs_input_grad[0]: 210 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None) 211 | 212 | if ctx.needs_input_grad[1]: 213 | p = calc_output_padding( 214 | input_shape=input.shape, output_shape=grad_output.shape 215 | ) 216 | grad_grad_input = conv2d_gradfix( 217 | transpose=(not transpose), 218 | weight_shape=weight_shape, 219 | output_padding=p, 220 | **common_kwargs, 221 | ).apply(grad_output, grad_grad_weight, None) 222 | 223 | return grad_grad_output, grad_grad_input 224 | 225 | conv2d_gradfix_cache[key] = Conv2d 226 | 227 | return Conv2d 228 | -------------------------------------------------------------------------------- /op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, bias, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | if bias: 39 | grad_bias = grad_input.sum(dim).detach() 40 | 41 | else: 42 | grad_bias = empty 43 | 44 | return grad_input, grad_bias 45 | 46 | @staticmethod 47 | def backward(ctx, gradgrad_input, gradgrad_bias): 48 | out, = ctx.saved_tensors 49 | gradgrad_out = fused.fused_bias_act( 50 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 51 | ) 52 | 53 | return gradgrad_out, None, None, None, None 54 | 55 | 56 | class FusedLeakyReLUFunction(Function): 57 | @staticmethod 58 | def forward(ctx, input, bias, negative_slope, scale): 59 | empty = input.new_empty(0) 60 | 61 | ctx.bias = bias is not None 62 | 63 | if bias is None: 64 | bias = empty 65 | 66 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 67 | ctx.save_for_backward(out) 68 | ctx.negative_slope = negative_slope 69 | ctx.scale = scale 70 | 71 | return out 72 | 73 | @staticmethod 74 | def backward(ctx, grad_output): 75 | out, = ctx.saved_tensors 76 | 77 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 78 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale 79 | ) 80 | 81 | if not ctx.bias: 82 | grad_bias = None 83 | 84 | return grad_input, grad_bias, None, None 85 | 86 | 87 | class FusedLeakyReLU(nn.Module): 88 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 89 | super().__init__() 90 | 91 | if bias: 92 | self.bias = nn.Parameter(torch.zeros(channel)) 93 | 94 | else: 95 | self.bias = None 96 | 97 | self.negative_slope = negative_slope 98 | self.scale = scale 99 | 100 | def forward(self, input): 101 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 102 | 103 | 104 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 105 | if input.device.type == "cpu": 106 | if bias is not None: 107 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 108 | return ( 109 | F.leaky_relu( 110 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 111 | ) 112 | * scale 113 | ) 114 | 115 | else: 116 | return F.leaky_relu(input, negative_slope=0.2) * scale 117 | 118 | else: 119 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 120 | -------------------------------------------------------------------------------- /op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor &input, 6 | const torch::Tensor &bias, 7 | const torch::Tensor &refer, int act, int grad, 8 | float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) \ 11 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) \ 13 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 14 | #define CHECK_INPUT(x) \ 15 | CHECK_CUDA(x); \ 16 | CHECK_CONTIGUOUS(x) 17 | 18 | torch::Tensor fused_bias_act(const torch::Tensor &input, 19 | const torch::Tensor &bias, 20 | const torch::Tensor &refer, int act, int grad, 21 | float alpha, float scale) { 22 | CHECK_INPUT(input); 23 | CHECK_INPUT(bias); 24 | 25 | at::DeviceGuard guard(input.device()); 26 | 27 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 28 | } 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 32 | } -------------------------------------------------------------------------------- /op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | 15 | #include 16 | #include 17 | 18 | template 19 | static __global__ void 20 | fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b, 21 | const scalar_t *p_ref, int act, int grad, scalar_t alpha, 22 | scalar_t scale, int loop_x, int size_x, int step_b, 23 | int size_b, int use_bias, int use_ref) { 24 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 25 | 26 | scalar_t zero = 0.0; 27 | 28 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; 29 | loop_idx++, xi += blockDim.x) { 30 | scalar_t x = p_x[xi]; 31 | 32 | if (use_bias) { 33 | x += p_b[(xi / step_b) % size_b]; 34 | } 35 | 36 | scalar_t ref = use_ref ? p_ref[xi] : zero; 37 | 38 | scalar_t y; 39 | 40 | switch (act * 10 + grad) { 41 | default: 42 | case 10: 43 | y = x; 44 | break; 45 | case 11: 46 | y = x; 47 | break; 48 | case 12: 49 | y = 0.0; 50 | break; 51 | 52 | case 30: 53 | y = (x > 0.0) ? x : x * alpha; 54 | break; 55 | case 31: 56 | y = (ref > 0.0) ? x : x * alpha; 57 | break; 58 | case 32: 59 | y = 0.0; 60 | break; 61 | } 62 | 63 | out[xi] = y * scale; 64 | } 65 | } 66 | 67 | torch::Tensor fused_bias_act_op(const torch::Tensor &input, 68 | const torch::Tensor &bias, 69 | const torch::Tensor &refer, int act, int grad, 70 | float alpha, float scale) { 71 | int curDevice = -1; 72 | cudaGetDevice(&curDevice); 73 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 74 | 75 | auto x = input.contiguous(); 76 | auto b = bias.contiguous(); 77 | auto ref = refer.contiguous(); 78 | 79 | int use_bias = b.numel() ? 1 : 0; 80 | int use_ref = ref.numel() ? 1 : 0; 81 | 82 | int size_x = x.numel(); 83 | int size_b = b.numel(); 84 | int step_b = 1; 85 | 86 | for (int i = 1 + 1; i < x.dim(); i++) { 87 | step_b *= x.size(i); 88 | } 89 | 90 | int loop_x = 4; 91 | int block_size = 4 * 32; 92 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 93 | 94 | auto y = torch::empty_like(x); 95 | 96 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 97 | x.scalar_type(), "fused_bias_act_kernel", [&] { 98 | fused_bias_act_kernel<<>>( 99 | y.data_ptr(), x.data_ptr(), 100 | b.data_ptr(), ref.data_ptr(), act, grad, alpha, 101 | scale, loop_x, size_x, step_b, size_b, use_bias, use_ref); 102 | }); 103 | 104 | return y; 105 | } -------------------------------------------------------------------------------- /op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 5 | const torch::Tensor &kernel, int up_x, int up_y, 6 | int down_x, int down_y, int pad_x0, int pad_x1, 7 | int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) \ 10 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) \ 12 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) \ 14 | CHECK_CUDA(x); \ 15 | CHECK_CONTIGUOUS(x) 16 | 17 | torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel, 18 | int up_x, int up_y, int down_x, int down_y, int pad_x0, 19 | int pad_x1, int pad_y0, int pad_y1) { 20 | CHECK_INPUT(input); 21 | CHECK_INPUT(kernel); 22 | 23 | at::DeviceGuard guard(input.device()); 24 | 25 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, 26 | pad_y0, pad_y1); 27 | } 28 | 29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 30 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 31 | } -------------------------------------------------------------------------------- /op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | from collections import abc 2 | import os 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | upfirdn2d_op = load( 12 | "upfirdn2d", 13 | sources=[ 14 | os.path.join(module_path, "upfirdn2d.cpp"), 15 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class UpFirDn2dBackward(Function): 21 | @staticmethod 22 | def forward( 23 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 24 | ): 25 | 26 | up_x, up_y = up 27 | down_x, down_y = down 28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 29 | 30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 31 | 32 | grad_input = upfirdn2d_op.upfirdn2d( 33 | grad_output, 34 | grad_kernel, 35 | down_x, 36 | down_y, 37 | up_x, 38 | up_y, 39 | g_pad_x0, 40 | g_pad_x1, 41 | g_pad_y0, 42 | g_pad_y1, 43 | ) 44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 45 | 46 | ctx.save_for_backward(kernel) 47 | 48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 49 | 50 | ctx.up_x = up_x 51 | ctx.up_y = up_y 52 | ctx.down_x = down_x 53 | ctx.down_y = down_y 54 | ctx.pad_x0 = pad_x0 55 | ctx.pad_x1 = pad_x1 56 | ctx.pad_y0 = pad_y0 57 | ctx.pad_y1 = pad_y1 58 | ctx.in_size = in_size 59 | ctx.out_size = out_size 60 | 61 | return grad_input 62 | 63 | @staticmethod 64 | def backward(ctx, gradgrad_input): 65 | kernel, = ctx.saved_tensors 66 | 67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 68 | 69 | gradgrad_out = upfirdn2d_op.upfirdn2d( 70 | gradgrad_input, 71 | kernel, 72 | ctx.up_x, 73 | ctx.up_y, 74 | ctx.down_x, 75 | ctx.down_y, 76 | ctx.pad_x0, 77 | ctx.pad_x1, 78 | ctx.pad_y0, 79 | ctx.pad_y1, 80 | ) 81 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 82 | gradgrad_out = gradgrad_out.view( 83 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 84 | ) 85 | 86 | return gradgrad_out, None, None, None, None, None, None, None, None 87 | 88 | 89 | class UpFirDn2d(Function): 90 | @staticmethod 91 | def forward(ctx, input, kernel, up, down, pad): 92 | up_x, up_y = up 93 | down_x, down_y = down 94 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 95 | 96 | kernel_h, kernel_w = kernel.shape 97 | batch, channel, in_h, in_w = input.shape 98 | ctx.in_size = input.shape 99 | 100 | input = input.reshape(-1, in_h, in_w, 1) 101 | 102 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 103 | 104 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 105 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 106 | ctx.out_size = (out_h, out_w) 107 | 108 | ctx.up = (up_x, up_y) 109 | ctx.down = (down_x, down_y) 110 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 111 | 112 | g_pad_x0 = kernel_w - pad_x0 - 1 113 | g_pad_y0 = kernel_h - pad_y0 - 1 114 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 115 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 116 | 117 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 118 | 119 | out = upfirdn2d_op.upfirdn2d( 120 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 121 | ) 122 | # out = out.view(major, out_h, out_w, minor) 123 | out = out.view(-1, channel, out_h, out_w) 124 | 125 | return out 126 | 127 | @staticmethod 128 | def backward(ctx, grad_output): 129 | kernel, grad_kernel = ctx.saved_tensors 130 | 131 | grad_input = None 132 | 133 | if ctx.needs_input_grad[0]: 134 | grad_input = UpFirDn2dBackward.apply( 135 | grad_output, 136 | kernel, 137 | grad_kernel, 138 | ctx.up, 139 | ctx.down, 140 | ctx.pad, 141 | ctx.g_pad, 142 | ctx.in_size, 143 | ctx.out_size, 144 | ) 145 | 146 | return grad_input, None, None, None, None 147 | 148 | 149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 150 | if not isinstance(up, abc.Iterable): 151 | up = (up, up) 152 | 153 | if not isinstance(down, abc.Iterable): 154 | down = (down, down) 155 | 156 | if len(pad) == 2: 157 | pad = (pad[0], pad[1], pad[0], pad[1]) 158 | 159 | if input.device.type == "cpu": 160 | out = upfirdn2d_native(input, kernel, *up, *down, *pad) 161 | 162 | else: 163 | out = UpFirDn2d.apply(input, kernel, up, down, pad) 164 | 165 | return out 166 | 167 | 168 | def upfirdn2d_native( 169 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 170 | ): 171 | _, channel, in_h, in_w = input.shape 172 | input = input.reshape(-1, in_h, in_w, 1) 173 | 174 | _, in_h, in_w, minor = input.shape 175 | kernel_h, kernel_w = kernel.shape 176 | 177 | out = input.view(-1, in_h, 1, in_w, 1, minor) 178 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 179 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 180 | 181 | out = F.pad( 182 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 183 | ) 184 | out = out[ 185 | :, 186 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 187 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 188 | :, 189 | ] 190 | 191 | out = out.permute(0, 3, 1, 2) 192 | out = out.reshape( 193 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 194 | ) 195 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 196 | out = F.conv2d(out, w) 197 | out = out.reshape( 198 | -1, 199 | minor, 200 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 201 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 202 | ) 203 | out = out.permute(0, 2, 3, 1) 204 | out = out[:, ::down_y, ::down_x, :] 205 | 206 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 207 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 208 | 209 | return out.view(-1, channel, out_h, out_w) 210 | -------------------------------------------------------------------------------- /op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /results/.gitignore: -------------------------------------------------------------------------------- 1 | img_*.* -------------------------------------------------------------------------------- /results/default.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyungkwonko/StyleSpace-pytorch/c001d4ba97425b95f75dce725032701dd5a776d7/results/default.png -------------------------------------------------------------------------------- /results/eye.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyungkwonko/StyleSpace-pytorch/c001d4ba97425b95f75dce725032701dd5a776d7/results/eye.png -------------------------------------------------------------------------------- /results/hair.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyungkwonko/StyleSpace-pytorch/c001d4ba97425b95f75dce725032701dd5a776d7/results/hair.png -------------------------------------------------------------------------------- /results/lip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyungkwonko/StyleSpace-pytorch/c001d4ba97425b95f75dce725032701dd5a776d7/results/lip.png -------------------------------------------------------------------------------- /results/mouth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyungkwonko/StyleSpace-pytorch/c001d4ba97425b95f75dce725032701dd5a776d7/results/mouth.png -------------------------------------------------------------------------------- /stylespace.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | from PIL import Image 5 | from model import Generator 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from torch.nn import functional as F 10 | 11 | 12 | def conv_warper(layer, input, style, noise): 13 | conv = layer.conv 14 | batch, in_channel, height, width = input.shape 15 | 16 | style = style.view(batch, 1, in_channel, 1, 1) # reshape (e.g., 512 --> 1,512,1,1) 17 | weight = conv.scale * conv.weight * style 18 | 19 | if conv.demodulate: 20 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 21 | weight = weight * demod.view(batch, conv.out_channel, 1, 1, 1) 22 | 23 | weight = weight.view( 24 | batch * conv.out_channel, in_channel, conv.kernel_size, conv.kernel_size 25 | ) 26 | 27 | if conv.upsample: 28 | input = input.view(1, batch * in_channel, height, width) 29 | weight = weight.view( 30 | batch, conv.out_channel, in_channel, conv.kernel_size, conv.kernel_size 31 | ) 32 | weight = weight.transpose(1, 2).reshape( 33 | batch * in_channel, conv.out_channel, conv.kernel_size, conv.kernel_size 34 | ) 35 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) 36 | _, _, height, width = out.shape 37 | out = out.view(batch, conv.out_channel, height, width) 38 | out = conv.blur(out) 39 | 40 | elif conv.downsample: 41 | input = conv.blur(input) 42 | _, _, height, width = input.shape 43 | input = input.view(1, batch * in_channel, height, width) 44 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) 45 | _, _, height, width = out.shape 46 | out = out.view(batch, conv.out_channel, height, width) 47 | 48 | else: 49 | input = input.view(1, batch * in_channel, height, width) 50 | out = F.conv2d(input, weight, padding=conv.padding, groups=batch) 51 | _, _, height, width = out.shape 52 | out = out.view(batch, conv.out_channel, height, width) 53 | 54 | out = layer.noise(out, noise=noise) 55 | out = layer.activate(out) 56 | 57 | return out 58 | 59 | 60 | def encoder(G, noise): 61 | styles = [noise] # (1, 512) 62 | style_space = [] 63 | 64 | styles = [G.style(s) for s in styles] 65 | 66 | noise = [getattr(G.noises, 'noise_{}'.format(i)) for i in range(G.num_layers)] 67 | inject_index = G.n_latent 68 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) # (18, 512) 69 | style_space.append(G.conv1.conv.modulation(latent[:, 0])) # () 70 | 71 | i = 1 72 | 73 | # EqualLinear layers to fit the channel dimension (e.g., 512 --> 64) 74 | for conv1, conv2 in zip(G.convs[::2], G.convs[1::2]): 75 | style_space.append(conv1.conv.modulation(latent[:, i])) 76 | style_space.append(conv2.conv.modulation(latent[:, i+1])) 77 | i += 2 78 | return style_space, latent, noise 79 | 80 | 81 | def decoder(G, style_space, latent, noise): 82 | out = G.input(latent) 83 | out = conv_warper(G.conv1, out, style_space[0], noise[0]) 84 | skip = G.to_rgb1(out, latent[:, 1]) 85 | 86 | i = 1 87 | for conv1, conv2, noise1, noise2, to_rgb in zip( 88 | G.convs[::2], G.convs[1::2], noise[1::2], noise[2::2], G.to_rgbs 89 | ): 90 | out = conv_warper(conv1, out, style_space[i], noise=noise1) 91 | out = conv_warper(conv2, out, style_space[i+1], noise=noise2) 92 | skip = to_rgb(out, latent[:, i + 2], skip) 93 | 94 | i += 2 95 | 96 | image = skip 97 | 98 | return image 99 | 100 | 101 | def generate_img(generator, input, layer_no, channel_no, degree=30): 102 | style_space, latent, noise = encoder(generator, input) # len(style_space) = 17 103 | style_space[index[layer_no]][:, channel_no] += degree 104 | image = decoder(generator, style_space, latent, noise) 105 | return image 106 | 107 | 108 | def save_fig(output, name, size=128): 109 | output = (output + 1)/2 110 | output = torch.clamp(output, 0, 1) 111 | if output.shape[1] == 1: 112 | output = torch.cat([output, output, output], 1) 113 | output = output[0].detach().cpu().permute(1,2,0).numpy() 114 | output = (output*255).astype(np.uint8) 115 | im = Image.fromarray(output).resize((size,size), Image.ANTIALIAS) 116 | im.save(name) 117 | 118 | 119 | if __name__ =='__main__': 120 | 121 | parser = argparse.ArgumentParser(description="Generate samples from the generator") 122 | 123 | parser.add_argument("--latent", type=int, default=512) 124 | parser.add_argument("--n_mlp", type=int, default=8) 125 | parser.add_argument("--ckpt", type=str, default="checkpoint/stylegan2-ffhq-config-f.pt") 126 | parser.add_argument("--out_dir", type=str, default='sample') 127 | parser.add_argument("--channel_multiplier", type=int, default=2) 128 | parser.add_argument("--seed", type=int, default=9) 129 | parser.add_argument("--save_all_attr", type=int, default=0) 130 | 131 | args = parser.parse_args() 132 | 133 | generator = Generator(size= 1024, style_dim=args.latent, n_mlp=args.n_mlp, channel_multiplier=args.channel_multiplier) 134 | generator.load_state_dict(torch.load(args.ckpt)['g_ema'], strict=False) 135 | generator.eval() 136 | generator.cuda() 137 | 138 | print(generator) 139 | 140 | index = [0,1,1,2,2,3,4,4,5,6,6,7,8,8,9,10,10,11,12,12,13,14,14,15,16,16] 141 | s_channel = [ 142 | 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 143 | 512, 512, 256, 256, 256, 128, 128, 128, 64, 64, 64, 32, 32 144 | ] 145 | 146 | os.makedirs(args.out_dir, exist_ok=True) 147 | 148 | # default image generation 149 | torch.manual_seed(args.seed) 150 | input = torch.randn(1, args.latent).cuda() 151 | image, _ = generator([input], False) 152 | save_fig(image, os.path.join(args.out_dir, f'{str(args.seed).zfill(6)}_default.png')) 153 | 154 | if args.save_all_attr: 155 | # 1. SAVE_ALL ATTR MANIPUlATION RESULT: Let's find out 156 | # TAKES SOME TIME 157 | for ix in range(len(index)): 158 | os.makedirs(os.path.join(args.out_dir, ix), exist_ok=True) 159 | for i in tqdm(range(s_channel[ix])): 160 | image = generate_img(generator, input, layer_no=ix, channel_no=i, degree=30) 161 | save_fig(image, os.path.join(args.out_dir, ix, f'{str(args.seed).zfill(6)}_{ix}_{i}.png')) 162 | else: 163 | # 2. MANIPULATE SPECIFIC ATTRIBUTE 164 | # pose (?) 165 | for i in [-30, -10, 10, 30]: 166 | image = generate_img(generator, input, layer_no=3, channel_no=95, degree=i) 167 | save_fig(image, os.path.join(args.out_dir, f'{str(args.seed).zfill(6)}_pose_{i}.png')) 168 | 169 | # eye 170 | image = generate_img(generator, input, layer_no=9, channel_no=409, degree=10) 171 | save_fig(image, os.path.join(args.out_dir, f'{str(args.seed).zfill(6)}_eye.png')) 172 | 173 | # hair 174 | image = generate_img(generator, input, layer_no=12, channel_no=330, degree=-50) 175 | save_fig(image, os.path.join(args.out_dir, f'{str(args.seed).zfill(6)}_hair.png')) 176 | 177 | # mouth 178 | image = generate_img(generator, input, layer_no=6, channel_no=259, degree=-20) 179 | save_fig(image, os.path.join(args.out_dir, f'{str(args.seed).zfill(6)}_mouth.png')) 180 | 181 | # lip 182 | image = generate_img(generator, input, layer_no=15, channel_no=45, degree=-3) 183 | save_fig(image, os.path.join(args.out_dir, f'{str(args.seed).zfill(6)}_lip.png')) 184 | 185 | print("generation complete...!") --------------------------------------------------------------------------------