├── .gitignore ├── LICENSE ├── README.md ├── overview.jpg ├── style_fusion_simple.py ├── stylefusion.ipynb ├── stylefusion ├── fusion_net.py ├── sf_hierarchy.py └── sf_stylegan2.py ├── stylegan2 ├── __init__.py ├── model.py └── op │ ├── __init__.py │ ├── conv2d_gradfix.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu └── weights ├── car └── .placeholder ├── car_weights.json ├── church └── .placeholder ├── church_weights.json ├── ffhq └── .placeholder └── ffhq_weights.json /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | weights/ffhq/* 3 | weights/car/* 4 | weights/church/* 5 | /.ipynb_checkpoints/ 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Omer Kafri 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StyleFusion: A Generative Model for Disentangling Spatial Segments 2 | 3 | ![](overview.jpg) 4 | 5 | > **StyleFusion: A Generative Model for Disentangling Spatial Segments**
6 | > Omer Kafri, Or Patashnik, Yuval Alaluf, Daniel Cohen-Or
7 | > https://arxiv.org/abs/2107.07437
8 | > 9 | >**Abstract:** We present StyleFusion, a new mapping architecture for StyleGAN, which takes as input a number of latent codes and fuses them into a single style code. Inserting the resulting style code into a pre-trained StyleGAN generator results in a single harmonized image in which each semantic region is controlled by one of the input latent codes. Effectively, StyleFusion yields a disentangled representation of the image, providing fine-grained control over each region of the generated image. Moreover, to help facilitate global control over the generated image, a special input latent code is incorporated into the fused representation. StyleFusion operates in a hierarchical manner, where each level is tasked with learning to disentangle a pair of image regions (e.g., the car body and wheels). The resulting learned disentanglement allows one to modify both local, fine-grained semantics (e.g., facial features) as well as more global features (e.g., pose and background), providing improved flexibility in the synthesis process. As a natural extension, StyleFusion enables one to perform semantically-aware cross-image mixing of regions that are not necessarily aligned. Finally, we demonstrate how StyleFusion can be paired with existing editing techniques to more faithfully constrain the edit to the user's region of interest. 10 | 11 | ## Citation 12 | 13 | If you use this code for your research, please cite our paper: 14 | 15 | ``` 16 | @misc{kafri2021stylefusion, 17 | title={StyleFusion: A Generative Model for Disentangling Spatial Segments}, 18 | author={Omer Kafri and Or Patashnik and Yuval Alaluf and Daniel Cohen-Or}, 19 | year={2021}, 20 | eprint={2107.07437}, 21 | archivePrefix={arXiv}, 22 | primaryClass={cs.CV} 23 | } 24 | ``` 25 | -------------------------------------------------------------------------------- /overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OmerKafri/StyleFusion/c1aa57267a192b3c50ee501339ed50a0aedb35fe/overview.jpg -------------------------------------------------------------------------------- /style_fusion_simple.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | import torch 5 | import torch.nn.functional 6 | 7 | sys.path.append(".") 8 | sys.path.append("..") 9 | 10 | from stylefusion.sf_stylegan2 import SFGenerator 11 | from stylefusion.sf_hierarchy import SFHierarchyFFHQ, SFHierarchyCar, SFHierarchyChurch 12 | from PIL import Image 13 | 14 | 15 | def tensor2im(var): 16 | var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() 17 | var = ((var + 1) / 2) 18 | var[var < 0] = 0 19 | var[var > 1] = 1 20 | var = var * 255 21 | return Image.fromarray(var.astype('uint8')) 22 | 23 | 24 | class StyleFusionSimple: 25 | def __init__(self, stylegan_type, stylegan_weights, fusion_nets_weights): 26 | self.stylegan_type = stylegan_type 27 | if self.stylegan_type == "ffhq": 28 | self.truncation = 0.7 29 | self.stylegan_size = 1024 30 | self.stylegan_layers = 18 31 | elif self.stylegan_type == "car": 32 | self.truncation = 0.5 33 | self.stylegan_size = 512 34 | self.stylegan_layers = 16 35 | elif self.stylegan_type == "church": 36 | self.truncation = 0.5 37 | self.stylegan_size = 256 38 | self.stylegan_layers = 14 39 | 40 | self.device = 'cuda:0' 41 | 42 | stylegan_ckpt = torch.load(stylegan_weights, map_location='cpu') 43 | self.original_net = SFGenerator(self.stylegan_size, 512, 8) 44 | self.original_net.load_state_dict(stylegan_ckpt['g_ema'], strict=True) 45 | 46 | self.original_net.to(self.device) 47 | 48 | with torch.no_grad(): 49 | self.mean_latent = self.original_net.mean_latent(4096) 50 | 51 | if self.stylegan_type == "ffhq": 52 | self.sf_hierarchy = SFHierarchyFFHQ() 53 | self.base_blender = self.sf_hierarchy.nodes["all"] 54 | elif self.stylegan_type == "car": 55 | self.sf_hierarchy = SFHierarchyCar() 56 | self.base_blender = self.sf_hierarchy.nodes["all"] 57 | elif self.stylegan_type == "church": 58 | self.sf_hierarchy = SFHierarchyChurch() 59 | self.base_blender = self.sf_hierarchy.nodes["all"] 60 | 61 | with open(fusion_nets_weights, 'r') as f: 62 | fusion_nets_paths = json.load(f) 63 | 64 | keys = fusion_nets_paths.keys() 65 | 66 | for key in keys: 67 | self.sf_hierarchy.nodes[key].load_fusion_net(fusion_nets_paths[key]) 68 | self.sf_hierarchy.nodes[key].fusion_net.to(self.device) 69 | self.sf_hierarchy.nodes[key].fusion_net.eval() 70 | 71 | def generate_img(self, base_latent, latents_type="z", hair=None, face=None, background=None, all=None, mouth=None, 72 | eyes=None, wheels=None, car=None, bg_top=None, bg_bottom=None): 73 | s_dict = dict() 74 | parts = self.sf_hierarchy.nodes["all"].get_all_active_parts() 75 | for part in parts: 76 | s_dict[part] = self.general_latent_to_s(base_latent, latents_type) 77 | 78 | def swap(value, keys): 79 | if value is None: 80 | return 81 | for k in keys: 82 | s_dict[k] = self.general_latent_to_s(value, latents_type) 83 | 84 | swap(hair, ["bg_hair_clothes", "hair"]) 85 | swap(face, ["face", "eyes", "skin_mouth", "mouth", "skin", "shirt"]) 86 | swap(background, ["background", "background_top", "background_bottom", "bg"]) 87 | swap(all, ["all"]) 88 | swap(mouth, ["skin_mouth", "face"]) 89 | swap(eyes, ["eyes", "face"]) 90 | swap(wheels, ["wheels"]) 91 | swap(car, ["car", "body", "wheels", "car_body"]) 92 | swap(bg_top, ["background_top"]) 93 | swap(bg_bottom, ["background_bottom"]) 94 | 95 | return self.s_dict_to_image(s_dict)[0] 96 | 97 | def seed_to_z(self, seed): 98 | torch.manual_seed(seed[0]) 99 | z_regular = torch.randn((seed[1] + 1, 1, 512), device=self.device) 100 | return z_regular[seed[1]] 101 | 102 | def z_to_s(self, z): 103 | return self.original_net([z], 104 | truncation=self.truncation, truncation_latent=self.mean_latent, 105 | randomize_noise=False, return_style_vector=True) 106 | 107 | def z_to_w_plus(self, z): 108 | _, res = self.original_net([z], 109 | truncation=self.truncation, truncation_latent=self.mean_latent, 110 | randomize_noise=False, return_latents=True) 111 | return res[0] 112 | 113 | def w_plus_to_s(self, w_plus, truncation): 114 | return self.original_net([w_plus], input_is_latent=True, 115 | truncation=truncation, truncation_latent=self.mean_latent, 116 | randomize_noise=False, return_style_vector=True) 117 | 118 | def general_latent_to_s(self, l, latent_type): 119 | assert latent_type in ["z", "w", "w+", "s"] 120 | 121 | if latent_type == "z": 122 | assert l.size() == (1, 512) 123 | return self.z_to_s(l) 124 | elif latent_type == "w" or latent_type == "w+": 125 | assert l.size() == (1, 512) or l.size() == (1, self.stylegan_layers, 512) 126 | if l.dim() == 2: 127 | return self.w_plus_to_s(l.unsqueeze(0).repeat(1, self.stylegan_layers, 1), truncation=1) 128 | else: 129 | return self.w_plus_to_s(l, truncation=1) 130 | else: 131 | return l 132 | 133 | def s_to_image(self, s): 134 | img, _ = self.original_net([torch.zeros(1, 512, device=self.device)], 135 | randomize_noise=False, style_vector=s) 136 | return img 137 | 138 | def w_plus_to_image(self, w_plus): 139 | s = self.w_plus_to_s(w_plus, truncation=1) 140 | return self.s_to_image(s) 141 | 142 | def z_to_image(self, z): 143 | s = self.z_to_s(z) 144 | return self.s_to_image(s) 145 | 146 | def s_dict_to_image(self, s_dict): 147 | s = self.base_blender.forward(s_dict) 148 | return self.s_to_image(s) 149 | 150 | def w_plus_dict_to_image(self, w_plus_dict, truncation=1): 151 | s_dict = dict() 152 | for key in w_plus_dict.keys(): 153 | s_dict[key] = self.w_plus_to_s(w_plus_dict[key], truncation=truncation) 154 | return self.s_dict_to_image(s_dict) 155 | 156 | def z_dict_to_image(self, z_dict): 157 | s_dict = dict() 158 | for key in z_dict.keys(): 159 | s_dict[key] = self.z_to_s(z_dict[key]) 160 | return self.s_dict_to_image(s_dict) 161 | -------------------------------------------------------------------------------- /stylefusion/fusion_net.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from stylegan2.model import EqualLinear 7 | 8 | MAX_LAYERS = 18 9 | DEFUALT_COMMON_LAYERS = 5 10 | DEFAULT_INDEX_BITS = 10 11 | 12 | 13 | class LatentBlender(nn.Module): 14 | def __init__(self, size, index_bits=DEFAULT_INDEX_BITS, extras=0): 15 | super(LatentBlender, self).__init__() 16 | layers = [] 17 | 18 | self.size = size 19 | self.index_bits = index_bits 20 | 21 | sizes = [size * 2 + index_bits * (MAX_LAYERS + 1) + extras * size, size * 5, size * 5, size * 5, size, size] 22 | 23 | for i in range(len(sizes) - 1): 24 | layers.append( 25 | EqualLinear( 26 | sizes[i], sizes[i + 1], lr_mul=0.01, activation='fused_lrelu' 27 | ) 28 | ) 29 | 30 | self.disentangle_net = nn.Sequential(*layers) 31 | 32 | def forward(self, a, b, i=None, rgb=False, extra=None): 33 | x = torch.cat((a[0], b[0])) 34 | if not extra is None: 35 | x = torch.cat((x, extra)) 36 | if not (i is None): 37 | indicator = torch.zeros(MAX_LAYERS + 1, device="cuda:0") 38 | indicator[i] = 1 39 | if rgb: 40 | indicator[-1] = 1 41 | x = torch.cat([x, indicator.repeat(self.index_bits)]) 42 | x = self.disentangle_net(x) 43 | x = torch.sigmoid(x) 44 | x = x.unsqueeze(0) 45 | 46 | return a[0] + x * (b[0] - a[0]) # Fusion procedure 47 | 48 | 49 | class FusionNet(nn.Module): 50 | def __init__(self, index_bits=DEFAULT_INDEX_BITS, common_layers=DEFUALT_COMMON_LAYERS): 51 | super(FusionNet, self).__init__() 52 | 53 | self.max_pools = [nn.MaxPool1d(512 // x) for x in [512, 256, 128, 64, 32]] 54 | self.upsample = nn.Upsample(size=512, mode='nearest') 55 | 56 | self.blender_segments = LatentBlender(512, index_bits=index_bits, extras=1) 57 | self.blender_common = LatentBlender(512, index_bits=index_bits) 58 | self.common_layers = common_layers 59 | 60 | def up(self, x): 61 | return self.upsample(x.unsqueeze(0)).squeeze(0) 62 | 63 | def pool(self, x, size): 64 | size_index = 9 - int(math.log2(size)) 65 | return self.max_pools[size_index](x.unsqueeze(0)).squeeze(0) 66 | 67 | def forward(self, s0, s1, s2, only_common=False, only_segments=False): 68 | if only_common: 69 | assert s1 is None 70 | if only_segments: 71 | assert s2 is None 72 | res = [[],[]] 73 | for rgb in [0, 1]: 74 | for i in range(len(s0[rgb])): 75 | s0_up = self.up(s0[rgb][i]) 76 | if not only_common: 77 | s1_up = self.up(s1[rgb][i]) 78 | if not only_segments: 79 | s2_up = self.up(s2[rgb][i]) 80 | 81 | skip_common_layer = False 82 | if not (self.common_layers is None): 83 | skip_common_layer = i >= self.common_layers 84 | 85 | if only_common: 86 | if skip_common_layer: 87 | res[rgb].append(s0[rgb][i]) 88 | else: 89 | x = self.blender_common(s0_up, s2_up, i=i, rgb=(rgb==1)) 90 | res[rgb].append(self.pool(x, s0[rgb][i].size()[1])) 91 | elif only_segments or skip_common_layer: 92 | x = self.blender_segments(s0_up, s1_up, i=i, rgb=False, extra=torch.zeros_like(s1_up[0])) 93 | res[rgb].append(self.pool(x, s0[rgb][i].size()[1])) 94 | else: 95 | x0 = self.blender_common(s0_up, s2_up, i=i, rgb=False) 96 | x0 = self.pool(x0, s0[rgb][i].size()[1]) 97 | x1 = self.blender_common(s1_up, s2_up, i=i, rgb=False) 98 | x1 = self.pool(x1, s0[rgb][i].size()[1]) 99 | 100 | x = self.blender_segments(self.up(x0), self.up(x1), i=i, rgb=False, extra=s2_up[0]) 101 | res[rgb].append(self.pool(x, s0[rgb][i].size()[1])) 102 | 103 | return res 104 | -------------------------------------------------------------------------------- /stylefusion/sf_hierarchy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from stylefusion.fusion_net import FusionNet 4 | 5 | 6 | class SFHierarchyFFHQ: 7 | def __init__(self): 8 | self.nodes = dict() 9 | self.nodes["clothes"] = SFNode("clothes") 10 | self.nodes["mouth"] = SFNode("mouth") 11 | self.nodes["eyes"] = SFNode("eyes") 12 | self.nodes["bg"] = SFNode("bg") 13 | self.nodes["hair"] = SFNode("hair") 14 | self.nodes["skin"] = SFNode("skin") 15 | self.nodes["skin_mouth"] = SFNode("skin_mouth", child1=self.nodes["mouth"], child2=self.nodes["skin"]) 16 | self.nodes["face"] = SFNode("face", child1=self.nodes["skin_mouth"], child2=self.nodes["eyes"]) 17 | self.nodes["bg_clothes"] = SFNode("bg_clothes", child1=self.nodes["clothes"], child2=self.nodes["bg"]) 18 | self.nodes["bg_hair_clothes"] = SFNode("bg_hair_clothes", 19 | child1=self.nodes["bg_clothes"], child2=self.nodes["hair"]) 20 | self.nodes["all"] = SFNode("all", child1=self.nodes["face"], child2=self.nodes["bg_hair_clothes"]) 21 | 22 | 23 | class SFHierarchyCar: 24 | def __init__(self): 25 | self.nodes = dict() 26 | self.nodes["wheels"] = SFNode("wheels") 27 | self.nodes["car_body"] = SFNode("car_body") 28 | self.nodes["background_top"] = SFNode("background_top") 29 | self.nodes["background_bottom"] = SFNode("background_bottom") 30 | self.nodes["car"] = SFNode("car", child1=self.nodes["car_body"], child2=self.nodes["wheels"]) 31 | self.nodes["background"] = SFNode("background", 32 | child1=self.nodes["background_top"], child2=self.nodes["background_bottom"]) 33 | self.nodes["all"] = SFNode("all", child1=self.nodes["car"], child2=self.nodes["background"]) 34 | 35 | 36 | class SFHierarchyChurch: 37 | def __init__(self=None): 38 | self.nodes = dict() 39 | self.nodes["church"] = SFNode("church") 40 | self.nodes["background"] = SFNode("background") 41 | self.nodes["all"] = SFNode("all", child1=self.nodes["church"], child2=self.nodes["background"]) 42 | 43 | 44 | class SFNode: 45 | def __init__(self, name, child1=None, child2=None): 46 | self.name = name 47 | self.child1 = child1 48 | self.child2 = child2 49 | self.fusion_net = None 50 | if child1 is None or child2 is None: 51 | assert child1 is None and child2 is None 52 | self._leaf = True 53 | else: 54 | self._leaf = False 55 | 56 | def get_all_parts(self): 57 | if self._leaf: 58 | return [self.name] 59 | else: 60 | return [self.name] + self.child1.get_all_parts() + self.child2.get_all_parts() 61 | 62 | def get_all_active_parts(self): 63 | if self.fusion_net is None: 64 | return [self.name] 65 | else: 66 | return [self.name] + self.child1.get_all_active_parts() + self.child2.get_all_active_parts() 67 | 68 | def get_fusion_nets_amount(self): 69 | if self.fusion_net is None: 70 | return 0 71 | if self._leaf: 72 | return 1 73 | else: 74 | return 1 + self.child1.get_fusion_nets_amount() + self.child2.get_fusion_nets_amount() 75 | 76 | def get_fusion_nets(self): 77 | if self.fusion_net is None: 78 | return [] 79 | if self._leaf: 80 | return [self.fusion_net] 81 | else: 82 | return [self.fusion_net] + self.child1.get_fusion_nets() + self.child2.get_fusion_nets() 83 | 84 | def forward(self, s_dict): 85 | if self.fusion_net is None: 86 | return s_dict[self.name] 87 | 88 | if not (self.child1.name in s_dict.keys() and self.child2.name in s_dict.keys()): 89 | return s_dict[self.name] 90 | 91 | return self.fusion_net( 92 | self.child1.forward(s_dict), 93 | self.child2.forward(s_dict), 94 | s_dict[self.name]) 95 | 96 | def load_fusion_net(self, path): 97 | data = torch.load(path) 98 | 99 | self.fusion_net = FusionNet() 100 | 101 | if "state_dict" in data.keys(): 102 | data = data["state_dict"] 103 | 104 | self.fusion_net.load_state_dict(data) 105 | -------------------------------------------------------------------------------- /stylefusion/sf_stylegan2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from stylegan2.model import * 5 | 6 | 7 | class SFModulatedConv2d(ModulatedConv2d): 8 | def forward(self, input, style, style_vector=None): 9 | batch, in_channel, height, width = input.shape 10 | 11 | if not self.fused: 12 | weight = self.scale * self.weight.squeeze(0) 13 | style = self.modulation(style) 14 | 15 | if self.demodulate: 16 | w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1) 17 | dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt() 18 | 19 | input = input * style.reshape(batch, in_channel, 1, 1) 20 | 21 | if self.upsample: 22 | weight = weight.transpose(0, 1) 23 | out = conv2d_gradfix.conv_transpose2d( 24 | input, weight, padding=0, stride=2 25 | ) 26 | out = self.blur(out) 27 | 28 | elif self.downsample: 29 | input = self.blur(input) 30 | out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2) 31 | 32 | else: 33 | out = conv2d_gradfix.conv2d(input, weight, padding=self.padding) 34 | 35 | if self.demodulate: 36 | out = out * dcoefs.view(batch, -1, 1, 1) 37 | 38 | return out 39 | 40 | if style_vector is None: 41 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1) 42 | else: 43 | style = style_vector.view(batch, 1, in_channel, 1, 1) 44 | 45 | weight = self.scale * self.weight * style 46 | 47 | if self.demodulate: 48 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 49 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 50 | 51 | weight = weight.view( 52 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 53 | ) 54 | 55 | if self.upsample: 56 | input = input.view(1, batch * in_channel, height, width) 57 | weight = weight.view( 58 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 59 | ) 60 | weight = weight.transpose(1, 2).reshape( 61 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 62 | ) 63 | out = conv2d_gradfix.conv_transpose2d( 64 | input, weight, padding=0, stride=2, groups=batch 65 | ) 66 | _, _, height, width = out.shape 67 | out = out.view(batch, self.out_channel, height, width) 68 | out = self.blur(out) 69 | 70 | elif self.downsample: 71 | input = self.blur(input) 72 | _, _, height, width = input.shape 73 | input = input.view(1, batch * in_channel, height, width) 74 | out = conv2d_gradfix.conv2d( 75 | input, weight, padding=0, stride=2, groups=batch 76 | ) 77 | _, _, height, width = out.shape 78 | out = out.view(batch, self.out_channel, height, width) 79 | 80 | else: 81 | input = input.view(1, batch * in_channel, height, width) 82 | out = conv2d_gradfix.conv2d( 83 | input, weight, padding=self.padding, groups=batch 84 | ) 85 | _, _, height, width = out.shape 86 | out = out.view(batch, self.out_channel, height, width) 87 | 88 | return out 89 | 90 | 91 | class SFStyledConv(StyledConv): 92 | def __init__( 93 | self, 94 | in_channel, 95 | out_channel, 96 | kernel_size, 97 | style_dim, 98 | upsample=False, 99 | blur_kernel=[1, 3, 3, 1], 100 | demodulate=True, 101 | ): 102 | super(StyledConv, self).__init__() 103 | 104 | self.conv = SFModulatedConv2d( 105 | in_channel, 106 | out_channel, 107 | kernel_size, 108 | style_dim, 109 | upsample=upsample, 110 | blur_kernel=blur_kernel, 111 | demodulate=demodulate, 112 | ) 113 | 114 | self.noise = NoiseInjection() 115 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 116 | # self.activate = ScaledLeakyReLU(0.2) 117 | self.activate = FusedLeakyReLU(out_channel) 118 | 119 | def forward(self, input, style, noise=None, style_vector=None): 120 | out = self.conv(input, style, style_vector=style_vector) 121 | out = self.noise(out, noise=noise) 122 | # out = out + self.bias 123 | out = self.activate(out) 124 | 125 | return out 126 | 127 | def get_style_vector(self, style): 128 | return self.conv.modulation(style) 129 | 130 | 131 | class SFToRGB(ToRGB): 132 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): 133 | super(ToRGB, self).__init__() 134 | 135 | if upsample: 136 | self.upsample = Upsample(blur_kernel) 137 | 138 | self.conv = SFModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) 139 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) 140 | 141 | def forward(self, input, style, skip=None, style_vector=None): 142 | out = self.conv(input, style, style_vector=style_vector) 143 | out = out + self.bias 144 | 145 | if skip is not None: 146 | skip = self.upsample(skip) 147 | 148 | out = out + skip 149 | 150 | return out 151 | 152 | def get_style_vector(self, style): 153 | return self.conv.modulation(style) 154 | 155 | 156 | class SFGenerator(Generator): 157 | def __init__( 158 | self, 159 | size, 160 | style_dim, 161 | n_mlp, 162 | channel_multiplier=2, 163 | blur_kernel=[1, 3, 3, 1], 164 | lr_mlp=0.01, 165 | ): 166 | super(Generator, self).__init__() 167 | 168 | self.size = size 169 | 170 | self.style_dim = style_dim 171 | 172 | layers = [PixelNorm()] 173 | 174 | for i in range(n_mlp): 175 | layers.append( 176 | EqualLinear( 177 | style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" 178 | ) 179 | ) 180 | 181 | self.style = nn.Sequential(*layers) 182 | 183 | self.channels = { 184 | 4: 512, 185 | 8: 512, 186 | 16: 512, 187 | 32: 512, 188 | 64: 256 * channel_multiplier, 189 | 128: 128 * channel_multiplier, 190 | 256: 64 * channel_multiplier, 191 | 512: 32 * channel_multiplier, 192 | 1024: 16 * channel_multiplier, 193 | } 194 | 195 | self.input = ConstantInput(self.channels[4]) 196 | self.conv1 = SFStyledConv( 197 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel 198 | ) 199 | self.to_rgb1 = SFToRGB(self.channels[4], style_dim, upsample=False) 200 | 201 | self.log_size = int(math.log(size, 2)) 202 | self.num_layers = (self.log_size - 2) * 2 + 1 203 | 204 | self.convs = nn.ModuleList() 205 | self.upsamples = nn.ModuleList() 206 | self.to_rgbs = nn.ModuleList() 207 | self.noises = nn.Module() 208 | 209 | in_channel = self.channels[4] 210 | 211 | for layer_idx in range(self.num_layers): 212 | res = (layer_idx + 5) // 2 213 | shape = [1, 1, 2 ** res, 2 ** res] 214 | self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape)) 215 | 216 | for i in range(3, self.log_size + 1): 217 | out_channel = self.channels[2 ** i] 218 | 219 | self.convs.append( 220 | SFStyledConv( 221 | in_channel, 222 | out_channel, 223 | 3, 224 | style_dim, 225 | upsample=True, 226 | blur_kernel=blur_kernel, 227 | ) 228 | ) 229 | 230 | self.convs.append( 231 | SFStyledConv( 232 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel 233 | ) 234 | ) 235 | 236 | self.to_rgbs.append(SFToRGB(out_channel, style_dim)) 237 | 238 | in_channel = out_channel 239 | 240 | self.n_latent = self.log_size * 2 - 2 241 | 242 | def forward( 243 | self, 244 | styles, 245 | return_latents=False, 246 | return_style_vector=False, 247 | inject_index=None, 248 | truncation=1, 249 | truncation_latent=None, 250 | input_is_latent=False, 251 | style_vector=None, 252 | noise=None, 253 | randomize_noise=True, 254 | ): 255 | if not input_is_latent: 256 | styles = [self.style(s) for s in styles] 257 | 258 | if noise is None: 259 | if randomize_noise: 260 | noise = [None] * self.num_layers 261 | else: 262 | noise = [ 263 | getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) 264 | ] 265 | 266 | if truncation < 1: 267 | style_t = [] 268 | 269 | for style in styles: 270 | style_t.append( 271 | truncation_latent + truncation * (style - truncation_latent) 272 | ) 273 | 274 | styles = style_t 275 | 276 | if len(styles) < 2: 277 | inject_index = self.n_latent 278 | 279 | if styles[0].ndim < 3: 280 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 281 | else: 282 | latent = styles[0] 283 | 284 | else: 285 | if inject_index is None: 286 | inject_index = 2#random.randint(1, self.n_latent - 1) 287 | 288 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 289 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) 290 | 291 | latent = torch.cat([latent, latent2], 1) 292 | 293 | if return_style_vector: 294 | s = [self.conv1.get_style_vector(latent[:, 0])] + \ 295 | [self.convs[i].get_style_vector(latent[:, i + 1]) for i in range(len(self.convs))] 296 | s_rgb = [self.to_rgb1.get_style_vector(latent[:, 1])] + \ 297 | [self.to_rgbs[i].get_style_vector(latent[:, i * 2 + 2]) for i in range(len(self.to_rgbs))] 298 | return [s, s_rgb] 299 | 300 | if style_vector is None: 301 | style_vector = [[None] * (len(self.convs) + 1) , [None] * (len(self.to_rgbs) + 1)] 302 | 303 | out = self.input(latent) 304 | out = self.conv1(out, latent[:, 0], noise=noise[0], style_vector=style_vector[0][0]) 305 | 306 | skip = self.to_rgb1(out, latent[:, 1], style_vector=style_vector[1][0]) 307 | 308 | i = 1 309 | for conv1, conv2, noise1, noise2, to_rgb in zip( 310 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs 311 | ): 312 | out = conv1(out, latent[:, i], noise=noise1, style_vector=style_vector[0][i]) 313 | out = conv2(out, latent[:, i + 1], noise=noise2, style_vector=style_vector[0][i + 1]) 314 | skip = to_rgb(out, latent[:, i + 2], skip, style_vector=style_vector[1][i//2 + 1]) 315 | 316 | i += 2 317 | 318 | image = skip 319 | 320 | if return_latents: 321 | return image, latent 322 | else: 323 | return image, None -------------------------------------------------------------------------------- /stylegan2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OmerKafri/StyleFusion/c1aa57267a192b3c50ee501339ed50a0aedb35fe/stylegan2/__init__.py -------------------------------------------------------------------------------- /stylegan2/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | from stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix 9 | 10 | 11 | class PixelNorm(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def forward(self, input): 16 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 17 | 18 | 19 | def make_kernel(k): 20 | k = torch.tensor(k, dtype=torch.float32) 21 | 22 | if k.ndim == 1: 23 | k = k[None, :] * k[:, None] 24 | 25 | k /= k.sum() 26 | 27 | return k 28 | 29 | 30 | class Upsample(nn.Module): 31 | def __init__(self, kernel, factor=2): 32 | super().__init__() 33 | 34 | self.factor = factor 35 | kernel = make_kernel(kernel) * (factor ** 2) 36 | self.register_buffer("kernel", kernel) 37 | 38 | p = kernel.shape[0] - factor 39 | 40 | pad0 = (p + 1) // 2 + factor - 1 41 | pad1 = p // 2 42 | 43 | self.pad = (pad0, pad1) 44 | 45 | def forward(self, input): 46 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) 47 | 48 | return out 49 | 50 | 51 | class Downsample(nn.Module): 52 | def __init__(self, kernel, factor=2): 53 | super().__init__() 54 | 55 | self.factor = factor 56 | kernel = make_kernel(kernel) 57 | self.register_buffer("kernel", kernel) 58 | 59 | p = kernel.shape[0] - factor 60 | 61 | pad0 = (p + 1) // 2 62 | pad1 = p // 2 63 | 64 | self.pad = (pad0, pad1) 65 | 66 | def forward(self, input): 67 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) 68 | 69 | return out 70 | 71 | 72 | class Blur(nn.Module): 73 | def __init__(self, kernel, pad, upsample_factor=1): 74 | super().__init__() 75 | 76 | kernel = make_kernel(kernel) 77 | 78 | if upsample_factor > 1: 79 | kernel = kernel * (upsample_factor ** 2) 80 | 81 | self.register_buffer("kernel", kernel) 82 | 83 | self.pad = pad 84 | 85 | def forward(self, input): 86 | out = upfirdn2d(input, self.kernel, pad=self.pad) 87 | 88 | return out 89 | 90 | 91 | class EqualConv2d(nn.Module): 92 | def __init__( 93 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 94 | ): 95 | super().__init__() 96 | 97 | self.weight = nn.Parameter( 98 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 99 | ) 100 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 101 | 102 | self.stride = stride 103 | self.padding = padding 104 | 105 | if bias: 106 | self.bias = nn.Parameter(torch.zeros(out_channel)) 107 | 108 | else: 109 | self.bias = None 110 | 111 | def forward(self, input): 112 | out = conv2d_gradfix.conv2d( 113 | input, 114 | self.weight * self.scale, 115 | bias=self.bias, 116 | stride=self.stride, 117 | padding=self.padding, 118 | ) 119 | 120 | return out 121 | 122 | def __repr__(self): 123 | return ( 124 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," 125 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" 126 | ) 127 | 128 | 129 | class EqualLinear(nn.Module): 130 | def __init__( 131 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 132 | ): 133 | super().__init__() 134 | 135 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 136 | 137 | if bias: 138 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 139 | 140 | else: 141 | self.bias = None 142 | 143 | self.activation = activation 144 | 145 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 146 | self.lr_mul = lr_mul 147 | 148 | def forward(self, input): 149 | if self.activation: 150 | out = F.linear(input, self.weight * self.scale) 151 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 152 | 153 | else: 154 | out = F.linear( 155 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 156 | ) 157 | 158 | return out 159 | 160 | def __repr__(self): 161 | return ( 162 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" 163 | ) 164 | 165 | 166 | class ModulatedConv2d(nn.Module): 167 | def __init__( 168 | self, 169 | in_channel, 170 | out_channel, 171 | kernel_size, 172 | style_dim, 173 | demodulate=True, 174 | upsample=False, 175 | downsample=False, 176 | blur_kernel=[1, 3, 3, 1], 177 | fused=True, 178 | ): 179 | super().__init__() 180 | 181 | self.eps = 1e-8 182 | self.kernel_size = kernel_size 183 | self.in_channel = in_channel 184 | self.out_channel = out_channel 185 | self.upsample = upsample 186 | self.downsample = downsample 187 | 188 | if upsample: 189 | factor = 2 190 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 191 | pad0 = (p + 1) // 2 + factor - 1 192 | pad1 = p // 2 + 1 193 | 194 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) 195 | 196 | if downsample: 197 | factor = 2 198 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 199 | pad0 = (p + 1) // 2 200 | pad1 = p // 2 201 | 202 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 203 | 204 | fan_in = in_channel * kernel_size ** 2 205 | self.scale = 1 / math.sqrt(fan_in) 206 | self.padding = kernel_size // 2 207 | 208 | self.weight = nn.Parameter( 209 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 210 | ) 211 | 212 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 213 | 214 | self.demodulate = demodulate 215 | self.fused = fused 216 | 217 | def __repr__(self): 218 | return ( 219 | f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " 220 | f"upsample={self.upsample}, downsample={self.downsample})" 221 | ) 222 | 223 | def forward(self, input, style): 224 | batch, in_channel, height, width = input.shape 225 | 226 | if not self.fused: 227 | weight = self.scale * self.weight.squeeze(0) 228 | style = self.modulation(style) 229 | 230 | if self.demodulate: 231 | w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1) 232 | dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt() 233 | 234 | input = input * style.reshape(batch, in_channel, 1, 1) 235 | 236 | if self.upsample: 237 | weight = weight.transpose(0, 1) 238 | out = conv2d_gradfix.conv_transpose2d( 239 | input, weight, padding=0, stride=2 240 | ) 241 | out = self.blur(out) 242 | 243 | elif self.downsample: 244 | input = self.blur(input) 245 | out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2) 246 | 247 | else: 248 | out = conv2d_gradfix.conv2d(input, weight, padding=self.padding) 249 | 250 | if self.demodulate: 251 | out = out * dcoefs.view(batch, -1, 1, 1) 252 | 253 | return out 254 | 255 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1) 256 | weight = self.scale * self.weight * style 257 | 258 | if self.demodulate: 259 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 260 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 261 | 262 | weight = weight.view( 263 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 264 | ) 265 | 266 | if self.upsample: 267 | input = input.view(1, batch * in_channel, height, width) 268 | weight = weight.view( 269 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 270 | ) 271 | weight = weight.transpose(1, 2).reshape( 272 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 273 | ) 274 | out = conv2d_gradfix.conv_transpose2d( 275 | input, weight, padding=0, stride=2, groups=batch 276 | ) 277 | _, _, height, width = out.shape 278 | out = out.view(batch, self.out_channel, height, width) 279 | out = self.blur(out) 280 | 281 | elif self.downsample: 282 | input = self.blur(input) 283 | _, _, height, width = input.shape 284 | input = input.view(1, batch * in_channel, height, width) 285 | out = conv2d_gradfix.conv2d( 286 | input, weight, padding=0, stride=2, groups=batch 287 | ) 288 | _, _, height, width = out.shape 289 | out = out.view(batch, self.out_channel, height, width) 290 | 291 | else: 292 | input = input.view(1, batch * in_channel, height, width) 293 | out = conv2d_gradfix.conv2d( 294 | input, weight, padding=self.padding, groups=batch 295 | ) 296 | _, _, height, width = out.shape 297 | out = out.view(batch, self.out_channel, height, width) 298 | 299 | return out 300 | 301 | 302 | class NoiseInjection(nn.Module): 303 | def __init__(self): 304 | super().__init__() 305 | 306 | self.weight = nn.Parameter(torch.zeros(1)) 307 | 308 | def forward(self, image, noise=None): 309 | if noise is None: 310 | batch, _, height, width = image.shape 311 | noise = image.new_empty(batch, 1, height, width).normal_() 312 | 313 | return image + self.weight * noise 314 | 315 | 316 | class ConstantInput(nn.Module): 317 | def __init__(self, channel, size=4): 318 | super().__init__() 319 | 320 | self.input = nn.Parameter(torch.randn(1, channel, size, size)) 321 | 322 | def forward(self, input): 323 | batch = input.shape[0] 324 | out = self.input.repeat(batch, 1, 1, 1) 325 | 326 | return out 327 | 328 | 329 | class StyledConv(nn.Module): 330 | def __init__( 331 | self, 332 | in_channel, 333 | out_channel, 334 | kernel_size, 335 | style_dim, 336 | upsample=False, 337 | blur_kernel=[1, 3, 3, 1], 338 | demodulate=True, 339 | ): 340 | super().__init__() 341 | 342 | self.conv = ModulatedConv2d( 343 | in_channel, 344 | out_channel, 345 | kernel_size, 346 | style_dim, 347 | upsample=upsample, 348 | blur_kernel=blur_kernel, 349 | demodulate=demodulate, 350 | ) 351 | 352 | self.noise = NoiseInjection() 353 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 354 | # self.activate = ScaledLeakyReLU(0.2) 355 | self.activate = FusedLeakyReLU(out_channel) 356 | 357 | def forward(self, input, style, noise=None): 358 | out = self.conv(input, style) 359 | out = self.noise(out, noise=noise) 360 | # out = out + self.bias 361 | out = self.activate(out) 362 | 363 | return out 364 | 365 | 366 | class ToRGB(nn.Module): 367 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): 368 | super().__init__() 369 | 370 | if upsample: 371 | self.upsample = Upsample(blur_kernel) 372 | 373 | self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) 374 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) 375 | 376 | def forward(self, input, style, skip=None): 377 | out = self.conv(input, style) 378 | out = out + self.bias 379 | 380 | if skip is not None: 381 | skip = self.upsample(skip) 382 | 383 | out = out + skip 384 | 385 | return out 386 | 387 | 388 | class Generator(nn.Module): 389 | def __init__( 390 | self, 391 | size, 392 | style_dim, 393 | n_mlp, 394 | channel_multiplier=2, 395 | blur_kernel=[1, 3, 3, 1], 396 | lr_mlp=0.01, 397 | ): 398 | super().__init__() 399 | 400 | self.size = size 401 | 402 | self.style_dim = style_dim 403 | 404 | layers = [PixelNorm()] 405 | 406 | for i in range(n_mlp): 407 | layers.append( 408 | EqualLinear( 409 | style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" 410 | ) 411 | ) 412 | 413 | self.style = nn.Sequential(*layers) 414 | 415 | self.channels = { 416 | 4: 512, 417 | 8: 512, 418 | 16: 512, 419 | 32: 512, 420 | 64: 256 * channel_multiplier, 421 | 128: 128 * channel_multiplier, 422 | 256: 64 * channel_multiplier, 423 | 512: 32 * channel_multiplier, 424 | 1024: 16 * channel_multiplier, 425 | } 426 | 427 | self.input = ConstantInput(self.channels[4]) 428 | self.conv1 = StyledConv( 429 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel 430 | ) 431 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) 432 | 433 | self.log_size = int(math.log(size, 2)) 434 | self.num_layers = (self.log_size - 2) * 2 + 1 435 | 436 | self.convs = nn.ModuleList() 437 | self.upsamples = nn.ModuleList() 438 | self.to_rgbs = nn.ModuleList() 439 | self.noises = nn.Module() 440 | 441 | in_channel = self.channels[4] 442 | 443 | for layer_idx in range(self.num_layers): 444 | res = (layer_idx + 5) // 2 445 | shape = [1, 1, 2 ** res, 2 ** res] 446 | self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape)) 447 | 448 | for i in range(3, self.log_size + 1): 449 | out_channel = self.channels[2 ** i] 450 | 451 | self.convs.append( 452 | StyledConv( 453 | in_channel, 454 | out_channel, 455 | 3, 456 | style_dim, 457 | upsample=True, 458 | blur_kernel=blur_kernel, 459 | ) 460 | ) 461 | 462 | self.convs.append( 463 | StyledConv( 464 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel 465 | ) 466 | ) 467 | 468 | self.to_rgbs.append(ToRGB(out_channel, style_dim)) 469 | 470 | in_channel = out_channel 471 | 472 | self.n_latent = self.log_size * 2 - 2 473 | 474 | def make_noise(self): 475 | device = self.input.input.device 476 | 477 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] 478 | 479 | for i in range(3, self.log_size + 1): 480 | for _ in range(2): 481 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) 482 | 483 | return noises 484 | 485 | def mean_latent(self, n_latent): 486 | latent_in = torch.randn( 487 | n_latent, self.style_dim, device=self.input.input.device 488 | ) 489 | latent = self.style(latent_in).mean(0, keepdim=True) 490 | 491 | return latent 492 | 493 | def get_latent(self, input): 494 | return self.style(input) 495 | 496 | def forward( 497 | self, 498 | styles, 499 | return_latents=False, 500 | inject_index=None, 501 | truncation=1, 502 | truncation_latent=None, 503 | input_is_latent=False, 504 | noise=None, 505 | randomize_noise=True, 506 | ): 507 | if not input_is_latent: 508 | styles = [self.style(s) for s in styles] 509 | 510 | if noise is None: 511 | if randomize_noise: 512 | noise = [None] * self.num_layers 513 | else: 514 | noise = [ 515 | getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) 516 | ] 517 | 518 | if truncation < 1: 519 | style_t = [] 520 | 521 | for style in styles: 522 | style_t.append( 523 | truncation_latent + truncation * (style - truncation_latent) 524 | ) 525 | 526 | styles = style_t 527 | 528 | if len(styles) < 2: 529 | inject_index = self.n_latent 530 | 531 | if styles[0].ndim < 3: 532 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 533 | 534 | else: 535 | latent = styles[0] 536 | 537 | else: 538 | if inject_index is None: 539 | inject_index = random.randint(1, self.n_latent - 1) 540 | 541 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 542 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) 543 | 544 | latent = torch.cat([latent, latent2], 1) 545 | 546 | out = self.input(latent) 547 | out = self.conv1(out, latent[:, 0], noise=noise[0]) 548 | 549 | skip = self.to_rgb1(out, latent[:, 1]) 550 | 551 | i = 1 552 | for conv1, conv2, noise1, noise2, to_rgb in zip( 553 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs 554 | ): 555 | out = conv1(out, latent[:, i], noise=noise1) 556 | out = conv2(out, latent[:, i + 1], noise=noise2) 557 | skip = to_rgb(out, latent[:, i + 2], skip) 558 | 559 | i += 2 560 | 561 | image = skip 562 | 563 | if return_latents: 564 | return image, latent 565 | 566 | else: 567 | return image, None 568 | 569 | 570 | class ConvLayer(nn.Sequential): 571 | def __init__( 572 | self, 573 | in_channel, 574 | out_channel, 575 | kernel_size, 576 | downsample=False, 577 | blur_kernel=[1, 3, 3, 1], 578 | bias=True, 579 | activate=True, 580 | ): 581 | layers = [] 582 | 583 | if downsample: 584 | factor = 2 585 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 586 | pad0 = (p + 1) // 2 587 | pad1 = p // 2 588 | 589 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 590 | 591 | stride = 2 592 | self.padding = 0 593 | 594 | else: 595 | stride = 1 596 | self.padding = kernel_size // 2 597 | 598 | layers.append( 599 | EqualConv2d( 600 | in_channel, 601 | out_channel, 602 | kernel_size, 603 | padding=self.padding, 604 | stride=stride, 605 | bias=bias and not activate, 606 | ) 607 | ) 608 | 609 | if activate: 610 | layers.append(FusedLeakyReLU(out_channel, bias=bias)) 611 | 612 | super().__init__(*layers) 613 | 614 | 615 | class ResBlock(nn.Module): 616 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): 617 | super().__init__() 618 | 619 | self.conv1 = ConvLayer(in_channel, in_channel, 3) 620 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) 621 | 622 | self.skip = ConvLayer( 623 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False 624 | ) 625 | 626 | def forward(self, input): 627 | out = self.conv1(input) 628 | out = self.conv2(out) 629 | 630 | skip = self.skip(input) 631 | out = (out + skip) / math.sqrt(2) 632 | 633 | return out 634 | 635 | 636 | class Discriminator(nn.Module): 637 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): 638 | super().__init__() 639 | 640 | channels = { 641 | 4: 512, 642 | 8: 512, 643 | 16: 512, 644 | 32: 512, 645 | 64: 256 * channel_multiplier, 646 | 128: 128 * channel_multiplier, 647 | 256: 64 * channel_multiplier, 648 | 512: 32 * channel_multiplier, 649 | 1024: 16 * channel_multiplier, 650 | } 651 | 652 | convs = [ConvLayer(3, channels[size], 1)] 653 | 654 | log_size = int(math.log(size, 2)) 655 | 656 | in_channel = channels[size] 657 | 658 | for i in range(log_size, 2, -1): 659 | out_channel = channels[2 ** (i - 1)] 660 | 661 | convs.append(ResBlock(in_channel, out_channel, blur_kernel)) 662 | 663 | in_channel = out_channel 664 | 665 | self.convs = nn.Sequential(*convs) 666 | 667 | self.stddev_group = 4 668 | self.stddev_feat = 1 669 | 670 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) 671 | self.final_linear = nn.Sequential( 672 | EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), 673 | EqualLinear(channels[4], 1), 674 | ) 675 | 676 | def forward(self, input): 677 | out = self.convs(input) 678 | 679 | batch, channel, height, width = out.shape 680 | group = min(batch, self.stddev_group) 681 | stddev = out.view( 682 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width 683 | ) 684 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 685 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 686 | stddev = stddev.repeat(group, 1, height, width) 687 | out = torch.cat([out, stddev], 1) 688 | 689 | out = self.final_conv(out) 690 | 691 | out = out.view(batch, -1) 692 | out = self.final_linear(out) 693 | 694 | return out 695 | 696 | -------------------------------------------------------------------------------- /stylegan2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /stylegan2/op/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import warnings 3 | 4 | import torch 5 | from torch import autograd 6 | from torch.nn import functional as F 7 | 8 | enabled = True 9 | weight_gradients_disabled = False 10 | 11 | 12 | @contextlib.contextmanager 13 | def no_weight_gradients(): 14 | global weight_gradients_disabled 15 | 16 | old = weight_gradients_disabled 17 | weight_gradients_disabled = True 18 | yield 19 | weight_gradients_disabled = old 20 | 21 | 22 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 23 | if could_use_op(input): 24 | return conv2d_gradfix( 25 | transpose=False, 26 | weight_shape=weight.shape, 27 | stride=stride, 28 | padding=padding, 29 | output_padding=0, 30 | dilation=dilation, 31 | groups=groups, 32 | ).apply(input, weight, bias) 33 | 34 | return F.conv2d( 35 | input=input, 36 | weight=weight, 37 | bias=bias, 38 | stride=stride, 39 | padding=padding, 40 | dilation=dilation, 41 | groups=groups, 42 | ) 43 | 44 | 45 | def conv_transpose2d( 46 | input, 47 | weight, 48 | bias=None, 49 | stride=1, 50 | padding=0, 51 | output_padding=0, 52 | groups=1, 53 | dilation=1, 54 | ): 55 | if could_use_op(input): 56 | return conv2d_gradfix( 57 | transpose=True, 58 | weight_shape=weight.shape, 59 | stride=stride, 60 | padding=padding, 61 | output_padding=output_padding, 62 | groups=groups, 63 | dilation=dilation, 64 | ).apply(input, weight, bias) 65 | 66 | return F.conv_transpose2d( 67 | input=input, 68 | weight=weight, 69 | bias=bias, 70 | stride=stride, 71 | padding=padding, 72 | output_padding=output_padding, 73 | dilation=dilation, 74 | groups=groups, 75 | ) 76 | 77 | 78 | def could_use_op(input): 79 | if (not enabled) or (not torch.backends.cudnn.enabled): 80 | return False 81 | 82 | if input.device.type != "cuda": 83 | return False 84 | 85 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]): 86 | return True 87 | 88 | warnings.warn( 89 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()." 90 | ) 91 | 92 | return False 93 | 94 | 95 | def ensure_tuple(xs, ndim): 96 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 97 | 98 | return xs 99 | 100 | 101 | conv2d_gradfix_cache = dict() 102 | 103 | 104 | def conv2d_gradfix( 105 | transpose, weight_shape, stride, padding, output_padding, dilation, groups 106 | ): 107 | ndim = 2 108 | weight_shape = tuple(weight_shape) 109 | stride = ensure_tuple(stride, ndim) 110 | padding = ensure_tuple(padding, ndim) 111 | output_padding = ensure_tuple(output_padding, ndim) 112 | dilation = ensure_tuple(dilation, ndim) 113 | 114 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 115 | if key in conv2d_gradfix_cache: 116 | return conv2d_gradfix_cache[key] 117 | 118 | common_kwargs = dict( 119 | stride=stride, padding=padding, dilation=dilation, groups=groups 120 | ) 121 | 122 | def calc_output_padding(input_shape, output_shape): 123 | if transpose: 124 | return [0, 0] 125 | 126 | return [ 127 | input_shape[i + 2] 128 | - (output_shape[i + 2] - 1) * stride[i] 129 | - (1 - 2 * padding[i]) 130 | - dilation[i] * (weight_shape[i + 2] - 1) 131 | for i in range(ndim) 132 | ] 133 | 134 | class Conv2d(autograd.Function): 135 | @staticmethod 136 | def forward(ctx, input, weight, bias): 137 | if not transpose: 138 | out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 139 | 140 | else: 141 | out = F.conv_transpose2d( 142 | input=input, 143 | weight=weight, 144 | bias=bias, 145 | output_padding=output_padding, 146 | **common_kwargs, 147 | ) 148 | 149 | ctx.save_for_backward(input, weight) 150 | 151 | return out 152 | 153 | @staticmethod 154 | def backward(ctx, grad_output): 155 | input, weight = ctx.saved_tensors 156 | grad_input, grad_weight, grad_bias = None, None, None 157 | 158 | if ctx.needs_input_grad[0]: 159 | p = calc_output_padding( 160 | input_shape=input.shape, output_shape=grad_output.shape 161 | ) 162 | grad_input = conv2d_gradfix( 163 | transpose=(not transpose), 164 | weight_shape=weight_shape, 165 | output_padding=p, 166 | **common_kwargs, 167 | ).apply(grad_output, weight, None) 168 | 169 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 170 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 171 | 172 | if ctx.needs_input_grad[2]: 173 | grad_bias = grad_output.sum((0, 2, 3)) 174 | 175 | return grad_input, grad_weight, grad_bias 176 | 177 | class Conv2dGradWeight(autograd.Function): 178 | @staticmethod 179 | def forward(ctx, grad_output, input): 180 | op = torch._C._jit_get_operation( 181 | "aten::cudnn_convolution_backward_weight" 182 | if not transpose 183 | else "aten::cudnn_convolution_transpose_backward_weight" 184 | ) 185 | flags = [ 186 | torch.backends.cudnn.benchmark, 187 | torch.backends.cudnn.deterministic, 188 | torch.backends.cudnn.allow_tf32, 189 | ] 190 | grad_weight = op( 191 | weight_shape, 192 | grad_output, 193 | input, 194 | padding, 195 | stride, 196 | dilation, 197 | groups, 198 | *flags, 199 | ) 200 | ctx.save_for_backward(grad_output, input) 201 | 202 | return grad_weight 203 | 204 | @staticmethod 205 | def backward(ctx, grad_grad_weight): 206 | grad_output, input = ctx.saved_tensors 207 | grad_grad_output, grad_grad_input = None, None 208 | 209 | if ctx.needs_input_grad[0]: 210 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None) 211 | 212 | if ctx.needs_input_grad[1]: 213 | p = calc_output_padding( 214 | input_shape=input.shape, output_shape=grad_output.shape 215 | ) 216 | grad_grad_input = conv2d_gradfix( 217 | transpose=(not transpose), 218 | weight_shape=weight_shape, 219 | output_padding=p, 220 | **common_kwargs, 221 | ).apply(grad_output, grad_grad_weight, None) 222 | 223 | return grad_grad_output, grad_grad_input 224 | 225 | conv2d_gradfix_cache[key] = Conv2d 226 | 227 | return Conv2d 228 | -------------------------------------------------------------------------------- /stylegan2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, bias, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | if bias: 39 | grad_bias = grad_input.sum(dim).detach() 40 | 41 | else: 42 | grad_bias = empty 43 | 44 | return grad_input, grad_bias 45 | 46 | @staticmethod 47 | def backward(ctx, gradgrad_input, gradgrad_bias): 48 | out, = ctx.saved_tensors 49 | gradgrad_out = fused.fused_bias_act( 50 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 51 | ) 52 | 53 | return gradgrad_out, None, None, None, None 54 | 55 | 56 | class FusedLeakyReLUFunction(Function): 57 | @staticmethod 58 | def forward(ctx, input, bias, negative_slope, scale): 59 | empty = input.new_empty(0) 60 | 61 | ctx.bias = bias is not None 62 | 63 | if bias is None: 64 | bias = empty 65 | 66 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 67 | ctx.save_for_backward(out) 68 | ctx.negative_slope = negative_slope 69 | ctx.scale = scale 70 | 71 | return out 72 | 73 | @staticmethod 74 | def backward(ctx, grad_output): 75 | out, = ctx.saved_tensors 76 | 77 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 78 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale 79 | ) 80 | 81 | if not ctx.bias: 82 | grad_bias = None 83 | 84 | return grad_input, grad_bias, None, None 85 | 86 | 87 | class FusedLeakyReLU(nn.Module): 88 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 89 | super().__init__() 90 | 91 | if bias: 92 | self.bias = nn.Parameter(torch.zeros(channel)) 93 | 94 | else: 95 | self.bias = None 96 | 97 | self.negative_slope = negative_slope 98 | self.scale = scale 99 | 100 | def forward(self, input): 101 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 102 | 103 | 104 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 105 | if input.device.type == "cpu": 106 | if bias is not None: 107 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 108 | return ( 109 | F.leaky_relu( 110 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 111 | ) 112 | * scale 113 | ) 114 | 115 | else: 116 | return F.leaky_relu(input, negative_slope=0.2) * scale 117 | 118 | else: 119 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 120 | -------------------------------------------------------------------------------- /stylegan2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor &input, 6 | const torch::Tensor &bias, 7 | const torch::Tensor &refer, int act, int grad, 8 | float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) \ 11 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) \ 13 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 14 | #define CHECK_INPUT(x) \ 15 | CHECK_CUDA(x); \ 16 | CHECK_CONTIGUOUS(x) 17 | 18 | torch::Tensor fused_bias_act(const torch::Tensor &input, 19 | const torch::Tensor &bias, 20 | const torch::Tensor &refer, int act, int grad, 21 | float alpha, float scale) { 22 | CHECK_INPUT(input); 23 | CHECK_INPUT(bias); 24 | 25 | at::DeviceGuard guard(input.device()); 26 | 27 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 28 | } 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 32 | } -------------------------------------------------------------------------------- /stylegan2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | 15 | #include 16 | #include 17 | 18 | template 19 | static __global__ void 20 | fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b, 21 | const scalar_t *p_ref, int act, int grad, scalar_t alpha, 22 | scalar_t scale, int loop_x, int size_x, int step_b, 23 | int size_b, int use_bias, int use_ref) { 24 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 25 | 26 | scalar_t zero = 0.0; 27 | 28 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; 29 | loop_idx++, xi += blockDim.x) { 30 | scalar_t x = p_x[xi]; 31 | 32 | if (use_bias) { 33 | x += p_b[(xi / step_b) % size_b]; 34 | } 35 | 36 | scalar_t ref = use_ref ? p_ref[xi] : zero; 37 | 38 | scalar_t y; 39 | 40 | switch (act * 10 + grad) { 41 | default: 42 | case 10: 43 | y = x; 44 | break; 45 | case 11: 46 | y = x; 47 | break; 48 | case 12: 49 | y = 0.0; 50 | break; 51 | 52 | case 30: 53 | y = (x > 0.0) ? x : x * alpha; 54 | break; 55 | case 31: 56 | y = (ref > 0.0) ? x : x * alpha; 57 | break; 58 | case 32: 59 | y = 0.0; 60 | break; 61 | } 62 | 63 | out[xi] = y * scale; 64 | } 65 | } 66 | 67 | torch::Tensor fused_bias_act_op(const torch::Tensor &input, 68 | const torch::Tensor &bias, 69 | const torch::Tensor &refer, int act, int grad, 70 | float alpha, float scale) { 71 | int curDevice = -1; 72 | cudaGetDevice(&curDevice); 73 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 74 | 75 | auto x = input.contiguous(); 76 | auto b = bias.contiguous(); 77 | auto ref = refer.contiguous(); 78 | 79 | int use_bias = b.numel() ? 1 : 0; 80 | int use_ref = ref.numel() ? 1 : 0; 81 | 82 | int size_x = x.numel(); 83 | int size_b = b.numel(); 84 | int step_b = 1; 85 | 86 | for (int i = 1 + 1; i < x.dim(); i++) { 87 | step_b *= x.size(i); 88 | } 89 | 90 | int loop_x = 4; 91 | int block_size = 4 * 32; 92 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 93 | 94 | auto y = torch::empty_like(x); 95 | 96 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 97 | x.scalar_type(), "fused_bias_act_kernel", [&] { 98 | fused_bias_act_kernel<<>>( 99 | y.data_ptr(), x.data_ptr(), 100 | b.data_ptr(), ref.data_ptr(), act, grad, alpha, 101 | scale, loop_x, size_x, step_b, size_b, use_bias, use_ref); 102 | }); 103 | 104 | return y; 105 | } -------------------------------------------------------------------------------- /stylegan2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 5 | const torch::Tensor &kernel, int up_x, int up_y, 6 | int down_x, int down_y, int pad_x0, int pad_x1, 7 | int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) \ 10 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) \ 12 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) \ 14 | CHECK_CUDA(x); \ 15 | CHECK_CONTIGUOUS(x) 16 | 17 | torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel, 18 | int up_x, int up_y, int down_x, int down_y, int pad_x0, 19 | int pad_x1, int pad_y0, int pad_y1) { 20 | CHECK_INPUT(input); 21 | CHECK_INPUT(kernel); 22 | 23 | at::DeviceGuard guard(input.device()); 24 | 25 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, 26 | pad_y0, pad_y1); 27 | } 28 | 29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 30 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 31 | } -------------------------------------------------------------------------------- /stylegan2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | from collections import abc 2 | import os 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | upfirdn2d_op = load( 12 | "upfirdn2d", 13 | sources=[ 14 | os.path.join(module_path, "upfirdn2d.cpp"), 15 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class UpFirDn2dBackward(Function): 21 | @staticmethod 22 | def forward( 23 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 24 | ): 25 | 26 | up_x, up_y = up 27 | down_x, down_y = down 28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 29 | 30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 31 | 32 | grad_input = upfirdn2d_op.upfirdn2d( 33 | grad_output, 34 | grad_kernel, 35 | down_x, 36 | down_y, 37 | up_x, 38 | up_y, 39 | g_pad_x0, 40 | g_pad_x1, 41 | g_pad_y0, 42 | g_pad_y1, 43 | ) 44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 45 | 46 | ctx.save_for_backward(kernel) 47 | 48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 49 | 50 | ctx.up_x = up_x 51 | ctx.up_y = up_y 52 | ctx.down_x = down_x 53 | ctx.down_y = down_y 54 | ctx.pad_x0 = pad_x0 55 | ctx.pad_x1 = pad_x1 56 | ctx.pad_y0 = pad_y0 57 | ctx.pad_y1 = pad_y1 58 | ctx.in_size = in_size 59 | ctx.out_size = out_size 60 | 61 | return grad_input 62 | 63 | @staticmethod 64 | def backward(ctx, gradgrad_input): 65 | kernel, = ctx.saved_tensors 66 | 67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 68 | 69 | gradgrad_out = upfirdn2d_op.upfirdn2d( 70 | gradgrad_input, 71 | kernel, 72 | ctx.up_x, 73 | ctx.up_y, 74 | ctx.down_x, 75 | ctx.down_y, 76 | ctx.pad_x0, 77 | ctx.pad_x1, 78 | ctx.pad_y0, 79 | ctx.pad_y1, 80 | ) 81 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 82 | gradgrad_out = gradgrad_out.view( 83 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 84 | ) 85 | 86 | return gradgrad_out, None, None, None, None, None, None, None, None 87 | 88 | 89 | class UpFirDn2d(Function): 90 | @staticmethod 91 | def forward(ctx, input, kernel, up, down, pad): 92 | up_x, up_y = up 93 | down_x, down_y = down 94 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 95 | 96 | kernel_h, kernel_w = kernel.shape 97 | batch, channel, in_h, in_w = input.shape 98 | ctx.in_size = input.shape 99 | 100 | input = input.reshape(-1, in_h, in_w, 1) 101 | 102 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 103 | 104 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 105 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 106 | ctx.out_size = (out_h, out_w) 107 | 108 | ctx.up = (up_x, up_y) 109 | ctx.down = (down_x, down_y) 110 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 111 | 112 | g_pad_x0 = kernel_w - pad_x0 - 1 113 | g_pad_y0 = kernel_h - pad_y0 - 1 114 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 115 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 116 | 117 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 118 | 119 | out = upfirdn2d_op.upfirdn2d( 120 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 121 | ) 122 | # out = out.view(major, out_h, out_w, minor) 123 | out = out.view(-1, channel, out_h, out_w) 124 | 125 | return out 126 | 127 | @staticmethod 128 | def backward(ctx, grad_output): 129 | kernel, grad_kernel = ctx.saved_tensors 130 | 131 | grad_input = None 132 | 133 | if ctx.needs_input_grad[0]: 134 | grad_input = UpFirDn2dBackward.apply( 135 | grad_output, 136 | kernel, 137 | grad_kernel, 138 | ctx.up, 139 | ctx.down, 140 | ctx.pad, 141 | ctx.g_pad, 142 | ctx.in_size, 143 | ctx.out_size, 144 | ) 145 | 146 | return grad_input, None, None, None, None 147 | 148 | 149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 150 | if not isinstance(up, abc.Iterable): 151 | up = (up, up) 152 | 153 | if not isinstance(down, abc.Iterable): 154 | down = (down, down) 155 | 156 | if len(pad) == 2: 157 | pad = (pad[0], pad[1], pad[0], pad[1]) 158 | 159 | if input.device.type == "cpu": 160 | out = upfirdn2d_native(input, kernel, *up, *down, *pad) 161 | 162 | else: 163 | out = UpFirDn2d.apply(input, kernel, up, down, pad) 164 | 165 | return out 166 | 167 | 168 | def upfirdn2d_native( 169 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 170 | ): 171 | _, channel, in_h, in_w = input.shape 172 | input = input.reshape(-1, in_h, in_w, 1) 173 | 174 | _, in_h, in_w, minor = input.shape 175 | kernel_h, kernel_w = kernel.shape 176 | 177 | out = input.view(-1, in_h, 1, in_w, 1, minor) 178 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 179 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 180 | 181 | out = F.pad( 182 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 183 | ) 184 | out = out[ 185 | :, 186 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 187 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 188 | :, 189 | ] 190 | 191 | out = out.permute(0, 3, 1, 2) 192 | out = out.reshape( 193 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 194 | ) 195 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 196 | out = F.conv2d(out, w) 197 | out = out.reshape( 198 | -1, 199 | minor, 200 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 201 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 202 | ) 203 | out = out.permute(0, 2, 3, 1) 204 | out = out[:, ::down_y, ::down_x, :] 205 | 206 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 207 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 208 | 209 | return out.view(-1, channel, out_h, out_w) 210 | -------------------------------------------------------------------------------- /stylegan2/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /weights/car/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OmerKafri/StyleFusion/c1aa57267a192b3c50ee501339ed50a0aedb35fe/weights/car/.placeholder -------------------------------------------------------------------------------- /weights/car_weights.json: -------------------------------------------------------------------------------- 1 | { 2 | "all": "weights/car/all.pt", 3 | "car": "weights/car/car.pt", 4 | "background": "weights/car/background.pt" 5 | } 6 | -------------------------------------------------------------------------------- /weights/church/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OmerKafri/StyleFusion/c1aa57267a192b3c50ee501339ed50a0aedb35fe/weights/church/.placeholder -------------------------------------------------------------------------------- /weights/church_weights.json: -------------------------------------------------------------------------------- 1 | { 2 | "all": "weights/church/all.pt" 3 | } 4 | -------------------------------------------------------------------------------- /weights/ffhq/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OmerKafri/StyleFusion/c1aa57267a192b3c50ee501339ed50a0aedb35fe/weights/ffhq/.placeholder -------------------------------------------------------------------------------- /weights/ffhq_weights.json: -------------------------------------------------------------------------------- 1 | { 2 | "all": "weights/ffhq/all.pt", 3 | "bg_hair_clothes": "weights/ffhq/bg_hair_clothes.pt", 4 | "bg_clothes": "weights/ffhq/bg_clothes.pt", 5 | "face": "weights/ffhq/face.pt", 6 | "skin_mouth": "weights/ffhq/skin_mouth.pt" 7 | } 8 | --------------------------------------------------------------------------------