├── .gitignore
├── LICENSE
├── README.md
├── overview.jpg
├── style_fusion_simple.py
├── stylefusion.ipynb
├── stylefusion
├── fusion_net.py
├── sf_hierarchy.py
└── sf_stylegan2.py
├── stylegan2
├── __init__.py
├── model.py
└── op
│ ├── __init__.py
│ ├── conv2d_gradfix.py
│ ├── fused_act.py
│ ├── fused_bias_act.cpp
│ ├── fused_bias_act_kernel.cu
│ ├── upfirdn2d.cpp
│ ├── upfirdn2d.py
│ └── upfirdn2d_kernel.cu
└── weights
├── car
└── .placeholder
├── car_weights.json
├── church
└── .placeholder
├── church_weights.json
├── ffhq
└── .placeholder
└── ffhq_weights.json
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | weights/ffhq/*
3 | weights/car/*
4 | weights/church/*
5 | /.ipynb_checkpoints/
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Omer Kafri
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # StyleFusion: A Generative Model for Disentangling Spatial Segments
2 |
3 | 
4 |
5 | > **StyleFusion: A Generative Model for Disentangling Spatial Segments**
6 | > Omer Kafri, Or Patashnik, Yuval Alaluf, Daniel Cohen-Or
7 | > https://arxiv.org/abs/2107.07437
8 | >
9 | >**Abstract:** We present StyleFusion, a new mapping architecture for StyleGAN, which takes as input a number of latent codes and fuses them into a single style code. Inserting the resulting style code into a pre-trained StyleGAN generator results in a single harmonized image in which each semantic region is controlled by one of the input latent codes. Effectively, StyleFusion yields a disentangled representation of the image, providing fine-grained control over each region of the generated image. Moreover, to help facilitate global control over the generated image, a special input latent code is incorporated into the fused representation. StyleFusion operates in a hierarchical manner, where each level is tasked with learning to disentangle a pair of image regions (e.g., the car body and wheels). The resulting learned disentanglement allows one to modify both local, fine-grained semantics (e.g., facial features) as well as more global features (e.g., pose and background), providing improved flexibility in the synthesis process. As a natural extension, StyleFusion enables one to perform semantically-aware cross-image mixing of regions that are not necessarily aligned. Finally, we demonstrate how StyleFusion can be paired with existing editing techniques to more faithfully constrain the edit to the user's region of interest.
10 |
11 | ## Citation
12 |
13 | If you use this code for your research, please cite our paper:
14 |
15 | ```
16 | @misc{kafri2021stylefusion,
17 | title={StyleFusion: A Generative Model for Disentangling Spatial Segments},
18 | author={Omer Kafri and Or Patashnik and Yuval Alaluf and Daniel Cohen-Or},
19 | year={2021},
20 | eprint={2107.07437},
21 | archivePrefix={arXiv},
22 | primaryClass={cs.CV}
23 | }
24 | ```
25 |
--------------------------------------------------------------------------------
/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OmerKafri/StyleFusion/c1aa57267a192b3c50ee501339ed50a0aedb35fe/overview.jpg
--------------------------------------------------------------------------------
/style_fusion_simple.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 |
4 | import torch
5 | import torch.nn.functional
6 |
7 | sys.path.append(".")
8 | sys.path.append("..")
9 |
10 | from stylefusion.sf_stylegan2 import SFGenerator
11 | from stylefusion.sf_hierarchy import SFHierarchyFFHQ, SFHierarchyCar, SFHierarchyChurch
12 | from PIL import Image
13 |
14 |
15 | def tensor2im(var):
16 | var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy()
17 | var = ((var + 1) / 2)
18 | var[var < 0] = 0
19 | var[var > 1] = 1
20 | var = var * 255
21 | return Image.fromarray(var.astype('uint8'))
22 |
23 |
24 | class StyleFusionSimple:
25 | def __init__(self, stylegan_type, stylegan_weights, fusion_nets_weights):
26 | self.stylegan_type = stylegan_type
27 | if self.stylegan_type == "ffhq":
28 | self.truncation = 0.7
29 | self.stylegan_size = 1024
30 | self.stylegan_layers = 18
31 | elif self.stylegan_type == "car":
32 | self.truncation = 0.5
33 | self.stylegan_size = 512
34 | self.stylegan_layers = 16
35 | elif self.stylegan_type == "church":
36 | self.truncation = 0.5
37 | self.stylegan_size = 256
38 | self.stylegan_layers = 14
39 |
40 | self.device = 'cuda:0'
41 |
42 | stylegan_ckpt = torch.load(stylegan_weights, map_location='cpu')
43 | self.original_net = SFGenerator(self.stylegan_size, 512, 8)
44 | self.original_net.load_state_dict(stylegan_ckpt['g_ema'], strict=True)
45 |
46 | self.original_net.to(self.device)
47 |
48 | with torch.no_grad():
49 | self.mean_latent = self.original_net.mean_latent(4096)
50 |
51 | if self.stylegan_type == "ffhq":
52 | self.sf_hierarchy = SFHierarchyFFHQ()
53 | self.base_blender = self.sf_hierarchy.nodes["all"]
54 | elif self.stylegan_type == "car":
55 | self.sf_hierarchy = SFHierarchyCar()
56 | self.base_blender = self.sf_hierarchy.nodes["all"]
57 | elif self.stylegan_type == "church":
58 | self.sf_hierarchy = SFHierarchyChurch()
59 | self.base_blender = self.sf_hierarchy.nodes["all"]
60 |
61 | with open(fusion_nets_weights, 'r') as f:
62 | fusion_nets_paths = json.load(f)
63 |
64 | keys = fusion_nets_paths.keys()
65 |
66 | for key in keys:
67 | self.sf_hierarchy.nodes[key].load_fusion_net(fusion_nets_paths[key])
68 | self.sf_hierarchy.nodes[key].fusion_net.to(self.device)
69 | self.sf_hierarchy.nodes[key].fusion_net.eval()
70 |
71 | def generate_img(self, base_latent, latents_type="z", hair=None, face=None, background=None, all=None, mouth=None,
72 | eyes=None, wheels=None, car=None, bg_top=None, bg_bottom=None):
73 | s_dict = dict()
74 | parts = self.sf_hierarchy.nodes["all"].get_all_active_parts()
75 | for part in parts:
76 | s_dict[part] = self.general_latent_to_s(base_latent, latents_type)
77 |
78 | def swap(value, keys):
79 | if value is None:
80 | return
81 | for k in keys:
82 | s_dict[k] = self.general_latent_to_s(value, latents_type)
83 |
84 | swap(hair, ["bg_hair_clothes", "hair"])
85 | swap(face, ["face", "eyes", "skin_mouth", "mouth", "skin", "shirt"])
86 | swap(background, ["background", "background_top", "background_bottom", "bg"])
87 | swap(all, ["all"])
88 | swap(mouth, ["skin_mouth", "face"])
89 | swap(eyes, ["eyes", "face"])
90 | swap(wheels, ["wheels"])
91 | swap(car, ["car", "body", "wheels", "car_body"])
92 | swap(bg_top, ["background_top"])
93 | swap(bg_bottom, ["background_bottom"])
94 |
95 | return self.s_dict_to_image(s_dict)[0]
96 |
97 | def seed_to_z(self, seed):
98 | torch.manual_seed(seed[0])
99 | z_regular = torch.randn((seed[1] + 1, 1, 512), device=self.device)
100 | return z_regular[seed[1]]
101 |
102 | def z_to_s(self, z):
103 | return self.original_net([z],
104 | truncation=self.truncation, truncation_latent=self.mean_latent,
105 | randomize_noise=False, return_style_vector=True)
106 |
107 | def z_to_w_plus(self, z):
108 | _, res = self.original_net([z],
109 | truncation=self.truncation, truncation_latent=self.mean_latent,
110 | randomize_noise=False, return_latents=True)
111 | return res[0]
112 |
113 | def w_plus_to_s(self, w_plus, truncation):
114 | return self.original_net([w_plus], input_is_latent=True,
115 | truncation=truncation, truncation_latent=self.mean_latent,
116 | randomize_noise=False, return_style_vector=True)
117 |
118 | def general_latent_to_s(self, l, latent_type):
119 | assert latent_type in ["z", "w", "w+", "s"]
120 |
121 | if latent_type == "z":
122 | assert l.size() == (1, 512)
123 | return self.z_to_s(l)
124 | elif latent_type == "w" or latent_type == "w+":
125 | assert l.size() == (1, 512) or l.size() == (1, self.stylegan_layers, 512)
126 | if l.dim() == 2:
127 | return self.w_plus_to_s(l.unsqueeze(0).repeat(1, self.stylegan_layers, 1), truncation=1)
128 | else:
129 | return self.w_plus_to_s(l, truncation=1)
130 | else:
131 | return l
132 |
133 | def s_to_image(self, s):
134 | img, _ = self.original_net([torch.zeros(1, 512, device=self.device)],
135 | randomize_noise=False, style_vector=s)
136 | return img
137 |
138 | def w_plus_to_image(self, w_plus):
139 | s = self.w_plus_to_s(w_plus, truncation=1)
140 | return self.s_to_image(s)
141 |
142 | def z_to_image(self, z):
143 | s = self.z_to_s(z)
144 | return self.s_to_image(s)
145 |
146 | def s_dict_to_image(self, s_dict):
147 | s = self.base_blender.forward(s_dict)
148 | return self.s_to_image(s)
149 |
150 | def w_plus_dict_to_image(self, w_plus_dict, truncation=1):
151 | s_dict = dict()
152 | for key in w_plus_dict.keys():
153 | s_dict[key] = self.w_plus_to_s(w_plus_dict[key], truncation=truncation)
154 | return self.s_dict_to_image(s_dict)
155 |
156 | def z_dict_to_image(self, z_dict):
157 | s_dict = dict()
158 | for key in z_dict.keys():
159 | s_dict[key] = self.z_to_s(z_dict[key])
160 | return self.s_dict_to_image(s_dict)
161 |
--------------------------------------------------------------------------------
/stylefusion/fusion_net.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from stylegan2.model import EqualLinear
7 |
8 | MAX_LAYERS = 18
9 | DEFUALT_COMMON_LAYERS = 5
10 | DEFAULT_INDEX_BITS = 10
11 |
12 |
13 | class LatentBlender(nn.Module):
14 | def __init__(self, size, index_bits=DEFAULT_INDEX_BITS, extras=0):
15 | super(LatentBlender, self).__init__()
16 | layers = []
17 |
18 | self.size = size
19 | self.index_bits = index_bits
20 |
21 | sizes = [size * 2 + index_bits * (MAX_LAYERS + 1) + extras * size, size * 5, size * 5, size * 5, size, size]
22 |
23 | for i in range(len(sizes) - 1):
24 | layers.append(
25 | EqualLinear(
26 | sizes[i], sizes[i + 1], lr_mul=0.01, activation='fused_lrelu'
27 | )
28 | )
29 |
30 | self.disentangle_net = nn.Sequential(*layers)
31 |
32 | def forward(self, a, b, i=None, rgb=False, extra=None):
33 | x = torch.cat((a[0], b[0]))
34 | if not extra is None:
35 | x = torch.cat((x, extra))
36 | if not (i is None):
37 | indicator = torch.zeros(MAX_LAYERS + 1, device="cuda:0")
38 | indicator[i] = 1
39 | if rgb:
40 | indicator[-1] = 1
41 | x = torch.cat([x, indicator.repeat(self.index_bits)])
42 | x = self.disentangle_net(x)
43 | x = torch.sigmoid(x)
44 | x = x.unsqueeze(0)
45 |
46 | return a[0] + x * (b[0] - a[0]) # Fusion procedure
47 |
48 |
49 | class FusionNet(nn.Module):
50 | def __init__(self, index_bits=DEFAULT_INDEX_BITS, common_layers=DEFUALT_COMMON_LAYERS):
51 | super(FusionNet, self).__init__()
52 |
53 | self.max_pools = [nn.MaxPool1d(512 // x) for x in [512, 256, 128, 64, 32]]
54 | self.upsample = nn.Upsample(size=512, mode='nearest')
55 |
56 | self.blender_segments = LatentBlender(512, index_bits=index_bits, extras=1)
57 | self.blender_common = LatentBlender(512, index_bits=index_bits)
58 | self.common_layers = common_layers
59 |
60 | def up(self, x):
61 | return self.upsample(x.unsqueeze(0)).squeeze(0)
62 |
63 | def pool(self, x, size):
64 | size_index = 9 - int(math.log2(size))
65 | return self.max_pools[size_index](x.unsqueeze(0)).squeeze(0)
66 |
67 | def forward(self, s0, s1, s2, only_common=False, only_segments=False):
68 | if only_common:
69 | assert s1 is None
70 | if only_segments:
71 | assert s2 is None
72 | res = [[],[]]
73 | for rgb in [0, 1]:
74 | for i in range(len(s0[rgb])):
75 | s0_up = self.up(s0[rgb][i])
76 | if not only_common:
77 | s1_up = self.up(s1[rgb][i])
78 | if not only_segments:
79 | s2_up = self.up(s2[rgb][i])
80 |
81 | skip_common_layer = False
82 | if not (self.common_layers is None):
83 | skip_common_layer = i >= self.common_layers
84 |
85 | if only_common:
86 | if skip_common_layer:
87 | res[rgb].append(s0[rgb][i])
88 | else:
89 | x = self.blender_common(s0_up, s2_up, i=i, rgb=(rgb==1))
90 | res[rgb].append(self.pool(x, s0[rgb][i].size()[1]))
91 | elif only_segments or skip_common_layer:
92 | x = self.blender_segments(s0_up, s1_up, i=i, rgb=False, extra=torch.zeros_like(s1_up[0]))
93 | res[rgb].append(self.pool(x, s0[rgb][i].size()[1]))
94 | else:
95 | x0 = self.blender_common(s0_up, s2_up, i=i, rgb=False)
96 | x0 = self.pool(x0, s0[rgb][i].size()[1])
97 | x1 = self.blender_common(s1_up, s2_up, i=i, rgb=False)
98 | x1 = self.pool(x1, s0[rgb][i].size()[1])
99 |
100 | x = self.blender_segments(self.up(x0), self.up(x1), i=i, rgb=False, extra=s2_up[0])
101 | res[rgb].append(self.pool(x, s0[rgb][i].size()[1]))
102 |
103 | return res
104 |
--------------------------------------------------------------------------------
/stylefusion/sf_hierarchy.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from stylefusion.fusion_net import FusionNet
4 |
5 |
6 | class SFHierarchyFFHQ:
7 | def __init__(self):
8 | self.nodes = dict()
9 | self.nodes["clothes"] = SFNode("clothes")
10 | self.nodes["mouth"] = SFNode("mouth")
11 | self.nodes["eyes"] = SFNode("eyes")
12 | self.nodes["bg"] = SFNode("bg")
13 | self.nodes["hair"] = SFNode("hair")
14 | self.nodes["skin"] = SFNode("skin")
15 | self.nodes["skin_mouth"] = SFNode("skin_mouth", child1=self.nodes["mouth"], child2=self.nodes["skin"])
16 | self.nodes["face"] = SFNode("face", child1=self.nodes["skin_mouth"], child2=self.nodes["eyes"])
17 | self.nodes["bg_clothes"] = SFNode("bg_clothes", child1=self.nodes["clothes"], child2=self.nodes["bg"])
18 | self.nodes["bg_hair_clothes"] = SFNode("bg_hair_clothes",
19 | child1=self.nodes["bg_clothes"], child2=self.nodes["hair"])
20 | self.nodes["all"] = SFNode("all", child1=self.nodes["face"], child2=self.nodes["bg_hair_clothes"])
21 |
22 |
23 | class SFHierarchyCar:
24 | def __init__(self):
25 | self.nodes = dict()
26 | self.nodes["wheels"] = SFNode("wheels")
27 | self.nodes["car_body"] = SFNode("car_body")
28 | self.nodes["background_top"] = SFNode("background_top")
29 | self.nodes["background_bottom"] = SFNode("background_bottom")
30 | self.nodes["car"] = SFNode("car", child1=self.nodes["car_body"], child2=self.nodes["wheels"])
31 | self.nodes["background"] = SFNode("background",
32 | child1=self.nodes["background_top"], child2=self.nodes["background_bottom"])
33 | self.nodes["all"] = SFNode("all", child1=self.nodes["car"], child2=self.nodes["background"])
34 |
35 |
36 | class SFHierarchyChurch:
37 | def __init__(self=None):
38 | self.nodes = dict()
39 | self.nodes["church"] = SFNode("church")
40 | self.nodes["background"] = SFNode("background")
41 | self.nodes["all"] = SFNode("all", child1=self.nodes["church"], child2=self.nodes["background"])
42 |
43 |
44 | class SFNode:
45 | def __init__(self, name, child1=None, child2=None):
46 | self.name = name
47 | self.child1 = child1
48 | self.child2 = child2
49 | self.fusion_net = None
50 | if child1 is None or child2 is None:
51 | assert child1 is None and child2 is None
52 | self._leaf = True
53 | else:
54 | self._leaf = False
55 |
56 | def get_all_parts(self):
57 | if self._leaf:
58 | return [self.name]
59 | else:
60 | return [self.name] + self.child1.get_all_parts() + self.child2.get_all_parts()
61 |
62 | def get_all_active_parts(self):
63 | if self.fusion_net is None:
64 | return [self.name]
65 | else:
66 | return [self.name] + self.child1.get_all_active_parts() + self.child2.get_all_active_parts()
67 |
68 | def get_fusion_nets_amount(self):
69 | if self.fusion_net is None:
70 | return 0
71 | if self._leaf:
72 | return 1
73 | else:
74 | return 1 + self.child1.get_fusion_nets_amount() + self.child2.get_fusion_nets_amount()
75 |
76 | def get_fusion_nets(self):
77 | if self.fusion_net is None:
78 | return []
79 | if self._leaf:
80 | return [self.fusion_net]
81 | else:
82 | return [self.fusion_net] + self.child1.get_fusion_nets() + self.child2.get_fusion_nets()
83 |
84 | def forward(self, s_dict):
85 | if self.fusion_net is None:
86 | return s_dict[self.name]
87 |
88 | if not (self.child1.name in s_dict.keys() and self.child2.name in s_dict.keys()):
89 | return s_dict[self.name]
90 |
91 | return self.fusion_net(
92 | self.child1.forward(s_dict),
93 | self.child2.forward(s_dict),
94 | s_dict[self.name])
95 |
96 | def load_fusion_net(self, path):
97 | data = torch.load(path)
98 |
99 | self.fusion_net = FusionNet()
100 |
101 | if "state_dict" in data.keys():
102 | data = data["state_dict"]
103 |
104 | self.fusion_net.load_state_dict(data)
105 |
--------------------------------------------------------------------------------
/stylefusion/sf_stylegan2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from stylegan2.model import *
5 |
6 |
7 | class SFModulatedConv2d(ModulatedConv2d):
8 | def forward(self, input, style, style_vector=None):
9 | batch, in_channel, height, width = input.shape
10 |
11 | if not self.fused:
12 | weight = self.scale * self.weight.squeeze(0)
13 | style = self.modulation(style)
14 |
15 | if self.demodulate:
16 | w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
17 | dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
18 |
19 | input = input * style.reshape(batch, in_channel, 1, 1)
20 |
21 | if self.upsample:
22 | weight = weight.transpose(0, 1)
23 | out = conv2d_gradfix.conv_transpose2d(
24 | input, weight, padding=0, stride=2
25 | )
26 | out = self.blur(out)
27 |
28 | elif self.downsample:
29 | input = self.blur(input)
30 | out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
31 |
32 | else:
33 | out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
34 |
35 | if self.demodulate:
36 | out = out * dcoefs.view(batch, -1, 1, 1)
37 |
38 | return out
39 |
40 | if style_vector is None:
41 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
42 | else:
43 | style = style_vector.view(batch, 1, in_channel, 1, 1)
44 |
45 | weight = self.scale * self.weight * style
46 |
47 | if self.demodulate:
48 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
49 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
50 |
51 | weight = weight.view(
52 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
53 | )
54 |
55 | if self.upsample:
56 | input = input.view(1, batch * in_channel, height, width)
57 | weight = weight.view(
58 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
59 | )
60 | weight = weight.transpose(1, 2).reshape(
61 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
62 | )
63 | out = conv2d_gradfix.conv_transpose2d(
64 | input, weight, padding=0, stride=2, groups=batch
65 | )
66 | _, _, height, width = out.shape
67 | out = out.view(batch, self.out_channel, height, width)
68 | out = self.blur(out)
69 |
70 | elif self.downsample:
71 | input = self.blur(input)
72 | _, _, height, width = input.shape
73 | input = input.view(1, batch * in_channel, height, width)
74 | out = conv2d_gradfix.conv2d(
75 | input, weight, padding=0, stride=2, groups=batch
76 | )
77 | _, _, height, width = out.shape
78 | out = out.view(batch, self.out_channel, height, width)
79 |
80 | else:
81 | input = input.view(1, batch * in_channel, height, width)
82 | out = conv2d_gradfix.conv2d(
83 | input, weight, padding=self.padding, groups=batch
84 | )
85 | _, _, height, width = out.shape
86 | out = out.view(batch, self.out_channel, height, width)
87 |
88 | return out
89 |
90 |
91 | class SFStyledConv(StyledConv):
92 | def __init__(
93 | self,
94 | in_channel,
95 | out_channel,
96 | kernel_size,
97 | style_dim,
98 | upsample=False,
99 | blur_kernel=[1, 3, 3, 1],
100 | demodulate=True,
101 | ):
102 | super(StyledConv, self).__init__()
103 |
104 | self.conv = SFModulatedConv2d(
105 | in_channel,
106 | out_channel,
107 | kernel_size,
108 | style_dim,
109 | upsample=upsample,
110 | blur_kernel=blur_kernel,
111 | demodulate=demodulate,
112 | )
113 |
114 | self.noise = NoiseInjection()
115 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
116 | # self.activate = ScaledLeakyReLU(0.2)
117 | self.activate = FusedLeakyReLU(out_channel)
118 |
119 | def forward(self, input, style, noise=None, style_vector=None):
120 | out = self.conv(input, style, style_vector=style_vector)
121 | out = self.noise(out, noise=noise)
122 | # out = out + self.bias
123 | out = self.activate(out)
124 |
125 | return out
126 |
127 | def get_style_vector(self, style):
128 | return self.conv.modulation(style)
129 |
130 |
131 | class SFToRGB(ToRGB):
132 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
133 | super(ToRGB, self).__init__()
134 |
135 | if upsample:
136 | self.upsample = Upsample(blur_kernel)
137 |
138 | self.conv = SFModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
139 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
140 |
141 | def forward(self, input, style, skip=None, style_vector=None):
142 | out = self.conv(input, style, style_vector=style_vector)
143 | out = out + self.bias
144 |
145 | if skip is not None:
146 | skip = self.upsample(skip)
147 |
148 | out = out + skip
149 |
150 | return out
151 |
152 | def get_style_vector(self, style):
153 | return self.conv.modulation(style)
154 |
155 |
156 | class SFGenerator(Generator):
157 | def __init__(
158 | self,
159 | size,
160 | style_dim,
161 | n_mlp,
162 | channel_multiplier=2,
163 | blur_kernel=[1, 3, 3, 1],
164 | lr_mlp=0.01,
165 | ):
166 | super(Generator, self).__init__()
167 |
168 | self.size = size
169 |
170 | self.style_dim = style_dim
171 |
172 | layers = [PixelNorm()]
173 |
174 | for i in range(n_mlp):
175 | layers.append(
176 | EqualLinear(
177 | style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
178 | )
179 | )
180 |
181 | self.style = nn.Sequential(*layers)
182 |
183 | self.channels = {
184 | 4: 512,
185 | 8: 512,
186 | 16: 512,
187 | 32: 512,
188 | 64: 256 * channel_multiplier,
189 | 128: 128 * channel_multiplier,
190 | 256: 64 * channel_multiplier,
191 | 512: 32 * channel_multiplier,
192 | 1024: 16 * channel_multiplier,
193 | }
194 |
195 | self.input = ConstantInput(self.channels[4])
196 | self.conv1 = SFStyledConv(
197 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
198 | )
199 | self.to_rgb1 = SFToRGB(self.channels[4], style_dim, upsample=False)
200 |
201 | self.log_size = int(math.log(size, 2))
202 | self.num_layers = (self.log_size - 2) * 2 + 1
203 |
204 | self.convs = nn.ModuleList()
205 | self.upsamples = nn.ModuleList()
206 | self.to_rgbs = nn.ModuleList()
207 | self.noises = nn.Module()
208 |
209 | in_channel = self.channels[4]
210 |
211 | for layer_idx in range(self.num_layers):
212 | res = (layer_idx + 5) // 2
213 | shape = [1, 1, 2 ** res, 2 ** res]
214 | self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
215 |
216 | for i in range(3, self.log_size + 1):
217 | out_channel = self.channels[2 ** i]
218 |
219 | self.convs.append(
220 | SFStyledConv(
221 | in_channel,
222 | out_channel,
223 | 3,
224 | style_dim,
225 | upsample=True,
226 | blur_kernel=blur_kernel,
227 | )
228 | )
229 |
230 | self.convs.append(
231 | SFStyledConv(
232 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
233 | )
234 | )
235 |
236 | self.to_rgbs.append(SFToRGB(out_channel, style_dim))
237 |
238 | in_channel = out_channel
239 |
240 | self.n_latent = self.log_size * 2 - 2
241 |
242 | def forward(
243 | self,
244 | styles,
245 | return_latents=False,
246 | return_style_vector=False,
247 | inject_index=None,
248 | truncation=1,
249 | truncation_latent=None,
250 | input_is_latent=False,
251 | style_vector=None,
252 | noise=None,
253 | randomize_noise=True,
254 | ):
255 | if not input_is_latent:
256 | styles = [self.style(s) for s in styles]
257 |
258 | if noise is None:
259 | if randomize_noise:
260 | noise = [None] * self.num_layers
261 | else:
262 | noise = [
263 | getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
264 | ]
265 |
266 | if truncation < 1:
267 | style_t = []
268 |
269 | for style in styles:
270 | style_t.append(
271 | truncation_latent + truncation * (style - truncation_latent)
272 | )
273 |
274 | styles = style_t
275 |
276 | if len(styles) < 2:
277 | inject_index = self.n_latent
278 |
279 | if styles[0].ndim < 3:
280 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
281 | else:
282 | latent = styles[0]
283 |
284 | else:
285 | if inject_index is None:
286 | inject_index = 2#random.randint(1, self.n_latent - 1)
287 |
288 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
289 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
290 |
291 | latent = torch.cat([latent, latent2], 1)
292 |
293 | if return_style_vector:
294 | s = [self.conv1.get_style_vector(latent[:, 0])] + \
295 | [self.convs[i].get_style_vector(latent[:, i + 1]) for i in range(len(self.convs))]
296 | s_rgb = [self.to_rgb1.get_style_vector(latent[:, 1])] + \
297 | [self.to_rgbs[i].get_style_vector(latent[:, i * 2 + 2]) for i in range(len(self.to_rgbs))]
298 | return [s, s_rgb]
299 |
300 | if style_vector is None:
301 | style_vector = [[None] * (len(self.convs) + 1) , [None] * (len(self.to_rgbs) + 1)]
302 |
303 | out = self.input(latent)
304 | out = self.conv1(out, latent[:, 0], noise=noise[0], style_vector=style_vector[0][0])
305 |
306 | skip = self.to_rgb1(out, latent[:, 1], style_vector=style_vector[1][0])
307 |
308 | i = 1
309 | for conv1, conv2, noise1, noise2, to_rgb in zip(
310 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
311 | ):
312 | out = conv1(out, latent[:, i], noise=noise1, style_vector=style_vector[0][i])
313 | out = conv2(out, latent[:, i + 1], noise=noise2, style_vector=style_vector[0][i + 1])
314 | skip = to_rgb(out, latent[:, i + 2], skip, style_vector=style_vector[1][i//2 + 1])
315 |
316 | i += 2
317 |
318 | image = skip
319 |
320 | if return_latents:
321 | return image, latent
322 | else:
323 | return image, None
--------------------------------------------------------------------------------
/stylegan2/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OmerKafri/StyleFusion/c1aa57267a192b3c50ee501339ed50a0aedb35fe/stylegan2/__init__.py
--------------------------------------------------------------------------------
/stylegan2/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 |
8 | from stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
9 |
10 |
11 | class PixelNorm(nn.Module):
12 | def __init__(self):
13 | super().__init__()
14 |
15 | def forward(self, input):
16 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
17 |
18 |
19 | def make_kernel(k):
20 | k = torch.tensor(k, dtype=torch.float32)
21 |
22 | if k.ndim == 1:
23 | k = k[None, :] * k[:, None]
24 |
25 | k /= k.sum()
26 |
27 | return k
28 |
29 |
30 | class Upsample(nn.Module):
31 | def __init__(self, kernel, factor=2):
32 | super().__init__()
33 |
34 | self.factor = factor
35 | kernel = make_kernel(kernel) * (factor ** 2)
36 | self.register_buffer("kernel", kernel)
37 |
38 | p = kernel.shape[0] - factor
39 |
40 | pad0 = (p + 1) // 2 + factor - 1
41 | pad1 = p // 2
42 |
43 | self.pad = (pad0, pad1)
44 |
45 | def forward(self, input):
46 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
47 |
48 | return out
49 |
50 |
51 | class Downsample(nn.Module):
52 | def __init__(self, kernel, factor=2):
53 | super().__init__()
54 |
55 | self.factor = factor
56 | kernel = make_kernel(kernel)
57 | self.register_buffer("kernel", kernel)
58 |
59 | p = kernel.shape[0] - factor
60 |
61 | pad0 = (p + 1) // 2
62 | pad1 = p // 2
63 |
64 | self.pad = (pad0, pad1)
65 |
66 | def forward(self, input):
67 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
68 |
69 | return out
70 |
71 |
72 | class Blur(nn.Module):
73 | def __init__(self, kernel, pad, upsample_factor=1):
74 | super().__init__()
75 |
76 | kernel = make_kernel(kernel)
77 |
78 | if upsample_factor > 1:
79 | kernel = kernel * (upsample_factor ** 2)
80 |
81 | self.register_buffer("kernel", kernel)
82 |
83 | self.pad = pad
84 |
85 | def forward(self, input):
86 | out = upfirdn2d(input, self.kernel, pad=self.pad)
87 |
88 | return out
89 |
90 |
91 | class EqualConv2d(nn.Module):
92 | def __init__(
93 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
94 | ):
95 | super().__init__()
96 |
97 | self.weight = nn.Parameter(
98 | torch.randn(out_channel, in_channel, kernel_size, kernel_size)
99 | )
100 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
101 |
102 | self.stride = stride
103 | self.padding = padding
104 |
105 | if bias:
106 | self.bias = nn.Parameter(torch.zeros(out_channel))
107 |
108 | else:
109 | self.bias = None
110 |
111 | def forward(self, input):
112 | out = conv2d_gradfix.conv2d(
113 | input,
114 | self.weight * self.scale,
115 | bias=self.bias,
116 | stride=self.stride,
117 | padding=self.padding,
118 | )
119 |
120 | return out
121 |
122 | def __repr__(self):
123 | return (
124 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
125 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
126 | )
127 |
128 |
129 | class EqualLinear(nn.Module):
130 | def __init__(
131 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
132 | ):
133 | super().__init__()
134 |
135 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
136 |
137 | if bias:
138 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
139 |
140 | else:
141 | self.bias = None
142 |
143 | self.activation = activation
144 |
145 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul
146 | self.lr_mul = lr_mul
147 |
148 | def forward(self, input):
149 | if self.activation:
150 | out = F.linear(input, self.weight * self.scale)
151 | out = fused_leaky_relu(out, self.bias * self.lr_mul)
152 |
153 | else:
154 | out = F.linear(
155 | input, self.weight * self.scale, bias=self.bias * self.lr_mul
156 | )
157 |
158 | return out
159 |
160 | def __repr__(self):
161 | return (
162 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
163 | )
164 |
165 |
166 | class ModulatedConv2d(nn.Module):
167 | def __init__(
168 | self,
169 | in_channel,
170 | out_channel,
171 | kernel_size,
172 | style_dim,
173 | demodulate=True,
174 | upsample=False,
175 | downsample=False,
176 | blur_kernel=[1, 3, 3, 1],
177 | fused=True,
178 | ):
179 | super().__init__()
180 |
181 | self.eps = 1e-8
182 | self.kernel_size = kernel_size
183 | self.in_channel = in_channel
184 | self.out_channel = out_channel
185 | self.upsample = upsample
186 | self.downsample = downsample
187 |
188 | if upsample:
189 | factor = 2
190 | p = (len(blur_kernel) - factor) - (kernel_size - 1)
191 | pad0 = (p + 1) // 2 + factor - 1
192 | pad1 = p // 2 + 1
193 |
194 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
195 |
196 | if downsample:
197 | factor = 2
198 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
199 | pad0 = (p + 1) // 2
200 | pad1 = p // 2
201 |
202 | self.blur = Blur(blur_kernel, pad=(pad0, pad1))
203 |
204 | fan_in = in_channel * kernel_size ** 2
205 | self.scale = 1 / math.sqrt(fan_in)
206 | self.padding = kernel_size // 2
207 |
208 | self.weight = nn.Parameter(
209 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
210 | )
211 |
212 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
213 |
214 | self.demodulate = demodulate
215 | self.fused = fused
216 |
217 | def __repr__(self):
218 | return (
219 | f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
220 | f"upsample={self.upsample}, downsample={self.downsample})"
221 | )
222 |
223 | def forward(self, input, style):
224 | batch, in_channel, height, width = input.shape
225 |
226 | if not self.fused:
227 | weight = self.scale * self.weight.squeeze(0)
228 | style = self.modulation(style)
229 |
230 | if self.demodulate:
231 | w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
232 | dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
233 |
234 | input = input * style.reshape(batch, in_channel, 1, 1)
235 |
236 | if self.upsample:
237 | weight = weight.transpose(0, 1)
238 | out = conv2d_gradfix.conv_transpose2d(
239 | input, weight, padding=0, stride=2
240 | )
241 | out = self.blur(out)
242 |
243 | elif self.downsample:
244 | input = self.blur(input)
245 | out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
246 |
247 | else:
248 | out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
249 |
250 | if self.demodulate:
251 | out = out * dcoefs.view(batch, -1, 1, 1)
252 |
253 | return out
254 |
255 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
256 | weight = self.scale * self.weight * style
257 |
258 | if self.demodulate:
259 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
260 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
261 |
262 | weight = weight.view(
263 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
264 | )
265 |
266 | if self.upsample:
267 | input = input.view(1, batch * in_channel, height, width)
268 | weight = weight.view(
269 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
270 | )
271 | weight = weight.transpose(1, 2).reshape(
272 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
273 | )
274 | out = conv2d_gradfix.conv_transpose2d(
275 | input, weight, padding=0, stride=2, groups=batch
276 | )
277 | _, _, height, width = out.shape
278 | out = out.view(batch, self.out_channel, height, width)
279 | out = self.blur(out)
280 |
281 | elif self.downsample:
282 | input = self.blur(input)
283 | _, _, height, width = input.shape
284 | input = input.view(1, batch * in_channel, height, width)
285 | out = conv2d_gradfix.conv2d(
286 | input, weight, padding=0, stride=2, groups=batch
287 | )
288 | _, _, height, width = out.shape
289 | out = out.view(batch, self.out_channel, height, width)
290 |
291 | else:
292 | input = input.view(1, batch * in_channel, height, width)
293 | out = conv2d_gradfix.conv2d(
294 | input, weight, padding=self.padding, groups=batch
295 | )
296 | _, _, height, width = out.shape
297 | out = out.view(batch, self.out_channel, height, width)
298 |
299 | return out
300 |
301 |
302 | class NoiseInjection(nn.Module):
303 | def __init__(self):
304 | super().__init__()
305 |
306 | self.weight = nn.Parameter(torch.zeros(1))
307 |
308 | def forward(self, image, noise=None):
309 | if noise is None:
310 | batch, _, height, width = image.shape
311 | noise = image.new_empty(batch, 1, height, width).normal_()
312 |
313 | return image + self.weight * noise
314 |
315 |
316 | class ConstantInput(nn.Module):
317 | def __init__(self, channel, size=4):
318 | super().__init__()
319 |
320 | self.input = nn.Parameter(torch.randn(1, channel, size, size))
321 |
322 | def forward(self, input):
323 | batch = input.shape[0]
324 | out = self.input.repeat(batch, 1, 1, 1)
325 |
326 | return out
327 |
328 |
329 | class StyledConv(nn.Module):
330 | def __init__(
331 | self,
332 | in_channel,
333 | out_channel,
334 | kernel_size,
335 | style_dim,
336 | upsample=False,
337 | blur_kernel=[1, 3, 3, 1],
338 | demodulate=True,
339 | ):
340 | super().__init__()
341 |
342 | self.conv = ModulatedConv2d(
343 | in_channel,
344 | out_channel,
345 | kernel_size,
346 | style_dim,
347 | upsample=upsample,
348 | blur_kernel=blur_kernel,
349 | demodulate=demodulate,
350 | )
351 |
352 | self.noise = NoiseInjection()
353 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
354 | # self.activate = ScaledLeakyReLU(0.2)
355 | self.activate = FusedLeakyReLU(out_channel)
356 |
357 | def forward(self, input, style, noise=None):
358 | out = self.conv(input, style)
359 | out = self.noise(out, noise=noise)
360 | # out = out + self.bias
361 | out = self.activate(out)
362 |
363 | return out
364 |
365 |
366 | class ToRGB(nn.Module):
367 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
368 | super().__init__()
369 |
370 | if upsample:
371 | self.upsample = Upsample(blur_kernel)
372 |
373 | self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
374 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
375 |
376 | def forward(self, input, style, skip=None):
377 | out = self.conv(input, style)
378 | out = out + self.bias
379 |
380 | if skip is not None:
381 | skip = self.upsample(skip)
382 |
383 | out = out + skip
384 |
385 | return out
386 |
387 |
388 | class Generator(nn.Module):
389 | def __init__(
390 | self,
391 | size,
392 | style_dim,
393 | n_mlp,
394 | channel_multiplier=2,
395 | blur_kernel=[1, 3, 3, 1],
396 | lr_mlp=0.01,
397 | ):
398 | super().__init__()
399 |
400 | self.size = size
401 |
402 | self.style_dim = style_dim
403 |
404 | layers = [PixelNorm()]
405 |
406 | for i in range(n_mlp):
407 | layers.append(
408 | EqualLinear(
409 | style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
410 | )
411 | )
412 |
413 | self.style = nn.Sequential(*layers)
414 |
415 | self.channels = {
416 | 4: 512,
417 | 8: 512,
418 | 16: 512,
419 | 32: 512,
420 | 64: 256 * channel_multiplier,
421 | 128: 128 * channel_multiplier,
422 | 256: 64 * channel_multiplier,
423 | 512: 32 * channel_multiplier,
424 | 1024: 16 * channel_multiplier,
425 | }
426 |
427 | self.input = ConstantInput(self.channels[4])
428 | self.conv1 = StyledConv(
429 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
430 | )
431 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
432 |
433 | self.log_size = int(math.log(size, 2))
434 | self.num_layers = (self.log_size - 2) * 2 + 1
435 |
436 | self.convs = nn.ModuleList()
437 | self.upsamples = nn.ModuleList()
438 | self.to_rgbs = nn.ModuleList()
439 | self.noises = nn.Module()
440 |
441 | in_channel = self.channels[4]
442 |
443 | for layer_idx in range(self.num_layers):
444 | res = (layer_idx + 5) // 2
445 | shape = [1, 1, 2 ** res, 2 ** res]
446 | self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
447 |
448 | for i in range(3, self.log_size + 1):
449 | out_channel = self.channels[2 ** i]
450 |
451 | self.convs.append(
452 | StyledConv(
453 | in_channel,
454 | out_channel,
455 | 3,
456 | style_dim,
457 | upsample=True,
458 | blur_kernel=blur_kernel,
459 | )
460 | )
461 |
462 | self.convs.append(
463 | StyledConv(
464 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
465 | )
466 | )
467 |
468 | self.to_rgbs.append(ToRGB(out_channel, style_dim))
469 |
470 | in_channel = out_channel
471 |
472 | self.n_latent = self.log_size * 2 - 2
473 |
474 | def make_noise(self):
475 | device = self.input.input.device
476 |
477 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
478 |
479 | for i in range(3, self.log_size + 1):
480 | for _ in range(2):
481 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
482 |
483 | return noises
484 |
485 | def mean_latent(self, n_latent):
486 | latent_in = torch.randn(
487 | n_latent, self.style_dim, device=self.input.input.device
488 | )
489 | latent = self.style(latent_in).mean(0, keepdim=True)
490 |
491 | return latent
492 |
493 | def get_latent(self, input):
494 | return self.style(input)
495 |
496 | def forward(
497 | self,
498 | styles,
499 | return_latents=False,
500 | inject_index=None,
501 | truncation=1,
502 | truncation_latent=None,
503 | input_is_latent=False,
504 | noise=None,
505 | randomize_noise=True,
506 | ):
507 | if not input_is_latent:
508 | styles = [self.style(s) for s in styles]
509 |
510 | if noise is None:
511 | if randomize_noise:
512 | noise = [None] * self.num_layers
513 | else:
514 | noise = [
515 | getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
516 | ]
517 |
518 | if truncation < 1:
519 | style_t = []
520 |
521 | for style in styles:
522 | style_t.append(
523 | truncation_latent + truncation * (style - truncation_latent)
524 | )
525 |
526 | styles = style_t
527 |
528 | if len(styles) < 2:
529 | inject_index = self.n_latent
530 |
531 | if styles[0].ndim < 3:
532 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
533 |
534 | else:
535 | latent = styles[0]
536 |
537 | else:
538 | if inject_index is None:
539 | inject_index = random.randint(1, self.n_latent - 1)
540 |
541 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
542 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
543 |
544 | latent = torch.cat([latent, latent2], 1)
545 |
546 | out = self.input(latent)
547 | out = self.conv1(out, latent[:, 0], noise=noise[0])
548 |
549 | skip = self.to_rgb1(out, latent[:, 1])
550 |
551 | i = 1
552 | for conv1, conv2, noise1, noise2, to_rgb in zip(
553 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
554 | ):
555 | out = conv1(out, latent[:, i], noise=noise1)
556 | out = conv2(out, latent[:, i + 1], noise=noise2)
557 | skip = to_rgb(out, latent[:, i + 2], skip)
558 |
559 | i += 2
560 |
561 | image = skip
562 |
563 | if return_latents:
564 | return image, latent
565 |
566 | else:
567 | return image, None
568 |
569 |
570 | class ConvLayer(nn.Sequential):
571 | def __init__(
572 | self,
573 | in_channel,
574 | out_channel,
575 | kernel_size,
576 | downsample=False,
577 | blur_kernel=[1, 3, 3, 1],
578 | bias=True,
579 | activate=True,
580 | ):
581 | layers = []
582 |
583 | if downsample:
584 | factor = 2
585 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
586 | pad0 = (p + 1) // 2
587 | pad1 = p // 2
588 |
589 | layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
590 |
591 | stride = 2
592 | self.padding = 0
593 |
594 | else:
595 | stride = 1
596 | self.padding = kernel_size // 2
597 |
598 | layers.append(
599 | EqualConv2d(
600 | in_channel,
601 | out_channel,
602 | kernel_size,
603 | padding=self.padding,
604 | stride=stride,
605 | bias=bias and not activate,
606 | )
607 | )
608 |
609 | if activate:
610 | layers.append(FusedLeakyReLU(out_channel, bias=bias))
611 |
612 | super().__init__(*layers)
613 |
614 |
615 | class ResBlock(nn.Module):
616 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
617 | super().__init__()
618 |
619 | self.conv1 = ConvLayer(in_channel, in_channel, 3)
620 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
621 |
622 | self.skip = ConvLayer(
623 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False
624 | )
625 |
626 | def forward(self, input):
627 | out = self.conv1(input)
628 | out = self.conv2(out)
629 |
630 | skip = self.skip(input)
631 | out = (out + skip) / math.sqrt(2)
632 |
633 | return out
634 |
635 |
636 | class Discriminator(nn.Module):
637 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
638 | super().__init__()
639 |
640 | channels = {
641 | 4: 512,
642 | 8: 512,
643 | 16: 512,
644 | 32: 512,
645 | 64: 256 * channel_multiplier,
646 | 128: 128 * channel_multiplier,
647 | 256: 64 * channel_multiplier,
648 | 512: 32 * channel_multiplier,
649 | 1024: 16 * channel_multiplier,
650 | }
651 |
652 | convs = [ConvLayer(3, channels[size], 1)]
653 |
654 | log_size = int(math.log(size, 2))
655 |
656 | in_channel = channels[size]
657 |
658 | for i in range(log_size, 2, -1):
659 | out_channel = channels[2 ** (i - 1)]
660 |
661 | convs.append(ResBlock(in_channel, out_channel, blur_kernel))
662 |
663 | in_channel = out_channel
664 |
665 | self.convs = nn.Sequential(*convs)
666 |
667 | self.stddev_group = 4
668 | self.stddev_feat = 1
669 |
670 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
671 | self.final_linear = nn.Sequential(
672 | EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
673 | EqualLinear(channels[4], 1),
674 | )
675 |
676 | def forward(self, input):
677 | out = self.convs(input)
678 |
679 | batch, channel, height, width = out.shape
680 | group = min(batch, self.stddev_group)
681 | stddev = out.view(
682 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
683 | )
684 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
685 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
686 | stddev = stddev.repeat(group, 1, height, width)
687 | out = torch.cat([out, stddev], 1)
688 |
689 | out = self.final_conv(out)
690 |
691 | out = out.view(batch, -1)
692 | out = self.final_linear(out)
693 |
694 | return out
695 |
696 |
--------------------------------------------------------------------------------
/stylegan2/op/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 | from .upfirdn2d import upfirdn2d
3 |
--------------------------------------------------------------------------------
/stylegan2/op/conv2d_gradfix.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import warnings
3 |
4 | import torch
5 | from torch import autograd
6 | from torch.nn import functional as F
7 |
8 | enabled = True
9 | weight_gradients_disabled = False
10 |
11 |
12 | @contextlib.contextmanager
13 | def no_weight_gradients():
14 | global weight_gradients_disabled
15 |
16 | old = weight_gradients_disabled
17 | weight_gradients_disabled = True
18 | yield
19 | weight_gradients_disabled = old
20 |
21 |
22 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23 | if could_use_op(input):
24 | return conv2d_gradfix(
25 | transpose=False,
26 | weight_shape=weight.shape,
27 | stride=stride,
28 | padding=padding,
29 | output_padding=0,
30 | dilation=dilation,
31 | groups=groups,
32 | ).apply(input, weight, bias)
33 |
34 | return F.conv2d(
35 | input=input,
36 | weight=weight,
37 | bias=bias,
38 | stride=stride,
39 | padding=padding,
40 | dilation=dilation,
41 | groups=groups,
42 | )
43 |
44 |
45 | def conv_transpose2d(
46 | input,
47 | weight,
48 | bias=None,
49 | stride=1,
50 | padding=0,
51 | output_padding=0,
52 | groups=1,
53 | dilation=1,
54 | ):
55 | if could_use_op(input):
56 | return conv2d_gradfix(
57 | transpose=True,
58 | weight_shape=weight.shape,
59 | stride=stride,
60 | padding=padding,
61 | output_padding=output_padding,
62 | groups=groups,
63 | dilation=dilation,
64 | ).apply(input, weight, bias)
65 |
66 | return F.conv_transpose2d(
67 | input=input,
68 | weight=weight,
69 | bias=bias,
70 | stride=stride,
71 | padding=padding,
72 | output_padding=output_padding,
73 | dilation=dilation,
74 | groups=groups,
75 | )
76 |
77 |
78 | def could_use_op(input):
79 | if (not enabled) or (not torch.backends.cudnn.enabled):
80 | return False
81 |
82 | if input.device.type != "cuda":
83 | return False
84 |
85 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86 | return True
87 |
88 | warnings.warn(
89 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90 | )
91 |
92 | return False
93 |
94 |
95 | def ensure_tuple(xs, ndim):
96 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97 |
98 | return xs
99 |
100 |
101 | conv2d_gradfix_cache = dict()
102 |
103 |
104 | def conv2d_gradfix(
105 | transpose, weight_shape, stride, padding, output_padding, dilation, groups
106 | ):
107 | ndim = 2
108 | weight_shape = tuple(weight_shape)
109 | stride = ensure_tuple(stride, ndim)
110 | padding = ensure_tuple(padding, ndim)
111 | output_padding = ensure_tuple(output_padding, ndim)
112 | dilation = ensure_tuple(dilation, ndim)
113 |
114 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115 | if key in conv2d_gradfix_cache:
116 | return conv2d_gradfix_cache[key]
117 |
118 | common_kwargs = dict(
119 | stride=stride, padding=padding, dilation=dilation, groups=groups
120 | )
121 |
122 | def calc_output_padding(input_shape, output_shape):
123 | if transpose:
124 | return [0, 0]
125 |
126 | return [
127 | input_shape[i + 2]
128 | - (output_shape[i + 2] - 1) * stride[i]
129 | - (1 - 2 * padding[i])
130 | - dilation[i] * (weight_shape[i + 2] - 1)
131 | for i in range(ndim)
132 | ]
133 |
134 | class Conv2d(autograd.Function):
135 | @staticmethod
136 | def forward(ctx, input, weight, bias):
137 | if not transpose:
138 | out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139 |
140 | else:
141 | out = F.conv_transpose2d(
142 | input=input,
143 | weight=weight,
144 | bias=bias,
145 | output_padding=output_padding,
146 | **common_kwargs,
147 | )
148 |
149 | ctx.save_for_backward(input, weight)
150 |
151 | return out
152 |
153 | @staticmethod
154 | def backward(ctx, grad_output):
155 | input, weight = ctx.saved_tensors
156 | grad_input, grad_weight, grad_bias = None, None, None
157 |
158 | if ctx.needs_input_grad[0]:
159 | p = calc_output_padding(
160 | input_shape=input.shape, output_shape=grad_output.shape
161 | )
162 | grad_input = conv2d_gradfix(
163 | transpose=(not transpose),
164 | weight_shape=weight_shape,
165 | output_padding=p,
166 | **common_kwargs,
167 | ).apply(grad_output, weight, None)
168 |
169 | if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170 | grad_weight = Conv2dGradWeight.apply(grad_output, input)
171 |
172 | if ctx.needs_input_grad[2]:
173 | grad_bias = grad_output.sum((0, 2, 3))
174 |
175 | return grad_input, grad_weight, grad_bias
176 |
177 | class Conv2dGradWeight(autograd.Function):
178 | @staticmethod
179 | def forward(ctx, grad_output, input):
180 | op = torch._C._jit_get_operation(
181 | "aten::cudnn_convolution_backward_weight"
182 | if not transpose
183 | else "aten::cudnn_convolution_transpose_backward_weight"
184 | )
185 | flags = [
186 | torch.backends.cudnn.benchmark,
187 | torch.backends.cudnn.deterministic,
188 | torch.backends.cudnn.allow_tf32,
189 | ]
190 | grad_weight = op(
191 | weight_shape,
192 | grad_output,
193 | input,
194 | padding,
195 | stride,
196 | dilation,
197 | groups,
198 | *flags,
199 | )
200 | ctx.save_for_backward(grad_output, input)
201 |
202 | return grad_weight
203 |
204 | @staticmethod
205 | def backward(ctx, grad_grad_weight):
206 | grad_output, input = ctx.saved_tensors
207 | grad_grad_output, grad_grad_input = None, None
208 |
209 | if ctx.needs_input_grad[0]:
210 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211 |
212 | if ctx.needs_input_grad[1]:
213 | p = calc_output_padding(
214 | input_shape=input.shape, output_shape=grad_output.shape
215 | )
216 | grad_grad_input = conv2d_gradfix(
217 | transpose=(not transpose),
218 | weight_shape=weight_shape,
219 | output_padding=p,
220 | **common_kwargs,
221 | ).apply(grad_output, grad_grad_weight, None)
222 |
223 | return grad_grad_output, grad_grad_input
224 |
225 | conv2d_gradfix_cache[key] = Conv2d
226 |
227 | return Conv2d
228 |
--------------------------------------------------------------------------------
/stylegan2/op/fused_act.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 | from torch.autograd import Function
7 | from torch.utils.cpp_extension import load
8 |
9 |
10 | module_path = os.path.dirname(__file__)
11 | fused = load(
12 | "fused",
13 | sources=[
14 | os.path.join(module_path, "fused_bias_act.cpp"),
15 | os.path.join(module_path, "fused_bias_act_kernel.cu"),
16 | ],
17 | )
18 |
19 |
20 | class FusedLeakyReLUFunctionBackward(Function):
21 | @staticmethod
22 | def forward(ctx, grad_output, out, bias, negative_slope, scale):
23 | ctx.save_for_backward(out)
24 | ctx.negative_slope = negative_slope
25 | ctx.scale = scale
26 |
27 | empty = grad_output.new_empty(0)
28 |
29 | grad_input = fused.fused_bias_act(
30 | grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
31 | )
32 |
33 | dim = [0]
34 |
35 | if grad_input.ndim > 2:
36 | dim += list(range(2, grad_input.ndim))
37 |
38 | if bias:
39 | grad_bias = grad_input.sum(dim).detach()
40 |
41 | else:
42 | grad_bias = empty
43 |
44 | return grad_input, grad_bias
45 |
46 | @staticmethod
47 | def backward(ctx, gradgrad_input, gradgrad_bias):
48 | out, = ctx.saved_tensors
49 | gradgrad_out = fused.fused_bias_act(
50 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
51 | )
52 |
53 | return gradgrad_out, None, None, None, None
54 |
55 |
56 | class FusedLeakyReLUFunction(Function):
57 | @staticmethod
58 | def forward(ctx, input, bias, negative_slope, scale):
59 | empty = input.new_empty(0)
60 |
61 | ctx.bias = bias is not None
62 |
63 | if bias is None:
64 | bias = empty
65 |
66 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
67 | ctx.save_for_backward(out)
68 | ctx.negative_slope = negative_slope
69 | ctx.scale = scale
70 |
71 | return out
72 |
73 | @staticmethod
74 | def backward(ctx, grad_output):
75 | out, = ctx.saved_tensors
76 |
77 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
78 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
79 | )
80 |
81 | if not ctx.bias:
82 | grad_bias = None
83 |
84 | return grad_input, grad_bias, None, None
85 |
86 |
87 | class FusedLeakyReLU(nn.Module):
88 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
89 | super().__init__()
90 |
91 | if bias:
92 | self.bias = nn.Parameter(torch.zeros(channel))
93 |
94 | else:
95 | self.bias = None
96 |
97 | self.negative_slope = negative_slope
98 | self.scale = scale
99 |
100 | def forward(self, input):
101 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
102 |
103 |
104 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
105 | if input.device.type == "cpu":
106 | if bias is not None:
107 | rest_dim = [1] * (input.ndim - bias.ndim - 1)
108 | return (
109 | F.leaky_relu(
110 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
111 | )
112 | * scale
113 | )
114 |
115 | else:
116 | return F.leaky_relu(input, negative_slope=0.2) * scale
117 |
118 | else:
119 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
120 |
--------------------------------------------------------------------------------
/stylegan2/op/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 |
2 | #include
3 | #include
4 |
5 | torch::Tensor fused_bias_act_op(const torch::Tensor &input,
6 | const torch::Tensor &bias,
7 | const torch::Tensor &refer, int act, int grad,
8 | float alpha, float scale);
9 |
10 | #define CHECK_CUDA(x) \
11 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
12 | #define CHECK_CONTIGUOUS(x) \
13 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
14 | #define CHECK_INPUT(x) \
15 | CHECK_CUDA(x); \
16 | CHECK_CONTIGUOUS(x)
17 |
18 | torch::Tensor fused_bias_act(const torch::Tensor &input,
19 | const torch::Tensor &bias,
20 | const torch::Tensor &refer, int act, int grad,
21 | float alpha, float scale) {
22 | CHECK_INPUT(input);
23 | CHECK_INPUT(bias);
24 |
25 | at::DeviceGuard guard(input.device());
26 |
27 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
28 | }
29 |
30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
31 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
32 | }
--------------------------------------------------------------------------------
/stylegan2/op/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 |
15 | #include
16 | #include
17 |
18 | template
19 | static __global__ void
20 | fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
21 | const scalar_t *p_ref, int act, int grad, scalar_t alpha,
22 | scalar_t scale, int loop_x, int size_x, int step_b,
23 | int size_b, int use_bias, int use_ref) {
24 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
25 |
26 | scalar_t zero = 0.0;
27 |
28 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
29 | loop_idx++, xi += blockDim.x) {
30 | scalar_t x = p_x[xi];
31 |
32 | if (use_bias) {
33 | x += p_b[(xi / step_b) % size_b];
34 | }
35 |
36 | scalar_t ref = use_ref ? p_ref[xi] : zero;
37 |
38 | scalar_t y;
39 |
40 | switch (act * 10 + grad) {
41 | default:
42 | case 10:
43 | y = x;
44 | break;
45 | case 11:
46 | y = x;
47 | break;
48 | case 12:
49 | y = 0.0;
50 | break;
51 |
52 | case 30:
53 | y = (x > 0.0) ? x : x * alpha;
54 | break;
55 | case 31:
56 | y = (ref > 0.0) ? x : x * alpha;
57 | break;
58 | case 32:
59 | y = 0.0;
60 | break;
61 | }
62 |
63 | out[xi] = y * scale;
64 | }
65 | }
66 |
67 | torch::Tensor fused_bias_act_op(const torch::Tensor &input,
68 | const torch::Tensor &bias,
69 | const torch::Tensor &refer, int act, int grad,
70 | float alpha, float scale) {
71 | int curDevice = -1;
72 | cudaGetDevice(&curDevice);
73 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
74 |
75 | auto x = input.contiguous();
76 | auto b = bias.contiguous();
77 | auto ref = refer.contiguous();
78 |
79 | int use_bias = b.numel() ? 1 : 0;
80 | int use_ref = ref.numel() ? 1 : 0;
81 |
82 | int size_x = x.numel();
83 | int size_b = b.numel();
84 | int step_b = 1;
85 |
86 | for (int i = 1 + 1; i < x.dim(); i++) {
87 | step_b *= x.size(i);
88 | }
89 |
90 | int loop_x = 4;
91 | int block_size = 4 * 32;
92 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
93 |
94 | auto y = torch::empty_like(x);
95 |
96 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
97 | x.scalar_type(), "fused_bias_act_kernel", [&] {
98 | fused_bias_act_kernel<<>>(
99 | y.data_ptr(), x.data_ptr(),
100 | b.data_ptr(), ref.data_ptr(), act, grad, alpha,
101 | scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
102 | });
103 |
104 | return y;
105 | }
--------------------------------------------------------------------------------
/stylegan2/op/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | torch::Tensor upfirdn2d_op(const torch::Tensor &input,
5 | const torch::Tensor &kernel, int up_x, int up_y,
6 | int down_x, int down_y, int pad_x0, int pad_x1,
7 | int pad_y0, int pad_y1);
8 |
9 | #define CHECK_CUDA(x) \
10 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11 | #define CHECK_CONTIGUOUS(x) \
12 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
13 | #define CHECK_INPUT(x) \
14 | CHECK_CUDA(x); \
15 | CHECK_CONTIGUOUS(x)
16 |
17 | torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
18 | int up_x, int up_y, int down_x, int down_y, int pad_x0,
19 | int pad_x1, int pad_y0, int pad_y1) {
20 | CHECK_INPUT(input);
21 | CHECK_INPUT(kernel);
22 |
23 | at::DeviceGuard guard(input.device());
24 |
25 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
26 | pad_y0, pad_y1);
27 | }
28 |
29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
30 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
31 | }
--------------------------------------------------------------------------------
/stylegan2/op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | from collections import abc
2 | import os
3 |
4 | import torch
5 | from torch.nn import functional as F
6 | from torch.autograd import Function
7 | from torch.utils.cpp_extension import load
8 |
9 |
10 | module_path = os.path.dirname(__file__)
11 | upfirdn2d_op = load(
12 | "upfirdn2d",
13 | sources=[
14 | os.path.join(module_path, "upfirdn2d.cpp"),
15 | os.path.join(module_path, "upfirdn2d_kernel.cu"),
16 | ],
17 | )
18 |
19 |
20 | class UpFirDn2dBackward(Function):
21 | @staticmethod
22 | def forward(
23 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
24 | ):
25 |
26 | up_x, up_y = up
27 | down_x, down_y = down
28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
29 |
30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
31 |
32 | grad_input = upfirdn2d_op.upfirdn2d(
33 | grad_output,
34 | grad_kernel,
35 | down_x,
36 | down_y,
37 | up_x,
38 | up_y,
39 | g_pad_x0,
40 | g_pad_x1,
41 | g_pad_y0,
42 | g_pad_y1,
43 | )
44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
45 |
46 | ctx.save_for_backward(kernel)
47 |
48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
49 |
50 | ctx.up_x = up_x
51 | ctx.up_y = up_y
52 | ctx.down_x = down_x
53 | ctx.down_y = down_y
54 | ctx.pad_x0 = pad_x0
55 | ctx.pad_x1 = pad_x1
56 | ctx.pad_y0 = pad_y0
57 | ctx.pad_y1 = pad_y1
58 | ctx.in_size = in_size
59 | ctx.out_size = out_size
60 |
61 | return grad_input
62 |
63 | @staticmethod
64 | def backward(ctx, gradgrad_input):
65 | kernel, = ctx.saved_tensors
66 |
67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
68 |
69 | gradgrad_out = upfirdn2d_op.upfirdn2d(
70 | gradgrad_input,
71 | kernel,
72 | ctx.up_x,
73 | ctx.up_y,
74 | ctx.down_x,
75 | ctx.down_y,
76 | ctx.pad_x0,
77 | ctx.pad_x1,
78 | ctx.pad_y0,
79 | ctx.pad_y1,
80 | )
81 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
82 | gradgrad_out = gradgrad_out.view(
83 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
84 | )
85 |
86 | return gradgrad_out, None, None, None, None, None, None, None, None
87 |
88 |
89 | class UpFirDn2d(Function):
90 | @staticmethod
91 | def forward(ctx, input, kernel, up, down, pad):
92 | up_x, up_y = up
93 | down_x, down_y = down
94 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
95 |
96 | kernel_h, kernel_w = kernel.shape
97 | batch, channel, in_h, in_w = input.shape
98 | ctx.in_size = input.shape
99 |
100 | input = input.reshape(-1, in_h, in_w, 1)
101 |
102 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
103 |
104 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
105 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
106 | ctx.out_size = (out_h, out_w)
107 |
108 | ctx.up = (up_x, up_y)
109 | ctx.down = (down_x, down_y)
110 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
111 |
112 | g_pad_x0 = kernel_w - pad_x0 - 1
113 | g_pad_y0 = kernel_h - pad_y0 - 1
114 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
115 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
116 |
117 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
118 |
119 | out = upfirdn2d_op.upfirdn2d(
120 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
121 | )
122 | # out = out.view(major, out_h, out_w, minor)
123 | out = out.view(-1, channel, out_h, out_w)
124 |
125 | return out
126 |
127 | @staticmethod
128 | def backward(ctx, grad_output):
129 | kernel, grad_kernel = ctx.saved_tensors
130 |
131 | grad_input = None
132 |
133 | if ctx.needs_input_grad[0]:
134 | grad_input = UpFirDn2dBackward.apply(
135 | grad_output,
136 | kernel,
137 | grad_kernel,
138 | ctx.up,
139 | ctx.down,
140 | ctx.pad,
141 | ctx.g_pad,
142 | ctx.in_size,
143 | ctx.out_size,
144 | )
145 |
146 | return grad_input, None, None, None, None
147 |
148 |
149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
150 | if not isinstance(up, abc.Iterable):
151 | up = (up, up)
152 |
153 | if not isinstance(down, abc.Iterable):
154 | down = (down, down)
155 |
156 | if len(pad) == 2:
157 | pad = (pad[0], pad[1], pad[0], pad[1])
158 |
159 | if input.device.type == "cpu":
160 | out = upfirdn2d_native(input, kernel, *up, *down, *pad)
161 |
162 | else:
163 | out = UpFirDn2d.apply(input, kernel, up, down, pad)
164 |
165 | return out
166 |
167 |
168 | def upfirdn2d_native(
169 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
170 | ):
171 | _, channel, in_h, in_w = input.shape
172 | input = input.reshape(-1, in_h, in_w, 1)
173 |
174 | _, in_h, in_w, minor = input.shape
175 | kernel_h, kernel_w = kernel.shape
176 |
177 | out = input.view(-1, in_h, 1, in_w, 1, minor)
178 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
179 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
180 |
181 | out = F.pad(
182 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
183 | )
184 | out = out[
185 | :,
186 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
187 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
188 | :,
189 | ]
190 |
191 | out = out.permute(0, 3, 1, 2)
192 | out = out.reshape(
193 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
194 | )
195 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
196 | out = F.conv2d(out, w)
197 | out = out.reshape(
198 | -1,
199 | minor,
200 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
201 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
202 | )
203 | out = out.permute(0, 2, 3, 1)
204 | out = out[:, ::down_y, ::down_x, :]
205 |
206 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
207 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
208 |
209 | return out.view(-1, channel, out_h, out_w)
210 |
--------------------------------------------------------------------------------
/stylegan2/op/upfirdn2d_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18 | int c = a / b;
19 |
20 | if (c * b > a) {
21 | c--;
22 | }
23 |
24 | return c;
25 | }
26 |
27 | struct UpFirDn2DKernelParams {
28 | int up_x;
29 | int up_y;
30 | int down_x;
31 | int down_y;
32 | int pad_x0;
33 | int pad_x1;
34 | int pad_y0;
35 | int pad_y1;
36 |
37 | int major_dim;
38 | int in_h;
39 | int in_w;
40 | int minor_dim;
41 | int kernel_h;
42 | int kernel_w;
43 | int out_h;
44 | int out_w;
45 | int loop_major;
46 | int loop_x;
47 | };
48 |
49 | template
50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51 | const scalar_t *kernel,
52 | const UpFirDn2DKernelParams p) {
53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54 | int out_y = minor_idx / p.minor_dim;
55 | minor_idx -= out_y * p.minor_dim;
56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57 | int major_idx_base = blockIdx.z * p.loop_major;
58 |
59 | if (out_x_base >= p.out_w || out_y >= p.out_h ||
60 | major_idx_base >= p.major_dim) {
61 | return;
62 | }
63 |
64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68 |
69 | for (int loop_major = 0, major_idx = major_idx_base;
70 | loop_major < p.loop_major && major_idx < p.major_dim;
71 | loop_major++, major_idx++) {
72 | for (int loop_x = 0, out_x = out_x_base;
73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78 |
79 | const scalar_t *x_p =
80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81 | minor_idx];
82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83 | int x_px = p.minor_dim;
84 | int k_px = -p.up_x;
85 | int x_py = p.in_w * p.minor_dim;
86 | int k_py = -p.up_y * p.kernel_w;
87 |
88 | scalar_t v = 0.0f;
89 |
90 | for (int y = 0; y < h; y++) {
91 | for (int x = 0; x < w; x++) {
92 | v += static_cast(*x_p) * static_cast(*k_p);
93 | x_p += x_px;
94 | k_p += k_px;
95 | }
96 |
97 | x_p += x_py - w * x_px;
98 | k_p += k_py - w * k_px;
99 | }
100 |
101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102 | minor_idx] = v;
103 | }
104 | }
105 | }
106 |
107 | template
109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110 | const scalar_t *kernel,
111 | const UpFirDn2DKernelParams p) {
112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114 |
115 | __shared__ volatile float sk[kernel_h][kernel_w];
116 | __shared__ volatile float sx[tile_in_h][tile_in_w];
117 |
118 | int minor_idx = blockIdx.x;
119 | int tile_out_y = minor_idx / p.minor_dim;
120 | minor_idx -= tile_out_y * p.minor_dim;
121 | tile_out_y *= tile_out_h;
122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123 | int major_idx_base = blockIdx.z * p.loop_major;
124 |
125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126 | major_idx_base >= p.major_dim) {
127 | return;
128 | }
129 |
130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131 | tap_idx += blockDim.x) {
132 | int ky = tap_idx / kernel_w;
133 | int kx = tap_idx - ky * kernel_w;
134 | scalar_t v = 0.0;
135 |
136 | if (kx < p.kernel_w & ky < p.kernel_h) {
137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138 | }
139 |
140 | sk[ky][kx] = v;
141 | }
142 |
143 | for (int loop_major = 0, major_idx = major_idx_base;
144 | loop_major < p.loop_major & major_idx < p.major_dim;
145 | loop_major++, major_idx++) {
146 | for (int loop_x = 0, tile_out_x = tile_out_x_base;
147 | loop_x < p.loop_x & tile_out_x < p.out_w;
148 | loop_x++, tile_out_x += tile_out_w) {
149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151 | int tile_in_x = floor_div(tile_mid_x, up_x);
152 | int tile_in_y = floor_div(tile_mid_y, up_y);
153 |
154 | __syncthreads();
155 |
156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157 | in_idx += blockDim.x) {
158 | int rel_in_y = in_idx / tile_in_w;
159 | int rel_in_x = in_idx - rel_in_y * tile_in_w;
160 | int in_x = rel_in_x + tile_in_x;
161 | int in_y = rel_in_y + tile_in_y;
162 |
163 | scalar_t v = 0.0;
164 |
165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167 | p.minor_dim +
168 | minor_idx];
169 | }
170 |
171 | sx[rel_in_y][rel_in_x] = v;
172 | }
173 |
174 | __syncthreads();
175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176 | out_idx += blockDim.x) {
177 | int rel_out_y = out_idx / tile_out_w;
178 | int rel_out_x = out_idx - rel_out_y * tile_out_w;
179 | int out_x = rel_out_x + tile_out_x;
180 | int out_y = rel_out_y + tile_out_y;
181 |
182 | int mid_x = tile_mid_x + rel_out_x * down_x;
183 | int mid_y = tile_mid_y + rel_out_y * down_y;
184 | int in_x = floor_div(mid_x, up_x);
185 | int in_y = floor_div(mid_y, up_y);
186 | int rel_in_x = in_x - tile_in_x;
187 | int rel_in_y = in_y - tile_in_y;
188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190 |
191 | scalar_t v = 0.0;
192 |
193 | #pragma unroll
194 | for (int y = 0; y < kernel_h / up_y; y++)
195 | #pragma unroll
196 | for (int x = 0; x < kernel_w / up_x; x++)
197 | v += sx[rel_in_y + y][rel_in_x + x] *
198 | sk[kernel_y + y * up_y][kernel_x + x * up_x];
199 |
200 | if (out_x < p.out_w & out_y < p.out_h) {
201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202 | minor_idx] = v;
203 | }
204 | }
205 | }
206 | }
207 | }
208 |
209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210 | const torch::Tensor &kernel, int up_x, int up_y,
211 | int down_x, int down_y, int pad_x0, int pad_x1,
212 | int pad_y0, int pad_y1) {
213 | int curDevice = -1;
214 | cudaGetDevice(&curDevice);
215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
216 |
217 | UpFirDn2DKernelParams p;
218 |
219 | auto x = input.contiguous();
220 | auto k = kernel.contiguous();
221 |
222 | p.major_dim = x.size(0);
223 | p.in_h = x.size(1);
224 | p.in_w = x.size(2);
225 | p.minor_dim = x.size(3);
226 | p.kernel_h = k.size(0);
227 | p.kernel_w = k.size(1);
228 | p.up_x = up_x;
229 | p.up_y = up_y;
230 | p.down_x = down_x;
231 | p.down_y = down_y;
232 | p.pad_x0 = pad_x0;
233 | p.pad_x1 = pad_x1;
234 | p.pad_y0 = pad_y0;
235 | p.pad_y1 = pad_y1;
236 |
237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238 | p.down_y;
239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240 | p.down_x;
241 |
242 | auto out =
243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244 |
245 | int mode = -1;
246 |
247 | int tile_out_h = -1;
248 | int tile_out_w = -1;
249 |
250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251 | p.kernel_h <= 4 && p.kernel_w <= 4) {
252 | mode = 1;
253 | tile_out_h = 16;
254 | tile_out_w = 64;
255 | }
256 |
257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258 | p.kernel_h <= 3 && p.kernel_w <= 3) {
259 | mode = 2;
260 | tile_out_h = 16;
261 | tile_out_w = 64;
262 | }
263 |
264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265 | p.kernel_h <= 4 && p.kernel_w <= 4) {
266 | mode = 3;
267 | tile_out_h = 16;
268 | tile_out_w = 64;
269 | }
270 |
271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272 | p.kernel_h <= 2 && p.kernel_w <= 2) {
273 | mode = 4;
274 | tile_out_h = 16;
275 | tile_out_w = 64;
276 | }
277 |
278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279 | p.kernel_h <= 4 && p.kernel_w <= 4) {
280 | mode = 5;
281 | tile_out_h = 8;
282 | tile_out_w = 32;
283 | }
284 |
285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286 | p.kernel_h <= 2 && p.kernel_w <= 2) {
287 | mode = 6;
288 | tile_out_h = 8;
289 | tile_out_w = 32;
290 | }
291 |
292 | dim3 block_size;
293 | dim3 grid_size;
294 |
295 | if (tile_out_h > 0 && tile_out_w > 0) {
296 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
297 | p.loop_x = 1;
298 | block_size = dim3(32 * 8, 1, 1);
299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301 | (p.major_dim - 1) / p.loop_major + 1);
302 | } else {
303 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
304 | p.loop_x = 4;
305 | block_size = dim3(4, 32, 1);
306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308 | (p.major_dim - 1) / p.loop_major + 1);
309 | }
310 |
311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312 | switch (mode) {
313 | case 1:
314 | upfirdn2d_kernel
315 | <<>>(out.data_ptr(),
316 | x.data_ptr(),
317 | k.data_ptr(), p);
318 |
319 | break;
320 |
321 | case 2:
322 | upfirdn2d_kernel
323 | <<>>(out.data_ptr(),
324 | x.data_ptr(),
325 | k.data_ptr(), p);
326 |
327 | break;
328 |
329 | case 3:
330 | upfirdn2d_kernel
331 | <<>>(out.data_ptr(),
332 | x.data_ptr(),
333 | k.data_ptr(), p);
334 |
335 | break;
336 |
337 | case 4:
338 | upfirdn2d_kernel
339 | <<>>(out.data_ptr(),
340 | x.data_ptr(),
341 | k.data_ptr(), p);
342 |
343 | break;
344 |
345 | case 5:
346 | upfirdn2d_kernel
347 | <<>>(out.data_ptr(),
348 | x.data_ptr(),
349 | k.data_ptr(), p);
350 |
351 | break;
352 |
353 | case 6:
354 | upfirdn2d_kernel
355 | <<>>(out.data_ptr(),
356 | x.data_ptr(),
357 | k.data_ptr(), p);
358 |
359 | break;
360 |
361 | default:
362 | upfirdn2d_kernel_large<<>>(
363 | out.data_ptr(), x.data_ptr(),
364 | k.data_ptr(), p);
365 | }
366 | });
367 |
368 | return out;
369 | }
--------------------------------------------------------------------------------
/weights/car/.placeholder:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OmerKafri/StyleFusion/c1aa57267a192b3c50ee501339ed50a0aedb35fe/weights/car/.placeholder
--------------------------------------------------------------------------------
/weights/car_weights.json:
--------------------------------------------------------------------------------
1 | {
2 | "all": "weights/car/all.pt",
3 | "car": "weights/car/car.pt",
4 | "background": "weights/car/background.pt"
5 | }
6 |
--------------------------------------------------------------------------------
/weights/church/.placeholder:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OmerKafri/StyleFusion/c1aa57267a192b3c50ee501339ed50a0aedb35fe/weights/church/.placeholder
--------------------------------------------------------------------------------
/weights/church_weights.json:
--------------------------------------------------------------------------------
1 | {
2 | "all": "weights/church/all.pt"
3 | }
4 |
--------------------------------------------------------------------------------
/weights/ffhq/.placeholder:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OmerKafri/StyleFusion/c1aa57267a192b3c50ee501339ed50a0aedb35fe/weights/ffhq/.placeholder
--------------------------------------------------------------------------------
/weights/ffhq_weights.json:
--------------------------------------------------------------------------------
1 | {
2 | "all": "weights/ffhq/all.pt",
3 | "bg_hair_clothes": "weights/ffhq/bg_hair_clothes.pt",
4 | "bg_clothes": "weights/ffhq/bg_clothes.pt",
5 | "face": "weights/ffhq/face.pt",
6 | "skin_mouth": "weights/ffhq/skin_mouth.pt"
7 | }
8 |
--------------------------------------------------------------------------------