├── main.py ├── scripts └── publish.sh ├── document ├── Function.md ├── StyleGAN.md └── Projects.md ├── sgan ├── __init__.py ├── model.py ├── cache.py ├── utils.py ├── lreq.py └── net.py ├── .gitignore ├── setup.py ├── Readme.md ├── hubconf.py └── License /main.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/publish.sh: -------------------------------------------------------------------------------- 1 | python setup.py sdist upload 2 | python setup.py sdist -------------------------------------------------------------------------------- /document/Function.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ## reinitialize 5 | 6 | 7 | 8 | ## generate 9 | 10 | 11 | 12 | ## style_interpolate 13 | 14 | -------------------------------------------------------------------------------- /sgan/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import StyleGAN, generate, style_mix, style_interpolate, image_encode 2 | from .cache import LOADED_MODEL, reinitialize 3 | -------------------------------------------------------------------------------- /document/StyleGAN.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ## Attributes 5 | 6 | ### seed 7 | 8 | 9 | ## Methods 10 | 11 | ### output 12 | 13 | 14 | ### save 15 | 16 | 17 | ### show -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ide 2 | .vscode/ 3 | .idea/ 4 | *.iml 5 | 6 | # python 7 | stylegan_zoo.egg-info 8 | __pycache__ 9 | dist/ 10 | eval.py 11 | test.py 12 | 13 | # test 14 | *-gene.txt 15 | *-show.png 16 | 17 | # models 18 | *.pkl 19 | *.pth 20 | *.mat 21 | *.wxf -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name='stylegan_zoo', 5 | author='aster', 6 | author_email='galaster@foxmail.com', 7 | url='https://github.com/GalAster/StyleGAN-Zoo', 8 | version='0.14.0', 9 | description='none', 10 | 11 | packages=['sgan'], 12 | install_requires=[ 13 | # no pytorch 14 | 'matplotlib', 15 | 'numpy', 16 | 'wolframclient' 17 | ], 18 | ) 19 | -------------------------------------------------------------------------------- /document/Projects.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | https://generated.photos/ 6 | 7 | 8 | --- 9 | 10 | 11 | https://thisvesseldoesnotexist.com/#/about 12 | 13 | - Vessel 14 | 15 | - Qinghua(青花瓷) 16 | 17 | 18 | --- 19 | 20 | https://thissnackdoesnotexist.com/ 21 | 22 | --- 23 | 24 | https://thisairbnbdoesnotexist.com/ 25 | 26 | --- 27 | 28 | https://thisrentaldoesnotexist.com/ 29 | 30 | 31 | 32 | https://github.com/xiong-jie-y/kawaii_girl_generator 33 | 34 | 35 | 36 | https://github.com/ak9250/stylegan-art 37 | 38 | 39 | 40 | https://github.com/maxbbraun/eboygan 41 | 42 | 43 | https://github.com/parameter-pollution/stylegan_paintings 44 | 45 | 46 | https://www.gwern.net/Faces -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | StyleGAN Zoo 2 | ============ 3 | [](https://colab.research.google.com/drive/1HHfyYCfnat4jhOnu34gqotRqzBiDeE_-) 4 | 5 | Base on https://github.com/podgorskiy/StyleGAN_Blobless 6 | 7 | Find models on https://github.com/GalAster/StyleGAN-Zoo/releases 8 | 9 | ## Install 10 | 11 | Pytorch needed, install via conda first 12 | 13 | ```sh 14 | conda install pytorch torchvision cudatoolkit=10.1 -c pytorch -y 15 | pip install stylegan_zoo 16 | ``` 17 | 18 | 19 | ## Start 20 | 21 | - For jupyter: 22 | 23 | ```python 24 | from sgan import StyleGAN 25 | a = StyleGAN('asuka') 26 | a.show() 27 | ``` 28 | 29 | ![](https://user-images.githubusercontent.com/17541209/71554236-b0813300-2a57-11ea-9ee4-fab29d592d9a.png) 30 | 31 | - For mathematica: 32 | 33 | ```python 34 | from sgan import StyleGAN 35 | 36 | StyleGAN('asuka') 37 | ``` 38 | 39 | ![](https://user-images.githubusercontent.com/17541209/71553454-c5a39500-2a4a-11ea-8513-7d9a475c4c46.png) 40 | 41 | - Multi-generation 42 | 43 | ```python 44 | from sgan import generate 45 | 46 | generate('asuka', 4) 47 | ``` 48 | 49 | ![](https://user-images.githubusercontent.com/17541209/71593157-df89c880-2b6d-11ea-8455-8dd4d2024671.png) 50 | 51 | - Style-interpolate 52 | 53 | ```python 54 | from sgan import generate, style_interpolate 55 | 56 | start, end = generate('asuka', 2) 57 | style_interpolate(start, end, steps=16) 58 | ``` 59 | 60 | ![](https://user-images.githubusercontent.com/17541209/71773895-45c48000-2fa0-11ea-8068-d7e5347a8233.png) 61 | 62 | 63 | ## License 64 | 65 | | Part | License | 66 | | :----------- | :----------------------------- | 67 | | Code | [Apache License Version 2.0]() | 68 | | [Asuka]() | [CC0 - Creative Commons]() | 69 | | [Horo]() | [CC0 - Creative Commons]() | 70 | | [Baby]() | [CC4.0 Non-Commercial]() | 71 | | [FFHQ]() | | 72 | | [CelebaHQ]() | | 73 | -------------------------------------------------------------------------------- /sgan/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Stanislav Pidhorskyi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import torch 17 | from torch import nn 18 | import random 19 | from sgan.net import Generator, Mapping 20 | import numpy as np 21 | 22 | 23 | class DLatent(nn.Module): 24 | def __init__(self, dlatent_size, layer_count): 25 | super(DLatent, self).__init__() 26 | buffer = torch.zeros(layer_count, dlatent_size, dtype=torch.float32) 27 | self.register_buffer('buff', buffer) 28 | 29 | 30 | class Model(nn.Module): 31 | def __init__( 32 | self, 33 | channels=3, 34 | mapping_layers=8, 35 | latent_size=512, 36 | 37 | startf=32, 38 | maxf=256, 39 | layer_count=3, 40 | 41 | dlatent_avg_beta=None, 42 | truncation_psi=None, 43 | truncation_cutoff=None, 44 | style_mixing_prob=None, 45 | mode='normal' 46 | ): 47 | super(Model, self).__init__() 48 | self.model = mode 49 | self.out_layer = layer_count - 1 50 | self.mapping = Mapping( 51 | num_layers=2 * layer_count, 52 | latent_size=latent_size, 53 | dlatent_size=latent_size, 54 | mapping_fmaps=latent_size, 55 | mapping_layers=mapping_layers 56 | ) 57 | 58 | self.generator = Generator( 59 | startf=startf, 60 | layer_count=layer_count, 61 | maxf=maxf, 62 | latent_size=latent_size, 63 | channels=channels 64 | ) 65 | 66 | self.dlatent_avg = DLatent(latent_size, self.mapping.num_layers) 67 | self.latent_size = latent_size 68 | self.dlatent_avg_beta = dlatent_avg_beta 69 | self.truncation_psi = truncation_psi 70 | self.style_mixing_prob = style_mixing_prob 71 | self.truncation_cutoff = truncation_cutoff 72 | 73 | def generate(self, lod, remove_blob=True, z=None, count=32): 74 | if z is None: 75 | z = torch.randn(count, self.latent_size) 76 | styles = self.mapping(z) 77 | 78 | if self.dlatent_avg_beta is not None: 79 | with torch.no_grad(): 80 | batch_avg = styles.mean(dim=0) 81 | self.dlatent_avg.buff.data.lerp_(batch_avg.data, 1.0 - self.dlatent_avg_beta) 82 | 83 | if self.style_mixing_prob is not None: 84 | if random.random() < self.style_mixing_prob: 85 | z2 = torch.randn(count, self.latent_size) 86 | styles2 = self.mapping(z2) 87 | 88 | layer_idx = torch.arange(self.mapping.num_layers)[np.newaxis, :, np.newaxis] 89 | cur_layers = (lod + 1) * 2 90 | mixing_cutoff = random.randint(1, cur_layers) 91 | styles = torch.where(layer_idx < mixing_cutoff, styles, styles2) 92 | 93 | if self.truncation_psi is not None: 94 | layer_idx = torch.arange(self.mapping.num_layers)[np.newaxis, :, np.newaxis] 95 | ones = torch.ones(layer_idx.shape, dtype=torch.float32) 96 | coefs = torch.where(layer_idx < self.truncation_cutoff, self.truncation_psi * ones, ones) 97 | styles = torch.lerp(self.dlatent_avg.buff.data, styles, coefs) 98 | 99 | rec = self.generator.forward(styles, lod, remove_blob, method=self.model) 100 | return rec 101 | 102 | def forward(self, x, lod, blend_factor, d_train): 103 | return self.generate(x, lod, blend_factor, d_train) 104 | -------------------------------------------------------------------------------- /sgan/cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import re 3 | 4 | from torch.hub import load as loading 5 | 6 | LOADED_MODEL = {} 7 | 8 | 9 | def get_model(name: str): 10 | m = re.sub('[-_ ]', '', name).lower() 11 | if m in LOADED_MODEL: 12 | return LOADED_MODEL[m] 13 | elif m == 'asuka': 14 | model = loading('GalAster/StyleGAN-Zoo', 'style_asuka', pretrained=True) 15 | LOADED_MODEL[m] = model 16 | return model 17 | elif m == 'horo': 18 | model = loading('GalAster/StyleGAN-Zoo', 'style_horo', pretrained=True) 19 | LOADED_MODEL[m] = model 20 | return model 21 | elif m == 'asashio': 22 | model = loading('GalAster/StyleGAN-Zoo', 'style_asashio', pretrained=True) 23 | LOADED_MODEL[m] = model 24 | return model 25 | elif m in ['anime', 'animehead']: 26 | model = loading('GalAster/StyleGAN-Zoo', 'style_anime_head', pretrained=True) 27 | LOADED_MODEL[m] = model 28 | return model 29 | elif m in ['animeface', 'animefacea']: 30 | model = loading('GalAster/StyleGAN-Zoo', 'style_anime_face_a', pretrained=True) 31 | LOADED_MODEL[m] = model 32 | return model 33 | elif m == 'animefaceb': 34 | model = loading('GalAster/StyleGAN-Zoo', 'style_anime_face_b', pretrained=True) 35 | LOADED_MODEL[m] = model 36 | return model 37 | elif m == 'animefacec': 38 | model = loading('GalAster/StyleGAN-Zoo', 'style_anime_face_c', pretrained=True) 39 | LOADED_MODEL[m] = model 40 | return model 41 | elif m == 'animefaced': 42 | model = loading('GalAster/StyleGAN-Zoo', 'style_anime_face_d', pretrained=True) 43 | LOADED_MODEL[m] = model 44 | return model 45 | elif m == 'animefacee': 46 | model = loading('GalAster/StyleGAN-Zoo', 'style_anime_face_e', pretrained=True) 47 | LOADED_MODEL[m] = model 48 | return model 49 | elif m == 'baby': 50 | model = loading('GalAster/StyleGAN-Zoo', 'style_baby', pretrained=True) 51 | LOADED_MODEL[m] = model 52 | return model 53 | elif m == 'wanghong': 54 | model = loading('GalAster/StyleGAN-Zoo', 'style_wanghong', pretrained=True) 55 | LOADED_MODEL[m] = model 56 | return model 57 | elif m == 'asianpeople': 58 | model = loading('GalAster/StyleGAN-Zoo', 'style_asian_people', pretrained=True) 59 | LOADED_MODEL[m] = model 60 | return model 61 | elif m in ['asian', 'asianstar']: 62 | model = loading('GalAster/StyleGAN-Zoo', 'style_asian_star', pretrained=True) 63 | LOADED_MODEL[m] = model 64 | return model 65 | elif m in ['star', 'superstar']: 66 | model = loading('GalAster/StyleGAN-Zoo', 'style_super_star', pretrained=True) 67 | LOADED_MODEL[m] = model 68 | return model 69 | elif m in ['art', 'arta']: 70 | model = loading('GalAster/StyleGAN-Zoo', 'style_art_a', pretrained=True) 71 | LOADED_MODEL[m] = model 72 | return model 73 | elif m == 'artb': 74 | model = loading('GalAster/StyleGAN-Zoo', 'style_art_b', pretrained=True) 75 | LOADED_MODEL[m] = model 76 | return model 77 | elif m in ['artc', 'ukiyoefaces']: 78 | model = loading('GalAster/StyleGAN-Zoo', 'style_ukiyoe_faces', pretrained=True) 79 | LOADED_MODEL[m] = model 80 | return model 81 | elif m == 'ffhq': 82 | model = loading('GalAster/StyleGAN-Zoo', 'style_ffhq', pretrained=True) 83 | LOADED_MODEL[m] = model 84 | return model 85 | elif m == 'celebahq': 86 | model = loading('GalAster/StyleGAN-Zoo', 'style_celeba_hq', pretrained=True) 87 | LOADED_MODEL[m] = model 88 | return model 89 | elif m == 'vessel': 90 | model = loading('GalAster/StyleGAN-Zoo', 'style_vessel', pretrained=True) 91 | LOADED_MODEL[m] = model 92 | return model 93 | elif m == 'qinghua': 94 | model = loading('GalAster/StyleGAN-Zoo', 'style_qinghua', pretrained=True) 95 | LOADED_MODEL[m] = model 96 | return model 97 | else: 98 | raise AttributeError() 99 | 100 | 101 | def reinitialize(model=None): 102 | global LOADED_MODEL 103 | LOADED_MODEL = {} 104 | torch.hub.list('GalAster/StyleGAN-Zoo', force_reload=True) 105 | if model is not None: 106 | # remove model 107 | pass 108 | -------------------------------------------------------------------------------- /sgan/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import matplotlib.pyplot as plt 4 | import torch 5 | 6 | from torchvision.transforms import ToPILImage 7 | from wolframclient.serializers.serializable import WLSerializable 8 | from torchvision.utils import save_image 9 | from numpy import savetxt 10 | from sgan.cache import get_model 11 | 12 | if torch.cuda.is_available(): 13 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 14 | DEFAULT_DEVICE = 'cuda' 15 | else: 16 | DEFAULT_DEVICE = 'cpu' 17 | 18 | 19 | class StyleGAN(WLSerializable): 20 | def __init__(self, method: str, gene=None, data=None): 21 | self.method = method 22 | self.data = data 23 | if gene is None: 24 | latents = torch.randn(1, 512) 25 | self.gene = torch.tensor(latents) 26 | else: 27 | self.gene = gene 28 | 29 | def output(self): 30 | if self.data is None: 31 | with torch.no_grad(): 32 | model = get_model(self.method) 33 | model.to(DEFAULT_DEVICE) 34 | self.data = model.generate(model.out_layer, z=self.gene) 35 | self.gene = self.gene.cpu() 36 | self.data = self.data.cpu() 37 | return self.data 38 | 39 | def show(self): 40 | img = self.output()[0].permute(1, 2, 0) 41 | plt.imshow(img.clamp(-1, 1) * 0.5 + 0.5) 42 | 43 | def save(self, path=None): 44 | img = self.output()[0] * 0.5 + 0.5 45 | src = self.gene 46 | name = str(src.__hash__()) 47 | d = '.' if path is None else path 48 | savetxt(os.path.join(d, name + '-gene.txt'), src, delimiter='\n') 49 | save_image(img, os.path.join(d, name + '-show.png')) 50 | 51 | def clean(self): 52 | self.data = None 53 | 54 | def to_wl(self): 55 | img = self.output()[0].clamp(-1, 1) 56 | return ToPILImage()(img * 0.5 + 0.5) 57 | 58 | @staticmethod 59 | def new(method, gene, data): 60 | return StyleGAN(method, gene=gene, data=data) 61 | 62 | 63 | def generate( 64 | method, num, 65 | save=None, 66 | batch_size=16 67 | ): 68 | # prepare model 69 | model = get_model(method) 70 | model.to(DEFAULT_DEVICE) 71 | # batch eval 72 | with torch.no_grad(): 73 | gene = [] 74 | data = [] 75 | for i in range(math.ceil(num / batch_size)): 76 | latents = torch.randn(batch_size, 512).to(DEFAULT_DEVICE) 77 | batch = model.generate(model.out_layer, z=latents) 78 | if save is None: 79 | gene.append(latents.cpu()) 80 | data.append(batch.cpu()) 81 | else: 82 | o = [StyleGAN.new(method, i.unsqueeze(0), j.unsqueeze(0)) for i, j in zip(latents.cpu(), batch.cpu())] 83 | for j in o: 84 | j.save(path=save) 85 | if save is None: 86 | gene = torch.cat(gene, dim=0) 87 | data = torch.cat(data, dim=0) 88 | o = [StyleGAN.new(method, i.unsqueeze(0), j.unsqueeze(0)) for i, j in zip(gene, data)] 89 | return o[:num] 90 | else: 91 | pass 92 | 93 | 94 | def as_tensor(o): 95 | if isinstance(o, StyleGAN): 96 | o.output() 97 | return o.gene 98 | else: 99 | return o 100 | 101 | 102 | def style_mix(model, genes, weights): 103 | pass 104 | 105 | 106 | def slerp(start, end, values): 107 | low_norm = start / torch.norm(start, dim=1, keepdim=True) 108 | high_norm = end / torch.norm(end, dim=1, keepdim=True) 109 | omega = torch.acos((low_norm * high_norm).sum(1)) 110 | so = torch.sin(omega) 111 | 112 | def interpolate(val): 113 | s = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * start 114 | e = (torch.sin(val * omega) / so).unsqueeze(1) * end 115 | return s + e 116 | 117 | return list(map(interpolate, values)) 118 | 119 | 120 | def style_interpolate( 121 | a, b, 122 | method=None, 123 | steps=24, 124 | batch_size=16, 125 | save=None, 126 | ): 127 | i = slerp(as_tensor(a), as_tensor(b), list(map(lambda x: x / (steps - 1), range(steps)))) 128 | i = torch.cat(i, dim=0).to(DEFAULT_DEVICE) 129 | if method is None and isinstance(a, StyleGAN): 130 | method = a.method 131 | elif method is None and isinstance(b, StyleGAN): 132 | method = b.method 133 | model = get_model(method) 134 | model.to(DEFAULT_DEVICE) 135 | with torch.no_grad(): 136 | result = model.generate(model.out_layer, z=i) 137 | o = [StyleGAN.new(method, i.unsqueeze(0), j.unsqueeze(0)) for i, j in zip(i.cpu(), result.cpu())] 138 | return o 139 | 140 | 141 | def model_settings( 142 | name: str, 143 | dlatent_avg_beta=None, 144 | truncation_psi=None, 145 | truncation_cutoff=None, 146 | style_mixing_prob=None, 147 | random_noise=None, 148 | ): 149 | model = get_model(name) 150 | if truncation_psi is not None: 151 | model.truncation_psi = truncation_psi 152 | if dlatent_avg_beta is not None: 153 | model.dlatent_avg_beta = dlatent_avg_beta 154 | if truncation_cutoff is not None: 155 | model.truncation_cutoff = truncation_cutoff 156 | if style_mixing_prob is not None: 157 | model.style_mixing_prob = style_mixing_prob 158 | 159 | 160 | def image_encode(): 161 | pass 162 | 163 | 164 | if __name__ == "__main__": 165 | # test for normal 166 | ''' 167 | t1 = StyleGAN('asuka') 168 | t1.show() 169 | t1.save('.') 170 | ''' 171 | # test for generate 172 | ''' 173 | t2 = generate('asuka', 5, batch_size=2) 174 | t3 = generate('asuka', 2, save='.') 175 | ''' 176 | # test for interpolate 177 | ''' 178 | t4, t5 = generate('asuka', 2) 179 | out = style_interpolate(t4, t5, steps=4) 180 | ''' 181 | -------------------------------------------------------------------------------- /sgan/lreq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Stanislav Pidhorskyi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import torch 17 | from torch import nn 18 | from torch.nn import functional as F 19 | from torch.nn import init 20 | from torch.nn.parameter import Parameter 21 | import numpy as np 22 | 23 | 24 | class Bool: 25 | def __init__(self): 26 | self.value = False 27 | 28 | def __bool__(self): 29 | return self.value 30 | 31 | __nonzero__ = __bool__ 32 | 33 | def set(self, value): 34 | self.value = value 35 | 36 | 37 | use_implicit_lreq = Bool() 38 | use_implicit_lreq.set(True) 39 | 40 | 41 | def is_sequence(arg): 42 | return (not hasattr(arg, "strip") and 43 | hasattr(arg, "__getitem__") or 44 | hasattr(arg, "__iter__")) 45 | 46 | 47 | def make_tuple(x, n): 48 | if is_sequence(x): 49 | return x 50 | return tuple([x for _ in range(n)]) 51 | 52 | 53 | class Linear(nn.Module): 54 | def __init__(self, in_features, out_features, bias=True, gain=np.sqrt(2.0), lrmul=1.0, 55 | implicit_lreq=use_implicit_lreq): 56 | super(Linear, self).__init__() 57 | self.in_features = in_features 58 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 59 | if bias: 60 | self.bias = Parameter(torch.Tensor(out_features)) 61 | else: 62 | self.register_parameter('bias', None) 63 | self.std = 0 64 | self.gain = gain 65 | self.lrmul = lrmul 66 | self.implicit_lreq = implicit_lreq 67 | self.reset_parameters() 68 | 69 | def reset_parameters(self): 70 | self.std = self.gain / np.sqrt(self.in_features) * self.lrmul 71 | if not self.implicit_lreq: 72 | init.normal_(self.weight, mean=0, std=1.0 / self.lrmul) 73 | else: 74 | init.normal_(self.weight, mean=0, std=self.std / self.lrmul) 75 | setattr(self.weight, 'lr_equalization_coef', self.std) 76 | if self.bias is not None: 77 | setattr(self.bias, 'lr_equalization_coef', self.lrmul) 78 | 79 | if self.bias is not None: 80 | with torch.no_grad(): 81 | self.bias.zero_() 82 | 83 | def forward(self, input): 84 | if not self.implicit_lreq: 85 | bias = self.bias 86 | if bias is not None: 87 | bias = bias * self.lrmul 88 | return F.linear(input, self.weight * self.std, bias) 89 | else: 90 | return F.linear(input, self.weight, self.bias) 91 | 92 | 93 | class Conv2d(nn.Module): 94 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, 95 | groups=1, bias=True, gain=np.sqrt(2.0), transpose=False, transform_kernel=False, lrmul=1.0, 96 | implicit_lreq=use_implicit_lreq): 97 | super(Conv2d, self).__init__() 98 | if in_channels % groups != 0: 99 | raise ValueError('in_channels must be divisible by groups') 100 | if out_channels % groups != 0: 101 | raise ValueError('out_channels must be divisible by groups') 102 | self.in_channels = in_channels 103 | self.out_channels = out_channels 104 | self.kernel_size = make_tuple(kernel_size, 2) 105 | self.stride = make_tuple(stride, 2) 106 | self.padding = make_tuple(padding, 2) 107 | self.output_padding = make_tuple(output_padding, 2) 108 | self.dilation = make_tuple(dilation, 2) 109 | self.groups = groups 110 | self.gain = gain 111 | self.lrmul = lrmul 112 | self.transpose = transpose 113 | self.fan_in = np.prod(self.kernel_size) * in_channels // groups 114 | self.transform_kernel = transform_kernel 115 | if transpose: 116 | self.weight = Parameter(torch.Tensor(in_channels, out_channels // groups, *self.kernel_size)) 117 | else: 118 | self.weight = Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) 119 | if bias: 120 | self.bias = Parameter(torch.Tensor(out_channels)) 121 | else: 122 | self.register_parameter('bias', None) 123 | self.std = 0 124 | self.implicit_lreq = implicit_lreq 125 | self.reset_parameters() 126 | 127 | def reset_parameters(self): 128 | self.std = self.gain / np.sqrt(self.fan_in) 129 | if not self.implicit_lreq: 130 | init.normal_(self.weight, mean=0, std=1.0 / self.lrmul) 131 | else: 132 | init.normal_(self.weight, mean=0, std=self.std / self.lrmul) 133 | setattr(self.weight, 'lr_equalization_coef', self.std) 134 | if self.bias is not None: 135 | setattr(self.bias, 'lr_equalization_coef', self.lrmul) 136 | 137 | if self.bias is not None: 138 | with torch.no_grad(): 139 | self.bias.zero_() 140 | 141 | def forward(self, x): 142 | if self.transpose: 143 | w = self.weight 144 | if self.transform_kernel: 145 | w = F.pad(w, (1, 1, 1, 1), mode='constant') 146 | w = w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1] 147 | if not self.implicit_lreq: 148 | bias = self.bias 149 | if bias is not None: 150 | bias = bias * self.lrmul 151 | return F.conv_transpose2d(x, w * self.std, bias, stride=self.stride, 152 | padding=self.padding, output_padding=self.output_padding, 153 | dilation=self.dilation, groups=self.groups) 154 | else: 155 | return F.conv_transpose2d(x, w, self.bias, stride=self.stride, padding=self.padding, 156 | output_padding=self.output_padding, dilation=self.dilation, 157 | groups=self.groups) 158 | else: 159 | w = self.weight 160 | if self.transform_kernel: 161 | w = F.pad(w, (1, 1, 1, 1), mode='constant') 162 | w = (w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]) * 0.25 163 | if not self.implicit_lreq: 164 | bias = self.bias 165 | if bias is not None: 166 | bias = bias * self.lrmul 167 | return F.conv2d(x, w * self.std, bias, stride=self.stride, padding=self.padding, 168 | dilation=self.dilation, groups=self.groups) 169 | else: 170 | return F.conv2d(x, w, self.bias, stride=self.stride, padding=self.padding, 171 | dilation=self.dilation, groups=self.groups) 172 | 173 | 174 | class ConvTranspose2d(Conv2d): 175 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, 176 | groups=1, bias=True, gain=np.sqrt(2.0), transform_kernel=False, lrmul=1.0, 177 | implicit_lreq=use_implicit_lreq): 178 | super(ConvTranspose2d, self).__init__(in_channels=in_channels, 179 | out_channels=out_channels, 180 | kernel_size=kernel_size, 181 | stride=stride, 182 | padding=padding, 183 | output_padding=output_padding, 184 | dilation=dilation, 185 | groups=groups, 186 | bias=bias, 187 | gain=gain, 188 | transpose=True, 189 | transform_kernel=transform_kernel, 190 | lrmul=lrmul, 191 | implicit_lreq=implicit_lreq) 192 | 193 | 194 | class SeparableConv2d(nn.Module): 195 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, 196 | bias=True, gain=np.sqrt(2.0), transpose=False): 197 | super(SeparableConv2d, self).__init__() 198 | self.spatial_conv = Conv2d(in_channels, in_channels, kernel_size, stride, padding, output_padding, dilation, 199 | in_channels, False, 1, transpose) 200 | self.channel_conv = Conv2d(in_channels, out_channels, 1, bias, 1, gain=gain) 201 | 202 | def forward(self, x): 203 | return self.channel_conv(self.spatial_conv(x)) 204 | 205 | 206 | class SeparableConvTranspose2d(Conv2d): 207 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, 208 | bias=True, gain=np.sqrt(2.0)): 209 | super(SeparableConvTranspose2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, 210 | output_padding, dilation, bias, gain, True) 211 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | from torch.hub import load_state_dict_from_url as _download 2 | from sgan.model import Model as _m 3 | 4 | dependencies = ['torch'] 5 | 6 | 7 | def style_asuka(pretrained=False): 8 | model = _m( 9 | layer_count=8, 10 | startf=32, 11 | maxf=512, 12 | 13 | truncation_psi=0.5, 14 | truncation_cutoff=8, 15 | mode='asuka' 16 | ) 17 | if pretrained: 18 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v1.0.0/Asuka-512x512-7de0e6.mat' 19 | model.load_state_dict(_download(checkpoint, progress=True)) 20 | return model 21 | 22 | 23 | def style_horo(pretrained=False): 24 | model = _m( 25 | layer_count=8, 26 | startf=32, 27 | maxf=512, 28 | 29 | truncation_psi=0.5, 30 | truncation_cutoff=8, 31 | mode='asuka' 32 | ) 33 | if pretrained: 34 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v1.1.0/Horo-512x512-822ee4.mat' 35 | model.load_state_dict(_download(checkpoint, progress=True)) 36 | return model 37 | 38 | 39 | def style_asashio(pretrained=False): 40 | model = _m( 41 | layer_count=8, 42 | startf=32, 43 | maxf=512, 44 | 45 | truncation_psi=0.5, 46 | truncation_cutoff=8, 47 | mode='asuka' 48 | ) 49 | if pretrained: 50 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v2.2.0/Asashio-512x512-a3c21a.mat' 51 | model.load_state_dict(_download(checkpoint, progress=True)) 52 | return model 53 | 54 | 55 | def style_anime_head(pretrained=False): 56 | model = _m( 57 | layer_count=8, 58 | startf=32, 59 | maxf=512, 60 | 61 | truncation_psi=0.5, 62 | truncation_cutoff=8, 63 | mode='asuka' 64 | ) 65 | if pretrained: 66 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v2.1.0/AnimeHead-512x512-960a82.mat' 67 | model.load_state_dict(_download(checkpoint, progress=True)) 68 | return model 69 | 70 | 71 | def style_anime_face_a(pretrained=False): 72 | model = _m( 73 | layer_count=8, 74 | startf=32, 75 | maxf=512, 76 | 77 | truncation_psi=0.5, 78 | truncation_cutoff=8, 79 | mode='asuka' 80 | ) 81 | if pretrained: 82 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v2.0.0/AnimeFaceC-512x512-47055c.mat' 83 | model.load_state_dict(_download(checkpoint, progress=True)) 84 | return model 85 | 86 | 87 | def style_anime_face_b(pretrained=False): 88 | model = _m( 89 | layer_count=8, 90 | startf=32, 91 | maxf=512, 92 | 93 | truncation_psi=0.5, 94 | truncation_cutoff=8, 95 | mode='asuka' 96 | ) 97 | if pretrained: 98 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v2.0.0/AnimeFaceA-512x512-feaff1.mat' 99 | model.load_state_dict(_download(checkpoint, progress=True)) 100 | return model 101 | 102 | 103 | def style_anime_face_c(pretrained=False): 104 | model = _m( 105 | layer_count=8, 106 | startf=32, 107 | maxf=512, 108 | 109 | truncation_psi=0.5, 110 | truncation_cutoff=8, 111 | mode='asuka' 112 | ) 113 | if pretrained: 114 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v2.0.0/AnimeFaceB-512x512-41bdee.mat' 115 | model.load_state_dict(_download(checkpoint, progress=True)) 116 | return model 117 | 118 | 119 | def style_anime_face_d(pretrained=False): 120 | model = _m( 121 | layer_count=8, 122 | startf=32, 123 | maxf=512, 124 | 125 | truncation_psi=0.5, 126 | truncation_cutoff=8, 127 | mode='asuka' 128 | ) 129 | if pretrained: 130 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v2.0.0/AnimeFaceD-512x512-3e59ff.mat' 131 | model.load_state_dict(_download(checkpoint, progress=True)) 132 | return model 133 | 134 | 135 | def style_anime_face_e(pretrained=False): 136 | model = _m( 137 | layer_count=8, 138 | startf=32, 139 | maxf=512, 140 | 141 | truncation_psi=0.4, 142 | truncation_cutoff=8, 143 | mode='asuka' 144 | ) 145 | if pretrained: 146 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v2.0.0/AnimeFaceE-512x512-9cfc38.mat' 147 | model.load_state_dict(_download(checkpoint, progress=True)) 148 | return model 149 | 150 | 151 | def style_art_a(pretrained=False): 152 | model = _m( 153 | layer_count=9, 154 | startf=16, 155 | maxf=512, 156 | 157 | truncation_psi=0.75, 158 | truncation_cutoff=8, 159 | mode='normal' 160 | ) 161 | if pretrained: 162 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v2.4.0/WikiArts-1024x1024-439e92.mat' 163 | model.load_state_dict(_download(checkpoint, progress=True)) 164 | return model 165 | 166 | 167 | def style_art_b(pretrained=False): 168 | model = _m( 169 | layer_count=8, 170 | startf=32, 171 | maxf=512, 172 | 173 | truncation_psi=0.5, 174 | truncation_cutoff=8, 175 | mode='asuka' 176 | ) 177 | if pretrained: 178 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v2.4.1/WikiArts-512x512-5955f8.mat' 179 | model.load_state_dict(_download(checkpoint, progress=True)) 180 | return model 181 | 182 | 183 | def style_ukiyoe_faces(pretrained=False): 184 | model = _m( 185 | layer_count=8, 186 | startf=32, 187 | maxf=512, 188 | 189 | truncation_psi=0.5, 190 | truncation_cutoff=8, 191 | mode='asuka' 192 | ) 193 | if pretrained: 194 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v2.5.0/UkiyoeFaces-512x512-e1d576.mat' 195 | model.load_state_dict(_download(checkpoint, progress=True)) 196 | return model 197 | 198 | 199 | def style_ffhq(pretrained=False): 200 | model = _m( 201 | layer_count=9, 202 | startf=16, 203 | maxf=512, 204 | 205 | truncation_psi=0.75, 206 | truncation_cutoff=8, 207 | mode='normal' 208 | ) 209 | if pretrained: 210 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v1.2.0/FFHQ-1024x1024-4a40cc.mat' 211 | model.load_state_dict(_download(checkpoint, progress=True)) 212 | return model 213 | 214 | 215 | def style_celeba_hq(pretrained=False): 216 | model = _m( 217 | layer_count=9, 218 | startf=16, 219 | maxf=512, 220 | 221 | truncation_psi=0.75, 222 | truncation_cutoff=8, 223 | mode='normal' 224 | ) 225 | if pretrained: 226 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v1.3.0/CelebaHQ-1024x1024-b5920a.mat' 227 | model.load_state_dict(_download(checkpoint, progress=True)) 228 | return model 229 | 230 | 231 | def style_baby(pretrained=False): 232 | model = _m( 233 | layer_count=9, 234 | startf=16, 235 | maxf=512, 236 | 237 | truncation_psi=0.75, 238 | truncation_cutoff=8, 239 | mode='normal' 240 | ) 241 | if pretrained: 242 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v1.8.0/Baby-1024x1024-b7d3cd.mat' 243 | model.load_state_dict(_download(checkpoint, progress=True)) 244 | return model 245 | 246 | 247 | def style_wanghong(pretrained=False): 248 | model = _m( 249 | layer_count=9, 250 | startf=16, 251 | maxf=512, 252 | 253 | truncation_psi=0.75, 254 | truncation_cutoff=8, 255 | mode='normal' 256 | ) 257 | if pretrained: 258 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v1.4.0/WangHong-1024x1024-3aff9a.mat' 259 | model.load_state_dict(_download(checkpoint, progress=True)) 260 | return model 261 | 262 | 263 | def style_asian_people(pretrained=False): 264 | model = _m( 265 | layer_count=9, 266 | startf=16, 267 | maxf=512, 268 | 269 | truncation_psi=0.75, 270 | truncation_cutoff=8, 271 | mode='normal' 272 | ) 273 | if pretrained: 274 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v1.5.0/AsianPeople-1024x1024-82c02f' \ 275 | '.mat ' 276 | model.load_state_dict(_download(checkpoint, progress=True)) 277 | return model 278 | 279 | 280 | def style_asian_star(pretrained=False): 281 | model = _m( 282 | layer_count=9, 283 | startf=16, 284 | maxf=512, 285 | 286 | truncation_psi=0.75, 287 | truncation_cutoff=8, 288 | mode='normal' 289 | ) 290 | if pretrained: 291 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v1.6.0/AsianStar-1024x1024-aff808.mat' 292 | model.load_state_dict(_download(checkpoint, progress=True)) 293 | return model 294 | 295 | 296 | def style_super_star(pretrained=False): 297 | model = _m( 298 | layer_count=9, 299 | startf=16, 300 | maxf=512, 301 | 302 | truncation_psi=0.75, 303 | truncation_cutoff=8, 304 | mode='normal' 305 | ) 306 | if pretrained: 307 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v1.7.0/SuperStar-1024x1024-4141b4.mat' 308 | model.load_state_dict(_download(checkpoint, progress=True)) 309 | return model 310 | 311 | 312 | def style_vessel(pretrained=False): 313 | model = _m( 314 | layer_count=9, 315 | startf=16, 316 | maxf=512, 317 | 318 | truncation_psi=0.75, 319 | truncation_cutoff=8, 320 | mode='normal' 321 | ) 322 | if pretrained: 323 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v2.3.0/Vessel-1024x1024-ea0817.mat' 324 | model.load_state_dict(_download(checkpoint, progress=True)) 325 | return model 326 | 327 | 328 | def style_qinghua(pretrained=False): 329 | model = _m( 330 | layer_count=9, 331 | startf=16, 332 | maxf=512, 333 | 334 | truncation_psi=0.75, 335 | truncation_cutoff=8, 336 | mode='normal' 337 | ) 338 | if pretrained: 339 | checkpoint = 'https://github.com/GalAster/StyleGAN-Zoo/releases/download/v2.3.1/QingHua-1024x1024-649fe8.mat' 340 | model.load_state_dict(_download(checkpoint, progress=True)) 341 | return model 342 | -------------------------------------------------------------------------------- /License: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | https://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | Copyright 2013-2018 Docker, Inc. 180 | 181 | Licensed under the Apache License, Version 2.0 (the "License"); 182 | you may not use this file except in compliance with the License. 183 | You may obtain a copy of the License at 184 | 185 | https://www.apache.org/licenses/LICENSE-2.0 186 | 187 | Unless required by applicable law or agreed to in writing, software 188 | distributed under the License is distributed on an "AS IS" BASIS, 189 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 190 | See the License for the specific language governing permissions and 191 | limitations under the License. -------------------------------------------------------------------------------- /sgan/net.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Stanislav Pidhorskyi 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import sgan.lreq as ln 17 | import numpy as np 18 | import torch 19 | from torch import nn 20 | from torch.nn import functional as F 21 | from torch.nn import init 22 | from torch.nn.parameter import Parameter 23 | 24 | 25 | def pixel_norm(x, epsilon=1e-8): 26 | return x * torch.rsqrt(torch.mean(x.pow(2.0), dim=1, keepdim=True) + epsilon) 27 | 28 | 29 | def style_mod(x, style): 30 | style = style.view(style.shape[0], 2, x.shape[1], 1, 1) 31 | return torch.addcmul(style[:, 1], value=1.0, tensor1=x, tensor2=style[:, 0] + 1) 32 | 33 | 34 | def upscale2d(x, factor=2): 35 | # s = x.shape 36 | # x = torch.reshape(x, [-1, s[1], s[2], 1, s[3], 1]) 37 | # x = x.repeat(1, 1, 1, factor, 1, factor) 38 | # x = torch.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor]) 39 | # return x 40 | return F.interpolate(x, scale_factor=factor, mode='bilinear', align_corners=True) 41 | 42 | 43 | class Blur(nn.Module): 44 | def __init__(self, channels): 45 | super(Blur, self).__init__() 46 | f = np.array([1, 2, 1], dtype=np.float32) 47 | f = f[:, np.newaxis] * f[np.newaxis, :] 48 | f /= np.sum(f) 49 | kernel = torch.Tensor(f).view(1, 1, 3, 3).repeat(channels, 1, 1, 1) 50 | self.register_buffer('weight', kernel) 51 | self.groups = channels 52 | 53 | def forward(self, x): 54 | return F.conv2d(x, weight=self.weight, groups=self.groups, padding=1) 55 | 56 | 57 | class DecodeBlock(nn.Module): 58 | def __init__(self, inputs, outputs, latent_size, has_first_conv=True, fused_scale=True, layer=0): 59 | super(DecodeBlock, self).__init__() 60 | self.has_first_conv = has_first_conv 61 | self.inputs = inputs 62 | self.has_first_conv = has_first_conv 63 | self.fused_scale = fused_scale 64 | if has_first_conv: 65 | if fused_scale: 66 | self.conv_1 = ln.ConvTranspose2d(inputs, outputs, 3, 2, 1, bias=False, transform_kernel=True) 67 | else: 68 | self.conv_1 = ln.Conv2d(inputs, outputs, 3, 1, 1, bias=False) 69 | 70 | self.blur = Blur(outputs) 71 | self.noise_weight_1 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 72 | self.noise_weight_1.data.zero_() 73 | self.bias_1 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 74 | self.instance_norm_1 = nn.InstanceNorm2d(outputs, affine=False, eps=1e-8) 75 | self.style_1 = ln.Linear(latent_size, 2 * outputs, gain=1) 76 | 77 | self.conv_2 = ln.Conv2d(outputs, outputs, 3, 1, 1, bias=False) 78 | self.noise_weight_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 79 | self.noise_weight_2.data.zero_() 80 | self.bias_2 = nn.Parameter(torch.Tensor(1, outputs, 1, 1)) 81 | self.instance_norm_2 = nn.InstanceNorm2d(outputs, affine=False, eps=1e-8) 82 | self.style_2 = ln.Linear(latent_size, 2 * outputs, gain=1) 83 | 84 | self.layer = layer 85 | 86 | self.c = -1 87 | 88 | with torch.no_grad(): 89 | self.bias_1.zero_() 90 | self.bias_2.zero_() 91 | 92 | def set(self, c): 93 | self.c = c 94 | 95 | def forward(self, x, s1, s2): 96 | # TODO: disable random when interpolate 97 | if self.has_first_conv: 98 | if not self.fused_scale: 99 | x = upscale2d(x) 100 | x = self.conv_1(x) 101 | x = self.blur(x) 102 | 103 | x = torch.addcmul( 104 | x, value=1.0, tensor1=self.noise_weight_1, 105 | tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]) 106 | ) 107 | 108 | x = x + self.bias_1 109 | 110 | x = F.leaky_relu(x, 0.2) 111 | 112 | x = self.instance_norm_1(x) 113 | 114 | x = style_mod(x, self.style_1(s1)) 115 | 116 | x = self.conv_2(x) 117 | 118 | x = torch.addcmul( 119 | x, value=1.0, tensor1=self.noise_weight_2, 120 | tensor2=torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]]) 121 | ) 122 | 123 | x = x + self.bias_2 124 | 125 | x = F.leaky_relu(x, 0.2) 126 | x = self.instance_norm_2(x) 127 | 128 | x = style_mod(x, self.style_2(s2)) 129 | 130 | return x 131 | 132 | def forward_double(self, x, _x, s1, s2): 133 | if self.has_first_conv: 134 | if not self.fused_scale: 135 | x = upscale2d(x) 136 | _x = upscale2d(_x) 137 | x = self.conv_1(x) 138 | _x = self.conv_1(_x) 139 | 140 | x = self.blur(x) 141 | _x = self.blur(_x) 142 | 143 | n1 = torch.randn([int(x.shape[0]), 1, int(x.shape[2]), int(x.shape[3])]) 144 | x = torch.addcmul( 145 | x, value=1.0, tensor1=self.noise_weight_1, 146 | tensor2=n1 147 | ) 148 | 149 | _x = torch.addcmul( 150 | _x, value=1.0, tensor1=self.noise_weight_1, 151 | tensor2=n1 152 | ) 153 | 154 | x = x + self.bias_1 155 | _x = _x + self.bias_1 156 | 157 | x = F.leaky_relu(x, 0.2) 158 | _x = F.leaky_relu(_x, 0.2) 159 | 160 | std = x.std(axis=[2, 3], keepdim=True) 161 | mean = x.mean(axis=[2, 3], keepdim=True) 162 | 163 | x = (x - mean) / std 164 | _x = (_x - mean) / std 165 | 166 | x = style_mod(x, self.style_1(s1)) 167 | _x = style_mod(_x, self.style_1(s1)) 168 | 169 | x = self.conv_2(x) 170 | _x = self.conv_2(_x) 171 | 172 | n2 = torch.randn([int(x.shape[0]), 1, int(x.shape[2]), int(x.shape[3])]) 173 | 174 | x = torch.addcmul(x, value=1.0, tensor1=self.noise_weight_2, 175 | tensor2=n2) 176 | 177 | _x = torch.addcmul(_x, value=1.0, tensor1=self.noise_weight_2, 178 | tensor2=n2) 179 | 180 | x = x + self.bias_2 181 | _x = _x + self.bias_2 182 | 183 | x = F.leaky_relu(x, 0.2) 184 | _x = F.leaky_relu(_x, 0.2) 185 | 186 | std = x.std(axis=[2, 3], keepdim=True) 187 | mean = x.mean(axis=[2, 3], keepdim=True) 188 | 189 | x = (x - mean) / std 190 | _x = (_x - mean) / std 191 | 192 | x = style_mod(x, self.style_2(s2)) 193 | _x = style_mod(_x, self.style_2(s2)) 194 | 195 | return x, _x 196 | 197 | 198 | class ToRGB(nn.Module): 199 | def __init__(self, inputs, channels): 200 | super(ToRGB, self).__init__() 201 | self.inputs = inputs 202 | self.channels = channels 203 | self.to_rgb = ln.Conv2d(inputs, channels, 1, 1, 0, gain=1) 204 | 205 | def forward(self, x): 206 | x = self.to_rgb(x) 207 | return x 208 | 209 | 210 | class Generator(nn.Module): 211 | def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, channels=3): 212 | super(Generator, self).__init__() 213 | self.maxf = maxf 214 | self.startf = startf 215 | self.layer_count = layer_count 216 | 217 | self.channels = channels 218 | self.latent_size = latent_size 219 | 220 | mul = 2 ** (self.layer_count - 1) 221 | 222 | inputs = min(self.maxf, startf * mul) 223 | self.const = Parameter(torch.Tensor(1, inputs, 4, 4)) 224 | self.zeros = torch.zeros(1, 1, 1, 1) 225 | init.ones_(self.const) 226 | 227 | self.layer_to_resolution = [0 for _ in range(layer_count)] 228 | resolution = 2 229 | 230 | self.style_sizes = [] 231 | 232 | to_rgb = nn.ModuleList() 233 | 234 | self.decode_block: nn.ModuleList[DecodeBlock] = nn.ModuleList() 235 | for i in range(self.layer_count): 236 | outputs = min(self.maxf, startf * mul) 237 | 238 | has_first_conv = i != 0 239 | fused_scale = resolution * 2 >= 128 240 | 241 | block = DecodeBlock(inputs, outputs, latent_size, has_first_conv, fused_scale=fused_scale, layer=i) 242 | 243 | resolution *= 2 244 | self.layer_to_resolution[i] = resolution 245 | 246 | self.style_sizes += [2 * (inputs if has_first_conv else outputs), 2 * outputs] 247 | 248 | to_rgb.append(ToRGB(outputs, channels)) 249 | 250 | # print("decode_block%d %s styles in: %dl out resolution: %d" % ((i + 1), millify(count_parameters(block)), outputs, resolution)) 251 | self.decode_block.append(block) 252 | inputs = outputs 253 | mul //= 2 254 | 255 | self.to_rgb = to_rgb 256 | 257 | def decode(self, styles, lod, remove_blob=True): 258 | x = self.const 259 | _x = None 260 | for i in range(lod + 1): 261 | if i < 4 or not remove_blob: 262 | x = self.decode_block[i].forward(x, styles[:, 2 * i + 0], styles[:, 2 * i + 1]) 263 | if remove_blob and i == 3: 264 | _x = x.clone() 265 | _x[x > 300.0] = 0 266 | 267 | # plt.hist((torch.max(torch.max(_x, dim=2)[0], dim=2)[0]).cpu().flatten().numpy(), bins=300) 268 | # plt.show() 269 | # exit() 270 | else: 271 | x, _x = self.decode_block[i].forward_double(x, _x, styles[:, 2 * i + 0], styles[:, 2 * i + 1]) 272 | 273 | if _x is not None: 274 | x = _x 275 | if lod == 8: 276 | x = self.to_rgb[lod](x) 277 | else: 278 | x = x.max(dim=1, keepdim=True)[0] 279 | x = x - x.min() 280 | x = x / x.max() 281 | x = torch.pow(x, 1.0 / 2.2) 282 | x = x.repeat(1, 3, 1, 1) 283 | return x 284 | 285 | def decode_asuka(self, styles, lod, remove_blob=True): 286 | x = self.const 287 | _x = None 288 | prune_at_layer = 1 289 | 290 | for i in range(lod + 1): 291 | if i <= prune_at_layer or not remove_blob: 292 | x = self.decode_block[i].forward(x, styles[:, 2 * i + 0], styles[:, 2 * i + 1]) 293 | if remove_blob and i == prune_at_layer: 294 | _x = x.clone() 295 | # TODO: seems not enough 296 | ch80 = _x[:, 80] 297 | ch73 = _x[:, 73] 298 | 299 | ch80[ch80 > torch.max(ch80) * 0.9] = 0 300 | ch73[ch73 > torch.max(ch73) * 0.9] = 0 301 | 302 | _x[:, 80] = ch80 303 | _x[:, 73] = ch73 304 | 305 | # plt.hist((torch.max(torch.max(x, dim=2)[0], dim=2)[0]).cpu().flatten().numpy(), bins=300) 306 | # plt.show() 307 | else: 308 | x, _x = self.decode_block[i].forward_double(x, _x, styles[:, 2 * i + 0], styles[:, 2 * i + 1]) 309 | 310 | if _x is not None: 311 | x = _x 312 | if lod == 7: 313 | x = self.to_rgb[lod](x) 314 | else: 315 | x = x.max(dim=1, keepdim=True)[0] 316 | x = x - x.min() 317 | x = x / x.max() 318 | x = torch.pow(x, 1.0 / 2.2) 319 | x = x.repeat(1, 3, 1, 1) 320 | return x 321 | 322 | def forward(self, styles, lod, remove_blob=True, method='normal'): 323 | if method == 'asuka': 324 | return self.decode_asuka(styles, lod, remove_blob) 325 | else: 326 | return self.decode(styles, lod, remove_blob) 327 | 328 | 329 | class MappingBlock(nn.Module): 330 | def __init__(self, inputs, output, lrmul): 331 | super(MappingBlock, self).__init__() 332 | self.fc = ln.Linear(inputs, output, lrmul=lrmul) 333 | 334 | def forward(self, x): 335 | x = F.leaky_relu(self.fc(x), 0.2) 336 | return x 337 | 338 | 339 | class Mapping(nn.Module): 340 | def __init__(self, num_layers, mapping_layers=5, latent_size=256, dlatent_size=256, mapping_fmaps=256): 341 | super(Mapping, self).__init__() 342 | inputs = latent_size 343 | self.mapping_layers = mapping_layers 344 | self.num_layers = num_layers 345 | for i in range(mapping_layers): 346 | outputs = dlatent_size if i == mapping_layers - 1 else mapping_fmaps 347 | block = MappingBlock(inputs, outputs, lrmul=0.01) 348 | inputs = outputs 349 | setattr(self, "block_%d" % (i + 1), block) 350 | 351 | def forward(self, z): 352 | x = pixel_norm(z) 353 | 354 | for i in range(self.mapping_layers): 355 | x = getattr(self, "block_%d" % (i + 1))(x) 356 | 357 | return x.view(x.shape[0], 1, x.shape[1]).repeat(1, self.num_layers, 1) 358 | --------------------------------------------------------------------------------