├── LICENSE ├── README.md ├── barlow_twins.py ├── dataset.py ├── doc ├── arch.png └── sample_afhq.png ├── environment.yml ├── generate.py ├── latent_interpolation.py ├── metric_impl ├── dc.py ├── ppl.py └── pr.py ├── metrics.py ├── model ├── hyper_mod.py └── stylegan.py ├── prepare_data.py ├── pretrained_converter.py ├── self_alignment.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Kim Seonghyeon, Hector Laria 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Transferring Unconditional to Conditional GANs with Hyper-Modulation 2 | -------------------------------------------------------------------- 3 | 4 | _GANs have matured in recent years and are able to generate high-resolution, realistic images. However, the computational resources and the data required for the training of high-quality GANs are enormous, and the study of transfer learning of these models is therefore an urgent topic. Many of the available high-quality pretrained GANs are unconditional (like StyleGAN). For many applications, however, conditional GANs are preferable, because they provide more control over the generation process, despite often suffering more training difficulties. Therefore, in this paper, we focus on transferring from high-quality pretrained unconditional GANs to conditional GANs. This requires architectural adaptation of the pretrained GAN to perform the conditioning. To this end, we propose hyper-modulated generative networks that allow for shared and complementary supervision. To prevent the additional weights of the hypernetwork to overfit, with subsequent mode collapse on small target domains, we introduce a self-initialization procedure that does not require any real data to initialize the hypernetwork parameters. To further improve the sample efficiency of the transfer, we apply contrastive learning in the discriminator, which effectively works on very limited batch sizes. In extensive experiments, we validate the efficiency of the hypernetworks, self-initialization and contrastive loss for knowledge transfer on standard benchmarks._ 5 | 6 |

7 | Model architecture 8 |

9 | 10 | Official implementation of [Hyper-Modulation](https://arxiv.org/abs/2112.02219). 11 | 12 |

13 | Sample of the model trained on AFHQ 14 |
15 | Given a pre-trained model, our hyper-modulation method learns to use its knowledge to produce out-of-domain generalizations. I.e., trained on human faces, it quickly learns animal faces (above) or further domains like objects and places (paper). No finetuning is needed. It also preserves and responds to the source domain style mechanism. 16 |

17 | 18 | ### Environment 19 | Please install the required packages with conda and activate the environment 20 | ```bash 21 | conda env create -f environment.yml 22 | conda activate hypermod 23 | ``` 24 | 25 | ### Base model 26 | Pretrained unconditional models are available in the [base code repository](https://github.com/rosinality/style-based-gan-pytorch), concretely we used [this one](https://drive.google.com/open?id=1QlXFPIOFzsJyjZ1AtfpnVhqW4Z0r8GLZ) trained on FFHQ dataset. 27 | Let's assume it's downloaded in ``. 28 | 29 | ### Data 30 | To replicate the stated results, please download [AFHQ(v1)](https://github.com/clovaai/stargan-v2/blob/master/README.md#animal-faces-hq-dataset-afhq). 31 | Additional datasets like FFHQ, CelebA-HQ, Flowers102standard, Places356 can be downloaded from their respective sources. 32 | 33 | Once downloaded and extracted, the dataset must be preprocessed into an LMDB file with the following command 34 | ```bash 35 | python prepare_data.py /train/ --labels --sizes 256 --out 36 | ``` 37 | The directory given to `prepare_data.py` must contain a folder per class, class samples within each respective folder as specified by [ImageFolder](https://pytorch.org/vision/stable/datasets.html#torchvision.datasets.ImageFolder) dataset. 38 | 39 | Additionally, we can compute the inception features of the dataset to calculate the metrics later. 40 | ```bash 41 | python metrics.py --output_path --batch_size 32 --operation extract-feats 42 | ``` 43 | 44 | ### Training 45 | 46 | #### Self-alignment 47 | 48 | ```bash 49 | python self_alignment.py --batch_size 32 --name --checkpoint_interval 1000 --iterations 10000 50 | ``` 51 | The model checkpoints will begin to get saved under `checkpoint/self_align/` folder (can be specified with `--checkpoint_dir`). 52 | Additional losses are also saved, can be visualized pointing tensorboard to folder `runs/self_align/` (can be specified with `--log_dir`). 53 | 54 | To assess the quality we can start computing metrics for each checkpoint with 55 | ```bash 56 | python metrics.py checkpoint/self_align/ --precomputed_feats --batch_size 32 --operation metrics 57 | ``` 58 | For each checkpoint, a `.metric` file will be created containing the computed scores. The function `class_mean()` from `utils.py` can be used to read the whole directory. 59 | One would normally want to plot the metric evolution over time, which can be done as 60 | ```python 61 | import matplotlib.pyplot as plt 62 | from utils import class_mean 63 | 64 | p = 'checkpoint/self_align/' 65 | plt.plot(*class_mean(p, metric='fid', return_steps=True)) 66 | ``` 67 | We then pick a checkpoint with good enough scores. We will name it ``. 68 | 69 | #### Transfer learning 70 | 71 | To run the actual training, we can do 72 | ```bash 73 | python train.py --ckpt --max_size 256 --loss r1 --const_alpha 1 --phase 3000000 --batch_default 10 --init_size 256 --mixing --name --checkpoint_interval 1000 --iterations 200002 74 | ``` 75 | 76 | For the finetuned version 77 | ```bash 78 | python train.py --finetune --ckpt --max_size 256 --loss r1 --const_alpha 1 --phase 3000000 --batch_default 10 --init_size 256 --mixing --name --checkpoint_interval 1000 --iterations 200002 79 | ``` 80 | 81 | For the version with contrastive learning 82 | ```bash 83 | python train.py --contrastive --ckpt --max_size 256 --loss r1 --const_alpha 1 --phase 3000000 --batch_default 10 --init_size 256 --mixing --name --checkpoint_interval 1000 --iterations 200002 84 | ``` 85 | 86 | Generated images and loss curves are found in `tensorboard --logdir runs/` by default. 87 | Checkpoints are saved in `checkpoint/`, where metrics can be computed and plotted by using `metrics.py` script and `class_mean` plot the same way as for self-alignment. 88 | 89 | Additionally, perceptual path length can be computed as 90 | ```bash 91 | python ppl.py --checkpoint_path 92 | ``` 93 | 94 | ### Generation 95 | 96 | To make use of the final models, we can directly make use of `generate.py` or `latent_interpolation.py`. 97 | 98 | To generate a figure like the interpolation from Supplementary Material section 7 (Latent space), run 99 | ```bash 100 | python latent_interpolation.py --output interpolation_class_noise.png --num_interpolations 5 --type class_noise 101 | ``` 102 | 103 | ### Acknowledgments 104 | 105 | The project is based on [this StyleGAN implementation](https://github.com/rosinality/style-based-gan-pytorch). We used [PIQ](https://github.com/photosynthesis-team/piq) package for metrics, and the PPL implementation was taken from [here](https://github.com/rosinality/stylegan2-pytorch/blob/master/ppl.py). 106 | 107 | ### Citation 108 | 109 | ``` 110 | @InProceedings{Laria_2022_CVPR, 111 | author = {Laria, H\'ector and Wang, Yaxing and van de Weijer, Joost and Raducanu, Bogdan}, 112 | title = {Transferring Unconditional to Conditional GANs With Hyper-Modulation}, 113 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 114 | month = {June}, 115 | year = {2022}, 116 | pages = {3840-3849} 117 | } 118 | ``` 119 | -------------------------------------------------------------------------------- /barlow_twins.py: -------------------------------------------------------------------------------- 1 | """Base Barlow Twins implementation (BarlowTwinsLoss) taken from 2 | https://github.com/facebookresearch/barlowtwins/blob/main/main.py""" 3 | from typing import Tuple 4 | import torch 5 | from torch import nn 6 | from kornia import augmentation as K 7 | 8 | 9 | class DiffTransform(nn.Module): 10 | def __init__(self, crop_resize: int = 224): 11 | super().__init__() 12 | 13 | self.transform = K.AugmentationSequential( 14 | K.Normalize(mean=torch.tensor(-1), std=torch.tensor(2)), # from [-1, 1] to [0, 1] 15 | K.RandomResizedCrop((crop_resize, crop_resize), resample='BICUBIC'), 16 | K.RandomHorizontalFlip(p=0.5), 17 | K.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8), 18 | K.RandomGrayscale(p=0.2), 19 | K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(1.05, 1.05), p=0.5), 20 | K.RandomSolarize(thresholds=(0, 0.5), additions=0, p=0.0), 21 | K.Normalize(mean=torch.tensor(0.5), std=torch.tensor(0.5)), # back to [-1, 1] 22 | ) 23 | self.transform_prime = K.AugmentationSequential( 24 | K.Normalize(mean=torch.tensor(-1), std=torch.tensor(2)), 25 | K.RandomResizedCrop((crop_resize, crop_resize), resample='BICUBIC'), 26 | K.RandomHorizontalFlip(p=0.5), 27 | K.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8), 28 | K.RandomGrayscale(p=0.2), 29 | K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(1.05, 1.05), p=0.1), 30 | K.RandomSolarize(thresholds=(0, 0.5), additions=0, p=0.2), 31 | K.Normalize(mean=torch.tensor(0.5), std=torch.tensor(0.5)), 32 | ) 33 | 34 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 35 | y1 = self.transform(x) 36 | y2 = self.transform_prime(x) 37 | return y1, y2 38 | 39 | 40 | class BarlowTwins(nn.Module): 41 | def __init__(self, num_feats, lambd=5e-3, sizes=(512, 512, 512, 512), use_projector=True): 42 | super().__init__() 43 | self.lambd = lambd 44 | 45 | # projector 46 | if not use_projector: 47 | self.projector = nn.Identity() 48 | else: 49 | layers = [] 50 | for i in range(len(sizes) - 2): 51 | layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False)) 52 | layers.append(nn.BatchNorm1d(sizes[i + 1])) 53 | layers.append(nn.ReLU(inplace=True)) 54 | layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False)) 55 | self.projector = nn.Sequential(*layers) 56 | 57 | # normalization layer for the representations z1 and z2 58 | self.bn = nn.BatchNorm1d(num_feats, affine=False) 59 | 60 | def forward(self, z1, z2): 61 | z1 = self.projector(z1) 62 | z2 = self.projector(z2) 63 | 64 | return self.bn(z1), self.bn(z2) 65 | 66 | @staticmethod 67 | def off_diagonal(x): 68 | # return a flattened view of the off-diagonal elements of a square matrix 69 | n, m = x.shape 70 | assert n == m 71 | return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() 72 | 73 | def compute_final_loss(self, z1, z2): 74 | # empirical cross-correlation matrix 75 | c = z1.T @ z2 76 | 77 | # sum the cross-correlation matrix between all gpus 78 | c.div_(z1.size(0)) 79 | 80 | on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() 81 | off_diag = self.off_diagonal(c).pow_(2).sum() 82 | loss = on_diag + self.lambd * off_diag 83 | return loss 84 | 85 | 86 | def compute_contrastive(discriminator, images, barlow_twins, transform, c, step, alpha, weighting=0.01): 87 | y1, y2 = transform(images) 88 | hierarchical_feats1 = discriminator(y1, c=c, step=step, alpha=alpha, return_hierarchical=True) 89 | hierarchical_feats2 = discriminator(y2, c=c, step=step, alpha=alpha, return_hierarchical=True) 90 | 91 | feats1 = hierarchical_feats1[-2].squeeze(-1).squeeze(-1) 92 | feats2 = hierarchical_feats2[-2].squeeze(-1).squeeze(-1) 93 | 94 | # hack to compute the cross-corr matrix among gpus on DP 95 | z1, z2 = barlow_twins(feats1, feats2) # gather zs from all gpus 96 | final_loss = barlow_twins.module.compute_final_loss(z1, z2) # .module = out of DP 97 | 98 | return final_loss * weighting 99 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import lmdb 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | import torch 7 | 8 | 9 | class MultiResolutionDataset(Dataset): 10 | def __init__(self, path, transform, resolution=8, selected_classes=None, class_samples=None, random_class_sampling=False, length=None, one_hot_label=False): 11 | self.path = path 12 | self.env = None 13 | self.open_lmdb() 14 | 15 | if not self.env: 16 | raise IOError('Cannot open lmdb dataset', path) 17 | self.filter_classes = None 18 | 19 | with self.env.begin(write=False) as txn: 20 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 21 | self.num_classes = txn.get('num_classes'.encode('utf-8')) 22 | 23 | if selected_classes is not None or class_samples is not None: 24 | if selected_classes is None: 25 | selected_classes = list(range(int(self.num_classes.decode('utf-8')))) if self.num_classes is not None else [0] 26 | self.filter_classes = self.filtering(txn, selected_classes, class_samples, random_sampling=random_class_sampling) 27 | self.length = len(self.filter_classes) 28 | self.num_classes = str(len(selected_classes)).encode('utf-8') 29 | 30 | if self.num_classes is not None: 31 | self.num_classes = int(self.num_classes.decode('utf-8')) 32 | assert length is None or length <= self.length, f'There are not enough samples in the dataset. {length} asked, {self.length} in total.' 33 | if length is not None: 34 | self.length = length 35 | self.resolution = resolution 36 | self.transform = transform 37 | self.one_hot_label = one_hot_label 38 | 39 | self.close_lmdb() 40 | 41 | def open_lmdb(self): 42 | self.env = lmdb.open( 43 | self.path, 44 | max_readers=32, 45 | readonly=True, 46 | lock=False, 47 | readahead=False, 48 | meminit=False, 49 | ) 50 | 51 | def close_lmdb(self): 52 | self.env.close() 53 | self.env = None 54 | 55 | def __len__(self): 56 | return self.length 57 | 58 | def __getitem__(self, index): 59 | if self.env is None: 60 | self.open_lmdb() 61 | 62 | if self.filter_classes is not None: 63 | index = self.filter_classes[index] 64 | 65 | with self.env.begin(write=False) as txn: 66 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') 67 | img_bytes = txn.get(key) 68 | 69 | key = f'label-{str(index).zfill(5)}'.encode('utf-8') 70 | label = txn.get(key) 71 | 72 | buffer = BytesIO(img_bytes) 73 | img = Image.open(buffer) 74 | img = self.transform(img) 75 | 76 | if label is not None: 77 | label = int(label.decode('utf-8')) 78 | if self.one_hot_label: 79 | label = torch.nn.functional.one_hot(torch.tensor(label), num_classes=self.num_classes).float() 80 | 81 | return img, label 82 | 83 | @staticmethod 84 | def filtering(txn, selected_classes, class_samples, random_sampling=False): 85 | def get_label(txn, i): 86 | return int(txn.get(f'label-{str(i).zfill(5)}'.encode('utf-8')).decode('utf-8')) 87 | 88 | from collections import defaultdict 89 | import itertools 90 | import random 91 | class_ids = defaultdict(list) 92 | length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 93 | 94 | # separate classes 95 | for i in range(length): 96 | class_ids[get_label(txn, i)].append(i) 97 | 98 | # drop unselected classes 99 | for i in list(class_ids.keys()): 100 | if i not in selected_classes: 101 | class_ids.pop(i) 102 | 103 | # reduce the samples per class 104 | for i in class_ids.keys(): 105 | if random_sampling: 106 | class_ids[i] = random.sample(class_ids[i], k=class_samples) 107 | else: 108 | class_ids[i] = class_ids[i][:class_samples] 109 | 110 | # flatten 111 | return list(itertools.chain(*class_ids.values())) 112 | -------------------------------------------------------------------------------- /doc/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hecoding/Hyper-Modulation/3c07d8052414470b2c8d10b265d522a1ed5d9b14/doc/arch.png -------------------------------------------------------------------------------- /doc/sample_afhq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hecoding/Hyper-Modulation/3c07d8052414470b2c8d10b265d522a1ed5d9b14/doc/sample_afhq.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: hypermod 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - absl-py=0.11.0=py38h578d9bd_0 9 | - aiohttp=3.7.4=py38h27cfd23_1 10 | - async-timeout=3.0.1=py_1000 11 | - attrs=20.3.0=pyhd3deb0d_0 12 | - blas=1.0=mkl 13 | - blinker=1.4=py_1 14 | - brotlipy=0.7.0=py38h8df0ef7_1001 15 | - c-ares=1.17.1=h36c2ea0_0 16 | - ca-certificates=2021.1.19=h06a4308_0 17 | - cachetools=4.2.1=pyhd8ed1ab_0 18 | - certifi=2020.12.5=py38h06a4308_0 19 | - cffi=1.14.5=py38h261ae71_0 20 | - chardet=3.0.4=py38h924ce5b_1008 21 | - click=7.1.2=pyh9f0ad1d_0 22 | - cryptography=2.9.2=py38h766eaa4_0 23 | - cudatoolkit=10.2.89=hfd86e86_1 24 | - freetype=2.10.4=h5ab3b9f_0 25 | - fsspec=0.8.7=pyhd8ed1ab_0 26 | - future=0.18.2=py38h578d9bd_3 27 | - google-auth=1.24.0=pyhd3deb0d_0 28 | - google-auth-oauthlib=0.4.1=py_2 29 | - grpcio=1.33.2=py38heead2fc_2 30 | - idna=2.10=pyh9f0ad1d_0 31 | - importlib-metadata=3.7.0=py38h578d9bd_0 32 | - intel-openmp=2020.2=254 33 | - jpeg=9b=h024ee3a_2 34 | - lcms2=2.11=h396b838_0 35 | - ld_impl_linux-64=2.33.1=h53a641e_7 36 | - libedit=3.1.20191231=h14c3975_1 37 | - libffi=3.3=he6710b0_2 38 | - libgcc-ng=9.1.0=hdf63c60_0 39 | - libpng=1.6.37=hbc83047_0 40 | - libprotobuf=3.14.0=h8c45485_0 41 | - libstdcxx-ng=9.1.0=hdf63c60_0 42 | - libtiff=4.1.0=h2733197_1 43 | - libuv=1.40.0=h7b6447c_0 44 | - lmdb=0.9.28=h2531618_0 45 | - lz4-c=1.9.3=h2531618_0 46 | - markdown=3.3.4=pyhd8ed1ab_0 47 | - mkl=2020.2=256 48 | - mkl-service=2.3.0=py38he904b0f_0 49 | - mkl_fft=1.3.0=py38h54f3939_0 50 | - mkl_random=1.1.1=py38h0573a6f_0 51 | - multidict=5.1.0=py38h27cfd23_2 52 | - ncurses=6.2=he6710b0_1 53 | - ninja=1.10.2=py38hff7bd54_0 54 | - numpy=1.19.2=py38h54aff64_0 55 | - numpy-base=1.19.2=py38hfa32c7d_0 56 | - oauthlib=3.0.1=py_0 57 | - olefile=0.46=py_0 58 | - openssl=1.1.1j=h27cfd23_0 59 | - packaging=20.9=pyh44b312d_0 60 | - pillow=8.1.1=py38he98fc37_0 61 | - pip=21.0.1=py38h06a4308_0 62 | - protobuf=3.14.0=py38h2531618_1 63 | - pyasn1=0.4.8=py_0 64 | - pyasn1-modules=0.2.7=py_0 65 | - pycparser=2.20=pyh9f0ad1d_2 66 | - pyjwt=2.0.1=pyhd8ed1ab_0 67 | - pyopenssl=19.1.0=py38_0 68 | - pyparsing=2.4.7=pyh9f0ad1d_0 69 | - pysocks=1.7.1=py38h578d9bd_3 70 | - python=3.8.8=hdb3f193_4 71 | - python-lmdb=1.1.1=py38h2531618_1 72 | - python_abi=3.8=1_cp38 73 | - pytorch=1.7.1=py3.8_cuda10.2.89_cudnn7.6.5_0 74 | - pytorch-lightning=1.2.1=pyhd8ed1ab_0 75 | - pyyaml=5.3.1=py38h8df0ef7_1 76 | - readline=8.1=h27cfd23_0 77 | - requests=2.25.1=pyhd3deb0d_0 78 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0 79 | - rsa=4.7.2=pyh44b312d_0 80 | - setuptools=52.0.0=py38h06a4308_0 81 | - six=1.15.0=py38h06a4308_0 82 | - sqlite=3.33.0=h62c20be_0 83 | - tensorboard=2.4.1=pyhd8ed1ab_0 84 | - tensorboard-plugin-wit=1.8.0=pyh44b312d_0 85 | - tk=8.6.10=hbc83047_0 86 | - torchaudio=0.7.2=py38 87 | - torchvision=0.8.2=py38_cu102 88 | - tqdm=4.58.0=pyhd8ed1ab_0 89 | - typing-extensions=3.7.4.3=0 90 | - typing_extensions=3.7.4.3=py_0 91 | - urllib3=1.26.3=pyhd8ed1ab_0 92 | - werkzeug=1.0.1=pyh9f0ad1d_0 93 | - wheel=0.36.2=pyhd3eb1b0_0 94 | - xz=5.2.5=h7b6447c_0 95 | - yaml=0.2.5=h516909a_0 96 | - yarl=1.6.3=py38h25fe258_0 97 | - zipp=3.4.0=py_0 98 | - zlib=1.2.11=h7b6447c_3 99 | - zstd=1.4.5=h9ceee32_0 100 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | 4 | import torch 5 | from torchvision import utils 6 | 7 | from utils import load_checkpoint 8 | 9 | 10 | @torch.no_grad() 11 | def get_mean_style(generator, device): 12 | mean_style = None 13 | 14 | for i in range(10): 15 | style = generator.mean_style(torch.randn(1024, 512).to(device)) 16 | 17 | if mean_style is None: 18 | mean_style = style 19 | 20 | else: 21 | mean_style += style 22 | 23 | mean_style /= 10 24 | return mean_style 25 | 26 | @torch.no_grad() 27 | def sample(generator, class_network, n_class, step, mean_style, n_sample, device, seed): 28 | rng = torch.Generator() 29 | rng.manual_seed(seed) 30 | class_ = class_network(torch.tensor([n_class] * n_sample, device=device)) 31 | image = generator( 32 | torch.randn(n_sample, 512, generator=rng).to(device), 33 | step=step, 34 | alpha=1, 35 | mean_style=mean_style, 36 | style_weight=0.7, 37 | task=class_, 38 | ) 39 | 40 | return image 41 | 42 | @torch.no_grad() 43 | def style_mixing(generator, class_network, n_class, step, mean_style, n_source, n_target, device, seed): 44 | rng = torch.Generator() 45 | rng.manual_seed(seed) 46 | source_code = torch.randn(n_source, 512, generator=rng).to(device) 47 | target_code = torch.randn(n_target, 512, generator=rng).to(device) 48 | class_source = class_network(torch.tensor([n_class] * n_source, device=device)) 49 | class_target = class_network(torch.tensor([n_class] * n_target, device=device)) 50 | 51 | shape = 4 * 2 ** step 52 | alpha = 1 53 | 54 | images = [torch.ones(1, 3, shape, shape).to(device) * -1] 55 | 56 | source_image = generator( 57 | source_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7, task=class_source, 58 | ) 59 | target_image = generator( 60 | target_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7, task=class_target, 61 | ) 62 | 63 | images.append(source_image) 64 | 65 | for i in range(n_target): 66 | image = generator( 67 | [target_code[i].unsqueeze(0).repeat(n_source, 1), source_code], 68 | step=step, 69 | alpha=alpha, 70 | mean_style=mean_style, 71 | style_weight=0.7, 72 | mixing_range=(0, 1), 73 | task=class_source, 74 | ) 75 | images.append(target_image[i].unsqueeze(0)) 76 | images.append(image) 77 | 78 | images = torch.cat(images, 0) 79 | 80 | return images 81 | 82 | 83 | def main(args): 84 | print('Loading model...', end=' ') 85 | generator, class_network = load_checkpoint(args) 86 | generator.eval() 87 | print('Loaded') 88 | 89 | mean_style = get_mean_style(generator, args.device) 90 | 91 | step = int(math.log(args.size, 2)) - 2 92 | 93 | print('Generating sample') 94 | img = sample(generator, class_network, args.class_num, step, mean_style, args.n_row * args.n_col, args.device, args.seed) 95 | utils.save_image(img, args.out, nrow=args.n_col, normalize=True, range=(-1, 1)) 96 | 97 | if not args.no_mixing: 98 | print('Generating style mixing') 99 | for j in range(20): 100 | img = style_mixing(generator, class_network, args.class_num, step, mean_style, args.n_col, args.n_row, args.device, args.seed) 101 | utils.save_image( 102 | img, f'{args.out}_sample_mixing_{j}.png', nrow=args.n_col + 1, normalize=True, range=(-1, 1) 103 | ) 104 | 105 | 106 | if __name__ == '__main__': 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument('--seed', type=int, default=2147483647, help='RNG seed') 109 | parser.add_argument('--size', type=int, default=1024, help='size of the image') 110 | parser.add_argument('--n_row', type=int, default=3, help='number of rows of sample matrix') 111 | parser.add_argument('--n_col', type=int, default=5, help='number of columns of sample matrix') 112 | parser.add_argument('checkpoint_path', type=str, help='path to checkpoint file') 113 | parser.add_argument('--device', type=str, default='cuda', help='') 114 | parser.add_argument('--no_mixing', action='store_true', help='Dont generate style mixing samples') 115 | parser.add_argument('--class_num', type=int, default=0, help='Which class to generate') 116 | parser.add_argument('--out', type=str, default='sample.png', help='') 117 | 118 | main(parser.parse_args()) 119 | -------------------------------------------------------------------------------- /latent_interpolation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torchvision.utils import save_image 4 | from utils import seed_everything, load_hparams, load_checkpoint 5 | 6 | 7 | def slerp(low, high, val): 8 | # https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 9 | low_norm = low / torch.norm(low, dim=-1, keepdim=True) 10 | high_norm = high / torch.norm(high, dim=-1, keepdim=True) 11 | omega = torch.acos((low_norm * high_norm).sum(-1)) 12 | so = torch.sin(omega) 13 | res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(-1) * low + (torch.sin(val * omega) / so).unsqueeze(-1) * high 14 | return res 15 | 16 | 17 | def class_to_class(args, hparams, class_network, generator, device): 18 | images = [] 19 | 20 | print(f'Generating interpolation with {args.steps} steps between class {args.class_1} and {args.class_2}, seed {args.seed}') 21 | with torch.no_grad(): 22 | class1 = class_network(torch.tensor(args.class_1, device=device)) 23 | class2 = class_network(torch.tensor(args.class_2, device=device)) 24 | for interpolation_i in range(args.num_interpolations): 25 | z = torch.randn(1, hparams.code_size, device=device) 26 | for alpha in torch.linspace(0, 1, steps=args.steps, device=device): 27 | if args.interpolation == 'linear': 28 | class_interpolation = class1.lerp(class2, alpha) 29 | else: 30 | class_interpolation = slerp(class1, class2, alpha) 31 | image = generator(z, step=6, alpha=1, task=class_interpolation) 32 | images.append(image.cpu()) 33 | 34 | return images 35 | 36 | 37 | def z_to_z(args, hparams, class_network, generator, device): 38 | images = [] 39 | 40 | print(f'Generating stuff') 41 | with torch.no_grad(): 42 | class1 = class_network(torch.tensor(args.class_1, device=device)) 43 | for interpolation_i in range(args.num_interpolations): 44 | z1 = torch.randn(1, hparams.code_size, device=device) 45 | z2 = torch.randn(1, hparams.code_size, device=device) 46 | for alpha in torch.linspace(0, 1, steps=args.steps, device=device): 47 | if args.interpolation == 'linear': 48 | z_interpolation = z1.lerp(z2, alpha) 49 | else: 50 | z_interpolation = slerp(z1, z2, alpha) 51 | image = generator(z_interpolation, step=6, alpha=1, task=class1) 52 | images.append(image.cpu()) 53 | 54 | return images 55 | 56 | 57 | def class_noise(args, hparams, class_network, generator, device): 58 | images = [] 59 | 60 | print(f'Generating stuff, seed {args.seed}') 61 | with torch.no_grad(): 62 | class1 = class_network(torch.tensor(args.class_1, device=device)) 63 | class2 = class_network(torch.tensor(args.class_2, device=device)) 64 | z1 = torch.randn(1, hparams.code_size, device=device) 65 | z2 = torch.randn(1, hparams.code_size, device=device) 66 | for alpha_z in torch.linspace(0, 1, steps=args.steps, device=device): 67 | if args.interpolation == 'linear': 68 | z_interpolation = z1.lerp(z2, alpha_z) 69 | else: 70 | z_interpolation = slerp(z1, z2, alpha_z) 71 | for alpha_class in torch.linspace(0, 1, steps=args.steps, device=device): 72 | if args.interpolation == 'linear': 73 | class_interpolation = class1.lerp(class2, alpha_class) 74 | else: 75 | class_interpolation = slerp(class1, class2, alpha_class) 76 | image = generator(z_interpolation, step=6, alpha=1, task=class_interpolation) 77 | images.append(image.cpu()) 78 | 79 | return images 80 | 81 | 82 | def random_z(args, hparams, class_network, generator, device): 83 | images = [] 84 | 85 | print(f'Generating {args.steps * args.num_interpolations} from class {args.class_1} with random z, seed {args.seed}') 86 | with torch.no_grad(): 87 | class1 = class_network(torch.tensor(args.class_1, device=device)) 88 | for i in range(args.steps * args.num_interpolations): 89 | z = torch.randn(1, hparams.code_size, device=device) 90 | image = generator(z, step=6, alpha=1, task=class1) 91 | images.append(image.cpu()) 92 | 93 | return images 94 | 95 | 96 | def noise_shift(args, hparams, class_network, generator, device): 97 | images = [] 98 | 99 | print(f'Generating stuff') 100 | with torch.no_grad(): 101 | class1 = class_network(torch.tensor(args.class_1, device=device)) 102 | for interpolation_i in range(args.num_interpolations): 103 | z = torch.randn(1, hparams.code_size, device=device) 104 | for alpha in torch.linspace(-1, 1, steps=args.steps, device=device): 105 | image = generator(z + alpha, step=6, alpha=1, task=class1) 106 | images.append(image.cpu()) 107 | 108 | return images 109 | 110 | 111 | def class_shift(args, hparams, class_network, generator, device): 112 | images = [] 113 | 114 | print(f'Generating stuff') 115 | with torch.no_grad(): 116 | class1 = class_network(torch.tensor(args.class_1, device=device)) 117 | for interpolation_i in range(args.num_interpolations): 118 | z = torch.randn(1, hparams.code_size, device=device) 119 | for alpha in torch.linspace(-0.1, 0.1, steps=args.steps, device=device): 120 | image = generator(z, step=6, alpha=1, task=class1 + alpha) 121 | images.append(image.cpu()) 122 | 123 | return images 124 | 125 | 126 | def main(args): 127 | seed_everything(args.seed) 128 | device = args.device 129 | print(f'Loading model {args.checkpoint_path}') 130 | args.from_self_align = False 131 | generator, class_network = load_checkpoint(args) 132 | hparams = load_hparams(args.checkpoint_path) 133 | print('Loaded') 134 | class_network.eval() 135 | generator.eval() 136 | 137 | if args.type == 'class': 138 | images = class_to_class(args, hparams, class_network, generator, device) 139 | elif args.type == 'noise': 140 | images = z_to_z(args, hparams, class_network, generator, device) 141 | elif args.type == 'class_noise': 142 | images = class_noise(args, hparams, class_network, generator, device) 143 | elif args.type == 'random_z': 144 | images = random_z(args, hparams, class_network, generator, device) 145 | elif args.type == 'z_shift': 146 | images = noise_shift(args, hparams, class_network, generator, device) 147 | elif args.type == 'class_shift': 148 | images = class_shift(args, hparams, class_network, generator, device) 149 | else: 150 | raise ValueError 151 | 152 | save_image(torch.cat(images, 0), args.output, nrow=args.steps, padding=0, normalize=True, range=(-1, 1)) 153 | print(f'Saved to {args.output}') 154 | 155 | 156 | def get_args(): 157 | parser = argparse.ArgumentParser(description='Progressive Growing of GANs') 158 | 159 | parser.add_argument('checkpoint_path', type=str, help='model to use') 160 | parser.add_argument('--output', default='interpolation.png', type=str, help='Output file') 161 | parser.add_argument('--steps', default=10, type=int, help='Number of interpolations') 162 | parser.add_argument('--seed', default=2147483647, type=int, help='Random seed') 163 | parser.add_argument('--device', default='cuda', type=str, help='Device to use') 164 | parser.add_argument('--class_1', default=0, type=int, help='First class for interpolation') 165 | parser.add_argument('--class_2', default=1, type=int, help='Second class for interpolation') 166 | parser.add_argument('--interpolation', default='linear', type=str, choices=['linear', 'spherical'], help='Interpolation type') 167 | parser.add_argument('--num_interpolations', default=1, type=int, help='How many interpolations to perform') 168 | parser.add_argument('--type', default='class', type=str, help='Which interpolation to do') 169 | 170 | return parser.parse_args() 171 | 172 | 173 | if __name__ == '__main__': 174 | main(get_args()) 175 | -------------------------------------------------------------------------------- /metric_impl/dc.py: -------------------------------------------------------------------------------- 1 | r"""PyTorch implementation of Density and Coverage (D&C). Based on Reliable Fidelity and Diversity Metrics for 2 | Generative Models https://arxiv.org/abs/2002.09797 and repository 3 | https://github.com/clovaai/generative-evaluation-prdc/blob/master/prdc/prdc.py 4 | """ 5 | from typing import Tuple 6 | import torch 7 | 8 | from piq.base import BaseFeatureMetric 9 | from piq.utils import _validate_input 10 | 11 | from metric_impl.pr import _compute_nearest_neighbour_distances, _compute_pairwise_distance 12 | 13 | 14 | class DC(BaseFeatureMetric): 15 | r"""Interface of Density and Coverage. 16 | It's computed for a whole set of data and uses features from encoder instead of images itself to decrease 17 | computation cost. Density and Coverage can compare two data distributions with different number of samples. 18 | But dimensionalities should match, otherwise it won't be possible to correctly compute statistics. 19 | 20 | Args: 21 | real_features: Samples from data distribution. Shape :math:`(N_x, D)` 22 | fake_features: Samples from generated distribution. Shape :math:`(N_y, D)` 23 | 24 | Returns: 25 | density: Scalar value of the density of image sets features. 26 | coverage: Scalar value of the coverage of image sets features. 27 | 28 | References: 29 | Ferjad Naeem M. et al. (2020). 30 | Reliable Fidelity and Diversity Metrics for Generative Models. 31 | International Conference on Machine Learning, 32 | https://arxiv.org/abs/2002.09797 33 | """ 34 | 35 | def __init__(self, nearest_k: int = 5) -> None: 36 | r""" 37 | Args: 38 | nearest_k: Nearest neighbor to compute the non-parametric representation. Shape :math:`1` 39 | """ 40 | super(DC, self).__init__() 41 | 42 | self.nearest_k = nearest_k 43 | 44 | def compute_metric(self, y_features: torch.Tensor, x_features: torch.Tensor) \ 45 | -> Tuple[torch.Tensor, torch.Tensor]: 46 | r"""Creates non-parametric representations of the manifolds of real and generated data and computes 47 | the density and coverage between them. 48 | 49 | Args: 50 | real_features: Samples from data distribution. Shape :math:`(N_x, D)` 51 | fake_features: Samples from fake distribution. Shape :math:`(N_x, D)` 52 | Returns: 53 | precision: Scalar value of the precision of the generated images. 54 | recall: Scalar value of the recall of the generated images. 55 | """ 56 | # _validate_input([real_features, fake_features], dim_range=(2, 2), size_range=(1, 2)) 57 | real_features = y_features 58 | fake_features = x_features 59 | real_nearest_neighbour_distances = _compute_nearest_neighbour_distances(real_features, self.nearest_k) 60 | distance_real_fake = _compute_pairwise_distance(real_features, fake_features) 61 | 62 | density = (1 / self.nearest_k) * ( 63 | distance_real_fake < real_nearest_neighbour_distances.unsqueeze(1) 64 | ).sum(dim=0).float().mean() 65 | 66 | coverage = ( 67 | distance_real_fake.min(dim=1)[0] < real_nearest_neighbour_distances 68 | ).float().mean() 69 | 70 | return density, coverage 71 | -------------------------------------------------------------------------------- /metric_impl/ppl.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/rosinality/stylegan2-pytorch/blob/master/ppl.py 2 | import argparse 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | import lpips 10 | from utils import load_hparams, load_checkpoint 11 | 12 | 13 | def normalize(x): 14 | return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True)) 15 | 16 | 17 | def slerp(a, b, t): 18 | a = normalize(a) 19 | b = normalize(b) 20 | d = (a * b).sum(-1, keepdim=True) 21 | p = t * torch.acos(d) 22 | c = normalize(b - d * a) 23 | d = a * torch.cos(p) + c * torch.sin(p) 24 | 25 | return normalize(d) 26 | 27 | 28 | def lerp(a, b, t): 29 | return a + (b - a) * t 30 | 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser(description="Perceptual Path Length calculator") 34 | 35 | parser.add_argument( 36 | "--space", choices=["z", "w"], default="w", help="space that PPL calculated with" 37 | ) 38 | parser.add_argument( 39 | "--batch", type=int, default=64, help="batch size for the models" 40 | ) 41 | parser.add_argument( 42 | "--n_sample", 43 | type=int, 44 | default=5000, 45 | help="number of the samples for calculating PPL", 46 | ) 47 | parser.add_argument( 48 | "--size", type=int, default=256, help="output image sizes of the generator" 49 | ) 50 | parser.add_argument( 51 | "--eps", type=float, default=1e-4, help="epsilon for numerical stability" 52 | ) 53 | parser.add_argument( 54 | "--crop", action="store_true", help="apply center crop to the images" 55 | ) 56 | parser.add_argument( 57 | "--sampling", 58 | default="end", 59 | choices=["end", "full"], 60 | help="set endpoint sampling method", 61 | ) 62 | parser.add_argument('checkpoint_path', type=str, help='model to use') 63 | parser.add_argument('--device', type=str, default='cuda', help='model to use') 64 | parser.add_argument('--class_1', default=0, type=int, help='First class for interpolation') 65 | parser.add_argument('--class_2', default=1, type=int, help='Second class for interpolation') 66 | 67 | args = parser.parse_args() 68 | args.from_self_align = False 69 | device = args.device 70 | 71 | g, c_net = load_checkpoint(args) 72 | hparams = load_hparams(args.checkpoint_path) 73 | g.eval() 74 | c_net.eval() 75 | 76 | percept = lpips.PerceptualLoss( 77 | model="net-lin", net="vgg", use_gpu=device.startswith("cuda") 78 | ) 79 | 80 | distances = [] 81 | 82 | n_batch = args.n_sample // args.batch 83 | resid = args.n_sample - (n_batch * args.batch) 84 | batch_sizes = [args.batch] * n_batch + [resid] 85 | 86 | with torch.no_grad(): 87 | for batch in tqdm(batch_sizes): 88 | 89 | z = torch.randn([batch * 2, hparams.code_size], device=device) 90 | if args.sampling == "full": 91 | lerp_t = torch.rand(batch, device=device) 92 | else: 93 | lerp_t = torch.zeros(batch, device=device) 94 | 95 | if args.space == "w": 96 | latent_t0, latent_t1 = c_net(torch.tensor(batch * [args.class_1], device=device)), c_net(torch.tensor(batch * [args.class_2], device=device)) 97 | latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None]) 98 | latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps) 99 | latent_e = torch.stack([latent_e0, latent_e1], 1).view(*[batch * 2, hparams.task_size]) 100 | 101 | # image, _ = g([latent_e]) 102 | image = g(z, step=6, alpha=1, task=latent_e) 103 | 104 | if args.crop: 105 | c = image.shape[2] // 8 106 | image = image[:, :, c * 3 : c * 7, c * 2 : c * 6] 107 | 108 | factor = image.shape[2] // 256 109 | 110 | if factor > 1: 111 | image = F.interpolate( 112 | image, size=(256, 256), mode="bilinear", align_corners=False 113 | ) 114 | 115 | dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / ( 116 | args.eps ** 2 117 | ) 118 | distances.append(dist.to("cpu").numpy()) 119 | 120 | distances = np.concatenate(distances, 0) 121 | 122 | lo = np.percentile(distances, 1, interpolation="lower") 123 | hi = np.percentile(distances, 99, interpolation="higher") 124 | filtered_dist = np.extract( 125 | np.logical_and(lo <= distances, distances <= hi), distances 126 | ) 127 | 128 | print("ppl:", filtered_dist.mean()) 129 | -------------------------------------------------------------------------------- /metric_impl/pr.py: -------------------------------------------------------------------------------- 1 | r"""PyTorch implementation of Improved Precision and Recall (P&R). Based on Improved Precision and Recall Metric for 2 | Assessing Generative Models https://arxiv.org/abs/1904.06991 and repository 3 | https://github.com/clovaai/generative-evaluation-prdc/blob/master/prdc/prdc.py 4 | """ 5 | from typing import Tuple, Optional 6 | import torch 7 | 8 | from piq.base import BaseFeatureMetric 9 | from piq.utils import _validate_input 10 | 11 | 12 | def _compute_pairwise_distance(data_x: torch.Tensor, data_y: Optional[torch.Tensor] = None) -> torch.Tensor: 13 | r"""Compute Euclidean distance between :math:`x` and :math:`y`. 14 | 15 | Args: 16 | data_x: Tensor of shape :math:`(N, feature_dim)` 17 | data_y: Tensor of shape :math:`(N, feature_dim)` 18 | Returns: 19 | Tensor of shape :math:`(N, N)` of pairwise distances. 20 | """ 21 | if data_y is None: 22 | data_y = data_x 23 | dists = torch.cdist(data_x, data_y, p=2) 24 | return dists 25 | 26 | 27 | def _get_kth_value(unsorted: torch.Tensor, k: int, axis: int = -1) -> torch.Tensor: 28 | r""" 29 | Args: 30 | unsorted: Tensor of any dimensionality. 31 | k: Int of the :math:`k`-th value to retrieve. 32 | Returns: 33 | kth values along the designated axis. 34 | """ 35 | k_smallests = torch.topk(unsorted, k, dim=axis, largest=False)[0] 36 | kth_values = k_smallests.max(dim=axis)[0] 37 | return kth_values 38 | 39 | 40 | def _compute_nearest_neighbour_distances(input_features: torch.Tensor, nearest_k: int) -> torch.Tensor: 41 | r"""Compute K-nearest neighbour distances. 42 | 43 | Args: 44 | input_features: Tensor of shape :math:`(N, feature_dim)` 45 | nearest_k: Int of the :math:`k`-th nearest neighbour. 46 | Returns: 47 | Distances to :math:`k`-th nearest neighbours. 48 | """ 49 | distances = _compute_pairwise_distance(input_features) 50 | radii = _get_kth_value(distances, k=nearest_k + 1, axis=-1) 51 | return radii 52 | 53 | 54 | class PR(BaseFeatureMetric): 55 | r"""Interface of Improved Precision and Recall. 56 | It's computed for a whole set of data and uses features from encoder instead of images itself to decrease 57 | computation cost. Precision and Recall can compare two data distributions with different number of samples. 58 | But dimensionalities should match, otherwise it won't be possible to correctly compute statistics. 59 | 60 | Args: 61 | real_features: Samples from data distribution. Shape :math:`(N_x, D)` 62 | fake_features: Samples from generated distribution. Shape :math:`(N_y, D)` 63 | 64 | Returns: 65 | precision: Scalar value of the precision of image sets features. 66 | recall: Scalar value of the recall of image sets features. 67 | 68 | References: 69 | Kynkäänniemi T. et al. (2019). 70 | Improved Precision and Recall Metric for Assessing Generative Models. 71 | Advances in Neural Information Processing Systems, 72 | https://arxiv.org/abs/1904.06991 73 | """ 74 | 75 | def __init__(self, nearest_k: int = 5) -> None: 76 | r""" 77 | Args: 78 | nearest_k: Nearest neighbor to compute the non-parametric representation. Shape :math:`1` 79 | """ 80 | super(PR, self).__init__() 81 | 82 | self.nearest_k = nearest_k 83 | 84 | def compute_metric(self, y_features: torch.Tensor, x_features: torch.Tensor) \ 85 | -> Tuple[torch.Tensor, torch.Tensor]: 86 | r"""Creates non-parametric representations of the manifolds of real and generated data and computes 87 | the precision and recall between them. 88 | 89 | Args: 90 | real_features: Samples from data distribution. Shape :math:`(N_x, D)` 91 | fake_features: Samples from fake distribution. Shape :math:`(N_x, D)` 92 | Returns: 93 | precision: Scalar value of the precision of the generated images. 94 | recall: Scalar value of the recall of the generated images. 95 | """ 96 | # _validate_input([real_features, fake_features], dim_range=(2, 2), size_range=(1, 2)) 97 | real_features = y_features 98 | fake_features = x_features 99 | real_nearest_neighbour_distances = _compute_nearest_neighbour_distances(real_features, self.nearest_k) 100 | fake_nearest_neighbour_distances = _compute_nearest_neighbour_distances(fake_features, self.nearest_k) 101 | distance_real_fake = _compute_pairwise_distance(real_features, fake_features) 102 | 103 | precision = ( 104 | distance_real_fake < real_nearest_neighbour_distances.unsqueeze(1) 105 | ).any(dim=0).float().mean() 106 | 107 | recall = ( 108 | distance_real_fake < fake_nearest_neighbour_distances.unsqueeze(0) 109 | ).any(dim=1).float().mean() 110 | 111 | return precision, recall 112 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from typing import Optional, Tuple, Union 4 | from io import BytesIO 5 | 6 | from tqdm import tqdm, trange 7 | import numpy as np 8 | import lmdb 9 | import torch 10 | from torchvision.transforms import ToTensor 11 | from dataset import MultiResolutionDataset 12 | from torch.utils.data import DataLoader 13 | from piq.feature_extractors import InceptionV3 14 | from piq import FID, KID 15 | 16 | from metric_impl.pr import PR 17 | from metric_impl.dc import DC 18 | from utils import load_hparams, load_checkpoint 19 | 20 | 21 | def to_bytes(img): 22 | buffer = BytesIO() 23 | torch.save(img, buffer) 24 | return buffer.getvalue() 25 | 26 | 27 | def get_z(num_samples=1, code_dim=512, device='cuda'): 28 | return torch.randn(num_samples, code_dim, device=device) 29 | 30 | 31 | def clamp_img(img: torch.Tensor, range: Optional[Tuple[int, int]] = (-1, 1)) -> None: 32 | img.clamp_(range[0], range[1]) 33 | 34 | 35 | def featurize(extractor, imgs): 36 | with torch.no_grad(): 37 | feats = extractor(imgs) 38 | return feats[0].squeeze(-1).squeeze(-1) 39 | 40 | 41 | def pick_samples(orig_size, selection_size, method): 42 | assert method in ['first', 'random'] 43 | if selection_size is None: 44 | return torch.arange(orig_size) 45 | if selection_size > orig_size: 46 | selection_size = orig_size 47 | 48 | if method == 'first': 49 | return torch.arange(selection_size) 50 | elif method == 'random': 51 | return torch.from_numpy(np.random.choice(orig_size, size=selection_size, replace=False)) 52 | 53 | 54 | def load_feats(path, classes=1, load_n=None, sampling_method='first', device='cpu'): 55 | """ 56 | Load precomputed inception features from an LMDB file. 57 | 58 | :param path: Path to the features. 59 | :param classes: Int with the number of classes to load or iterable consisting of class indices. 60 | :param load_n: Will load N samples per class. To pick which ones, see sampling_method param. 61 | :param sampling_method: Pick the first N samples for each class, or pick N random ones. 62 | :param device: Which device to load features to. 63 | """ 64 | if isinstance(classes, int): 65 | classes_iter = range(classes) 66 | elif isinstance(classes, tuple) or isinstance(classes, list): 67 | classes_iter = classes 68 | else: 69 | raise ValueError 70 | 71 | feats = {} 72 | with lmdb.open(path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) as env: 73 | for c in classes_iter: 74 | with env.begin(write=False) as txn: 75 | key = f'{c}'.encode('utf-8') 76 | feat = txn.get(key) 77 | feat = torch.load(BytesIO(feat), map_location=device) 78 | # in the loop bc classes can have different number of samples 79 | idx = pick_samples(orig_size=feat.shape[0], selection_size=load_n, method=sampling_method) 80 | # clone for keeping only the slice in memory, not the whole underlying tensor storage 81 | feats[c] = feat[idx].clone() 82 | 83 | return feats 84 | 85 | 86 | def generate(z, model, task, classes, step=6, device='cuda'): 87 | if isinstance(classes, tuple) or isinstance(classes, list) or isinstance(classes, range): 88 | classes = torch.tensor(classes, device=device) 89 | elif isinstance(classes, torch.Tensor): 90 | if classes.device != device: 91 | classes = classes.to(device) 92 | else: 93 | raise ValueError("Wrong classes type") 94 | 95 | with torch.no_grad(): 96 | task_input = task(classes) 97 | output = model(z, task=task_input, step=step) 98 | return output 99 | 100 | 101 | def generate_class(model, task, class_i, num_samples=10, code_dim=512, device='cuda'): 102 | if isinstance(class_i, torch.Tensor): 103 | conditioning = class_i 104 | else: 105 | conditioning = [class_i] * num_samples 106 | z = get_z(num_samples, code_dim=code_dim, device=device) 107 | return generate(z, model, task, conditioning, device=device) 108 | 109 | 110 | def generate_metrics_by_class(generator, task, metrics: dict, real_feats: Union[torch.Tensor, dict], num_classes=1, num_samples=10_000, batch_size=2, code_dim=512, feats_dim=2048, device='cuda'): 111 | metric_results_classes = {} 112 | feature_extractor = InceptionV3(normalize_input=False) # already (-1, 1) range from generator output 113 | feature_extractor.to(device) 114 | feature_extractor.eval() 115 | 116 | for c in trange(num_classes, desc='Classes', leave=False, disable=num_classes == 1): 117 | metric_results = {} 118 | fake_feats = torch.empty((num_samples, feats_dim), device=device) 119 | last_batch_size = num_samples % batch_size 120 | 121 | for idx, i in enumerate(trange(0, num_samples, batch_size, desc='Samples', leave=False)): 122 | current_bs = last_batch_size if idx == num_samples // batch_size else batch_size 123 | imgs = generate_class(generator, task, c, current_bs, code_dim, device=device) 124 | clamp_img(imgs) # sometimes images go beyond their range, thus Inception extractor will complain 125 | fake_feats[i:i + current_bs] = featurize(feature_extractor, imgs) 126 | 127 | for metric_name, metric in metrics.items(): 128 | fake_i = None if metric_name != 'kid' else real_feats[c].shape[0] # kid needs reals >= fakes 129 | score = metric(x_features=fake_feats[:fake_i], y_features=real_feats[c]) # predicted_feats, target_feats 130 | if isinstance(score, tuple): 131 | score = {f'{i}': s.item() for i, s in enumerate(score)} 132 | elif isinstance(score, torch.Tensor): 133 | score = score.item() 134 | else: 135 | raise RuntimeError('Unknown metric score type') 136 | metric_results[metric_name] = score 137 | 138 | metric_results_classes[c] = metric_results 139 | 140 | return metric_results_classes 141 | 142 | 143 | def generate_metrics_mixed_class(generator, task, metrics: dict, real_feats: Union[torch.Tensor, dict], num_classes=1, num_samples=10_000, batch_size=2, code_dim=512, feats_dim=2048, device='cuda'): 144 | feature_extractor = InceptionV3(normalize_input=False) # already (-1, 1) range from generator output 145 | feature_extractor.to(device) 146 | feature_extractor.eval() 147 | 148 | metric_results = {} 149 | real_feats = real_feats[0] 150 | fake_feats = torch.empty((num_samples, feats_dim), device=device) 151 | last_batch_size = num_samples % batch_size 152 | 153 | for idx, i in enumerate(trange(0, num_samples, batch_size, desc='Samples', leave=False)): 154 | current_bs = last_batch_size if idx == num_samples // batch_size else batch_size 155 | c = torch.from_numpy(np.random.choice(num_classes, size=current_bs)).to(device) 156 | imgs = generate_class(generator, task, c, current_bs, code_dim, device=device) 157 | clamp_img(imgs) # sometimes images go beyond their range, thus Inception extractor will complain 158 | fake_feats[i:i + current_bs] = featurize(feature_extractor, imgs) 159 | 160 | for metric_name, metric in metrics.items(): 161 | fake_i = None if metric_name != 'kid' else real_feats.shape[0] # kid needs reals >= fakes 162 | score = metric(x_features=fake_feats[:fake_i], y_features=real_feats) # predicted_feats, target_feats 163 | if isinstance(score, tuple): 164 | score = {f'{i}': s.item() for i, s in enumerate(score)} 165 | elif isinstance(score, torch.Tensor): 166 | score = score.item() 167 | else: 168 | raise RuntimeError('Unknown metric score type') 169 | metric_results[metric_name] = score 170 | 171 | return {0: metric_results} 172 | 173 | 174 | def compute_feats_from_dataset(db_path, batch_size=1, output_path='output_lmdb', device='cuda'): 175 | feats_dim = 2048 176 | resolution = 256 177 | feature_extractor = InceptionV3(normalize_input=True) 178 | feature_extractor.to(device) 179 | feature_extractor.eval() 180 | 181 | dataset = MultiResolutionDataset(db_path, transform=ToTensor(), resolution=resolution) 182 | num_classes = dataset.num_classes 183 | with lmdb.open(output_path, map_size=1024 ** 4, readahead=False) as env: 184 | for c in trange(num_classes, desc='Classes'): 185 | dataset = MultiResolutionDataset(db_path, transform=ToTensor(), resolution=resolution, selected_classes=[c]) 186 | feats_c = torch.empty((dataset.length, feats_dim), device=device) 187 | for i, (imgs, labels) in enumerate(tqdm(DataLoader(dataset, batch_size=batch_size, num_workers=4), desc='Batches', leave=False)): 188 | feats_c[i * batch_size:i * batch_size + imgs.size(0)] = featurize(feature_extractor, imgs.to(device)) 189 | with env.begin(write=True) as txn: 190 | key = f'{c}'.encode('utf-8') 191 | txn.put(key, to_bytes(feats_c.cpu())) 192 | 193 | 194 | def compute_metrics_per_ckpt(args): 195 | """ 196 | Compute specified metrics class-wise or for all classes and save results in a pickle. It will be run for each 197 | checkpoint it finds, creating a lock file for parallel computation (cheap but useful). 198 | """ 199 | import pickle 200 | mix_dir = 'mixed' if args.classes == 'interclass' else '' 201 | metrics = {'fid': FID(), 'kid': KID(), 'pr': PR(), 'dc': DC()} 202 | hparams = load_hparams(args.checkpoint_path) 203 | if not hasattr(hparams, 'num_classes'): 204 | hparams.num_classes = 1 205 | # 1 class and random if mix classes bc the precomputed feats dataset supposed to be passed has all classes in 206 | # the first class, so getting the first 10k will likely just get all samples from one/two classes 207 | real_feats = load_feats(args.precomputed_feats, classes=1 if args.classes == 'interclass' else hparams.num_classes, 208 | load_n=10_000, sampling_method='random' if args.classes == 'interclass' else 'first', 209 | device=args.device) 210 | 211 | checkpoint_folder = Path(args.checkpoint_path) 212 | if args.classes == 'interclass': 213 | (checkpoint_folder / mix_dir).mkdir(exist_ok=True) 214 | 215 | for epoch in tqdm(list(sorted(checkpoint_folder.glob('*.model'), reverse=True)), desc='Epoch'): 216 | output_score = epoch.parent / mix_dir / (epoch.stem + '.metric') 217 | epoch_lock = epoch.parent / mix_dir / (epoch.stem + '.lock') 218 | if output_score.exists() or epoch_lock.exists(): 219 | continue 220 | epoch_lock.touch(exist_ok=False) 221 | 222 | args.checkpoint_path = str(epoch) 223 | try: 224 | generator, task = load_checkpoint(args, hparams) 225 | except Exception as e: 226 | epoch_lock.unlink() 227 | raise e 228 | 229 | if args.classes == 'interclass': 230 | scores = generate_metrics_mixed_class(generator, task, metrics, real_feats, 231 | num_classes=hparams.num_classes, num_samples=args.num_samples, 232 | batch_size=args.batch_size, code_dim=hparams.code_size, 233 | feats_dim=args.feats_dim, device=args.device) 234 | elif args.classes == 'intraclass': 235 | scores = generate_metrics_by_class(generator, task, metrics, real_feats, 236 | num_classes=hparams.num_classes, num_samples=args.num_samples, 237 | batch_size=args.batch_size, code_dim=hparams.code_size, 238 | feats_dim=args.feats_dim, device=args.device) 239 | else: 240 | raise ValueError 241 | 242 | with open(output_score, 'wb') as f: 243 | pickle.dump(scores, f) 244 | 245 | if args.remove_checkpoint_afterwards: 246 | epoch.unlink() 247 | epoch_lock.unlink() 248 | 249 | 250 | def compute_metrics_biggan(imgs_path, num_samples=10_000, batch_size=2, feats_dim=2048, device='cuda'): 251 | """ 252 | Read BigGAN generations (npz file) and compute metrics. 253 | """ 254 | metrics = {'fid': FID(), 'kid': KID(), 'pr': PR(), 'dc': DC()} 255 | np_imgs = np.load(imgs_path) 256 | print(f'Loading {imgs_path}') 257 | y = np_imgs['y'] 258 | num_classes = len(np.unique(y)) 259 | assert np.array_equal(np.arange(num_classes), np.unique(y)), f'Discontinuous number of classes in {imgs_path}' 260 | 261 | x = {} 262 | for c in trange(num_classes, desc='Loading classes'): 263 | idcs = np.where(y == c)[0][:num_samples] 264 | if len(idcs) < num_samples: 265 | print(f'Class {c} has less number of samples than requested ({num_samples})') 266 | x[c] = np_imgs['x'][idcs] 267 | 268 | real_feats = load_feats(args.precomputed_feats, classes=num_classes, load_n=10_000, device=args.device) 269 | 270 | metric_results_classes = {} 271 | feature_extractor = InceptionV3(normalize_input=True) # [0, 1] to [-1, 1] 272 | feature_extractor.to(device) 273 | feature_extractor.eval() 274 | 275 | for c in trange(num_classes, desc='Classes', leave=False, disable=num_classes == 1): 276 | metric_results = {} 277 | fake_feats = torch.empty((num_samples, feats_dim), device=device) 278 | last_batch_size = num_samples % batch_size 279 | 280 | for idx, i in enumerate(trange(0, num_samples, batch_size, desc='Samples', leave=False)): 281 | current_bs = last_batch_size if idx == num_samples // batch_size else batch_size 282 | imgs = torch.from_numpy(x[c][i:i + current_bs]).to(device) / 255 283 | fake_feats[i:i + current_bs] = featurize(feature_extractor, imgs) 284 | 285 | for metric_name, metric in metrics.items(): 286 | score = metric(real_feats[c], fake_feats) 287 | if isinstance(score, tuple): 288 | score = {f'{i}': s.item() for i, s in enumerate(score)} 289 | elif isinstance(score, torch.Tensor): 290 | score = score.item() 291 | else: 292 | raise RuntimeError('Unknown metric score type') 293 | metric_results[metric_name] = score 294 | 295 | metric_results_classes[c] = metric_results 296 | 297 | return metric_results_classes 298 | 299 | 300 | if __name__ == '__main__': 301 | parser = argparse.ArgumentParser() 302 | parser.add_argument('checkpoint_path', type=str, help='Path of specified dataset') 303 | parser.add_argument('--output_path', type=str, default='Output LMDB for precomputed features') 304 | parser.add_argument('--precomputed_feats', type=str, default=None, help='If the function needs it') 305 | parser.add_argument('--device', type=str, default='cuda') 306 | parser.add_argument('--num_samples', type=int, default=10_000) 307 | parser.add_argument('--batch_size', type=int, default=32) 308 | parser.add_argument('--feats_dim', type=int, default=2048) 309 | parser.add_argument('--remove_checkpoint_afterwards', action='store_true', help='Remove computing metrics') 310 | parser.add_argument('--operation', type=str, default='metrics', choices=['extract-feats', 'metrics', 'metrics-biggan']) 311 | parser.add_argument('--classes', type=str, default='intraclass', choices=['interclass', 'intraclass'], help='Evaluation regarding conditioning') 312 | args = parser.parse_args() 313 | 314 | if args.operation == 'extract-feats': 315 | compute_feats_from_dataset(args.checkpoint_path, batch_size=args.batch_size, output_path=args.output_path, device=args.device) 316 | elif args.operation == 'metrics': 317 | compute_metrics_per_ckpt(args) 318 | elif args.operation == 'metrics-biggan': 319 | print(compute_metrics_biggan(args.checkpoint_path, batch_size=args.batch_size)) 320 | -------------------------------------------------------------------------------- /model/hyper_mod.py: -------------------------------------------------------------------------------- 1 | import random 2 | from math import sqrt 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | from model.stylegan import PixelNorm, equal_lr, EqualLinear, EqualConv2d, ConstantInput, Blur, NoiseInjection,\ 9 | AdaptiveInstanceNorm, ConvBlock 10 | 11 | 12 | class SequentialWithParams(nn.Sequential): 13 | def forward(self, x): 14 | x, task = x 15 | for m in self: 16 | if isinstance(m, (AdaptiveFilterModulation, FusedUpsampleAdaFM)): 17 | x = m(x, task) 18 | else: 19 | x = m(x) 20 | 21 | return x 22 | 23 | 24 | def equalize(weight): 25 | """Batched equalization. Since all kernels have the same shape we can apply the scaling to all batches.""" 26 | fan_in = weight[0].data.size(1) * weight[0].data[0][0].numel() 27 | weight *= sqrt(2 / fan_in) 28 | 29 | 30 | class AdaptiveFilterModulation(nn.Module): 31 | def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, task_dim=0): 32 | super().__init__() 33 | self.stride = stride 34 | self.padding = padding 35 | 36 | # these two will be normalized and initialized when loading from a pretrained model 37 | self.register_buffer('W', torch.empty((out_channel, in_channel, kernel_size, kernel_size))) 38 | self.register_buffer('b', torch.empty(out_channel)) 39 | 40 | self.task = EqualLinear(task_dim, out_channel * in_channel * 2) 41 | self.task.linear.bias.data[:] = 0 42 | 43 | self.task_bias_beta = EqualLinear(task_dim, out_channel) 44 | self.task_bias_beta.linear.bias.data[:] = 0 45 | 46 | def forward(self, x, task): 47 | gamma, beta = self.task(task).view(-1, self.W.shape[0], self.W.shape[1], 2).unsqueeze(4).chunk(2, 3) 48 | bias_beta = self.task_bias_beta(task) 49 | 50 | W_i = self.W * gamma + beta 51 | equalize(W_i) 52 | b_i = self.b + bias_beta 53 | out = conv1to1(x, W_i, bias=b_i, stride=self.stride, padding=self.padding) 54 | return out 55 | 56 | 57 | class FusedUpsampleAdaFM(nn.Module): 58 | def __init__(self, in_channel, out_channel, kernel_size, stride=2, padding=0, task_dim=0): 59 | super().__init__() 60 | self.stride = stride 61 | self.padding = padding 62 | 63 | self.register_buffer('W', torch.empty((in_channel, out_channel, kernel_size, kernel_size))) 64 | self.register_buffer('b', torch.empty(out_channel)) 65 | 66 | self.task = EqualLinear(task_dim, out_channel * in_channel * 2) 67 | self.task.linear.bias.data[:] = 0 68 | 69 | self.task_bias_beta = EqualLinear(task_dim, out_channel) 70 | self.task_bias_beta.linear.bias.data[:] = 0 71 | 72 | def forward(self, x, task): 73 | weight = F.pad(self.W, [1, 1, 1, 1]) 74 | weight = ( 75 | weight[:, :, 1:, 1:] 76 | + weight[:, :, :-1, 1:] 77 | + weight[:, :, 1:, :-1] 78 | + weight[:, :, :-1, :-1] 79 | ) / 4 80 | 81 | gamma, beta = self.task(task).view(-1, weight.shape[0], weight.shape[1], 2).unsqueeze(4).chunk(2, 3) 82 | bias_beta = self.task_bias_beta(task) 83 | 84 | W_i = weight * gamma + beta 85 | equalize(W_i) 86 | b_i = self.b + bias_beta 87 | out = conv1to1(x, W_i, bias=b_i, stride=self.stride, padding=self.padding, transposed=True, mode='upsample') 88 | return out 89 | 90 | 91 | def conv1to1(x, weights, bias=None, stride=1, padding=0, transposed=False, mode='normal'): 92 | """Apply a different kernel to a different sample. result[i] = conv(input[i], weights[i]). 93 | From https://discuss.pytorch.org/t/how-to-apply-different-kernels-to-each-example-in-a-batch-when-using-convolution/84848/3""" 94 | F_conv = F.conv_transpose2d if transposed else F.conv2d 95 | N, C, H, W = x.shape 96 | KN, KO, KI, KH, KW = weights.shape 97 | assert N == KN 98 | 99 | # group weights 100 | weights = weights.view(-1, KI, KH, KW) 101 | bias = bias.view(-1) 102 | 103 | # move batch dim into channels 104 | x = x.view(1, -1, H, W) 105 | 106 | # Apply grouped conv 107 | outputs_grouped = F_conv(x, weights, bias=bias, stride=stride, padding=padding, groups=N) 108 | if mode == 'upsample' or mode == 'downsample' or (outputs_grouped.shape[2] == 1 and outputs_grouped.shape[3] == 1): 109 | outputs_grouped = outputs_grouped.view(N, -1, outputs_grouped.shape[2], outputs_grouped.shape[3]) 110 | else: 111 | outputs_grouped = outputs_grouped.view(N, KO, H, W) 112 | return outputs_grouped 113 | 114 | 115 | class StyledConvBlock(nn.Module): 116 | def __init__( 117 | self, 118 | in_channel, 119 | out_channel, 120 | kernel_size=3, 121 | padding=1, 122 | style_dim=512, 123 | initial=False, 124 | upsample=False, 125 | fused=False, 126 | task_dim=512, 127 | ): 128 | super().__init__() 129 | 130 | if initial: 131 | self.conv1 = SequentialWithParams( 132 | ConstantInput(in_channel) 133 | ) 134 | 135 | else: 136 | if upsample: 137 | if fused: 138 | self.conv1 = SequentialWithParams( 139 | FusedUpsampleAdaFM(in_channel, out_channel, kernel_size, padding=padding, task_dim=task_dim), 140 | Blur(out_channel), 141 | ) 142 | 143 | else: 144 | self.conv1 = SequentialWithParams( 145 | nn.Upsample(scale_factor=2, mode='nearest'), 146 | AdaptiveFilterModulation(in_channel, out_channel, kernel_size, padding=padding, task_dim=task_dim), 147 | Blur(out_channel), 148 | ) 149 | 150 | else: 151 | self.conv1 = SequentialWithParams( 152 | AdaptiveFilterModulation(in_channel, out_channel, kernel_size, padding=padding, task_dim=task_dim)) 153 | 154 | self.noise1 = equal_lr(NoiseInjection(out_channel)) 155 | self.adain1 = AdaptiveInstanceNorm(out_channel, style_dim) 156 | self.lrelu1 = nn.LeakyReLU(0.2) 157 | 158 | self.conv2 = AdaptiveFilterModulation(out_channel, out_channel, kernel_size, padding=padding, task_dim=task_dim) 159 | self.noise2 = equal_lr(NoiseInjection(out_channel)) 160 | self.adain2 = AdaptiveInstanceNorm(out_channel, style_dim) 161 | self.lrelu2 = nn.LeakyReLU(0.2) 162 | 163 | def forward(self, input, style, noise, task): 164 | out = self.conv1((input, task)) 165 | out = self.noise1(out, noise) 166 | out = self.lrelu1(out) 167 | out = self.adain1(out, style) 168 | 169 | out = self.conv2(out, task) 170 | out = self.noise2(out, noise) 171 | out = self.lrelu2(out) 172 | out = self.adain2(out, style) 173 | 174 | return out 175 | 176 | 177 | class Generator(nn.Module): 178 | def __init__(self, code_dim, task_dim, fused=True): 179 | super().__init__() 180 | 181 | self.progression = nn.ModuleList( 182 | [ 183 | StyledConvBlock(512, 512, 3, 1, initial=True, task_dim=task_dim), # 4 184 | StyledConvBlock(512, 512, 3, 1, upsample=True, task_dim=task_dim), # 8 185 | StyledConvBlock(512, 512, 3, 1, upsample=True, task_dim=task_dim), # 16 186 | StyledConvBlock(512, 512, 3, 1, upsample=True, task_dim=task_dim), # 32 187 | StyledConvBlock(512, 256, 3, 1, upsample=True, task_dim=task_dim), # 64 188 | StyledConvBlock(256, 128, 3, 1, upsample=True, fused=fused, task_dim=task_dim), # 128 189 | StyledConvBlock(128, 64, 3, 1, upsample=True, fused=fused, task_dim=task_dim), # 256 190 | StyledConvBlock(64, 32, 3, 1, upsample=True, fused=fused, task_dim=task_dim), # 512 191 | StyledConvBlock(32, 16, 3, 1, upsample=True, fused=fused, task_dim=task_dim), # 1024 192 | ] 193 | ) 194 | 195 | self.to_rgb = nn.ModuleList( 196 | [ 197 | EqualConv2d(512, 3, 1), 198 | EqualConv2d(512, 3, 1), 199 | EqualConv2d(512, 3, 1), 200 | EqualConv2d(512, 3, 1), 201 | EqualConv2d(256, 3, 1), 202 | EqualConv2d(128, 3, 1), 203 | EqualConv2d(64, 3, 1), 204 | EqualConv2d(32, 3, 1), 205 | EqualConv2d(16, 3, 1), 206 | ] 207 | ) 208 | 209 | def forward(self, style, noise, step=0, alpha=-1, mixing_range=(-1, -1), task=None, return_hierarchical=False): 210 | out = noise[0] 211 | hierarchical_out = [] 212 | 213 | if len(style) < 2: 214 | inject_index = [len(self.progression) + 1] 215 | 216 | else: 217 | inject_index = sorted(random.sample(list(range(step)), len(style) - 1)) 218 | 219 | crossover = 0 220 | 221 | for i, (conv, to_rgb) in enumerate(zip(self.progression, self.to_rgb)): 222 | if mixing_range == (-1, -1): 223 | if crossover < len(inject_index) and i > inject_index[crossover]: 224 | crossover = min(crossover + 1, len(style)) 225 | 226 | style_step = style[crossover] 227 | 228 | else: 229 | if mixing_range[0] <= i <= mixing_range[1]: 230 | style_step = style[1] 231 | 232 | else: 233 | style_step = style[0] 234 | 235 | if i > 0 and step > 0: 236 | out_prev = out 237 | 238 | out = conv(out, style_step, noise[i], task) 239 | if return_hierarchical: 240 | hierarchical_out.append(out) 241 | 242 | if i == step: 243 | out = to_rgb(out) 244 | 245 | if i > 0 and 0 <= alpha < 1: 246 | skip_rgb = self.to_rgb[i - 1](out_prev) 247 | skip_rgb = F.interpolate(skip_rgb, scale_factor=2, mode='nearest') 248 | out = (1 - alpha) * skip_rgb + alpha * out 249 | 250 | break 251 | 252 | if return_hierarchical: 253 | return out, hierarchical_out 254 | else: 255 | return out 256 | 257 | 258 | class StyledGenerator(nn.Module): 259 | def __init__(self, code_dim=512, n_mlp=8, task_dim=512): 260 | super().__init__() 261 | 262 | self.generator = Generator(code_dim, task_dim) 263 | 264 | layers = [PixelNorm()] 265 | for i in range(n_mlp): 266 | layers.append(EqualLinear(code_dim, code_dim)) 267 | layers.append(nn.LeakyReLU(0.2)) 268 | 269 | self.style = nn.Sequential(*layers) 270 | 271 | def forward( 272 | self, 273 | input, 274 | noise=None, 275 | step=0, 276 | alpha=-1, 277 | mean_style=None, 278 | style_weight=0, 279 | mixing_range=(-1, -1), 280 | task=None, 281 | return_hierarchical=False, 282 | ): 283 | styles = [] 284 | if type(input) not in (list, tuple): 285 | input = [input] 286 | 287 | for i in input: 288 | styles.append(self.style(i)) 289 | 290 | batch = input[0].shape[0] 291 | 292 | if noise is None: 293 | noise = [] 294 | 295 | for i in range(step + 1): 296 | size = 4 * 2 ** i 297 | noise.append(torch.randn(batch, 1, size, size, device=input[0].device)) 298 | 299 | if mean_style is not None: 300 | styles_norm = [] 301 | 302 | for style in styles: 303 | styles_norm.append(mean_style + style_weight * (style - mean_style)) 304 | 305 | styles = styles_norm 306 | 307 | return self.generator(styles, noise, step, alpha, mixing_range=mixing_range, task=task, 308 | return_hierarchical=return_hierarchical) 309 | 310 | def mean_style(self, input): 311 | style = self.style(input).mean(0, keepdim=True) 312 | 313 | return style 314 | 315 | 316 | class Task(nn.Module): 317 | def __init__(self, code_dim=512, n_mlp=4, num_labels=0): 318 | super().__init__() 319 | 320 | layers = [equal_lr(nn.Embedding(num_labels, code_dim))] 321 | for i in range(n_mlp): 322 | layers.append(EqualLinear(code_dim, code_dim)) 323 | layers.append(nn.LeakyReLU(0.2)) 324 | 325 | self.task = nn.Sequential(*layers) 326 | 327 | def forward(self, x): 328 | return self.task(x) 329 | 330 | 331 | class Discriminator(nn.Module): 332 | def __init__(self, num_classes, fused=True, from_rgb_activate=False): 333 | super().__init__() 334 | 335 | self.progression = nn.ModuleList( 336 | [ 337 | ConvBlock(16, 32, 3, 1, downsample=True, fused=fused), # 512 338 | ConvBlock(32, 64, 3, 1, downsample=True, fused=fused), # 256 339 | ConvBlock(64, 128, 3, 1, downsample=True, fused=fused), # 128 340 | ConvBlock(128, 256, 3, 1, downsample=True, fused=fused), # 64 341 | ConvBlock(256, 512, 3, 1, downsample=True), # 32 342 | ConvBlock(512, 512, 3, 1, downsample=True), # 16 343 | ConvBlock(512, 512, 3, 1, downsample=True), # 8 344 | ConvBlock(512, 512, 3, 1, downsample=True), # 4 345 | ConvBlock(513, 512, 3, 1, 4, 0), 346 | ] 347 | ) 348 | 349 | def make_from_rgb(out_channel): 350 | if from_rgb_activate: 351 | return nn.Sequential(EqualConv2d(3, out_channel, 1), nn.LeakyReLU(0.2)) 352 | 353 | else: 354 | return EqualConv2d(3, out_channel, 1) 355 | 356 | self.from_rgb = nn.ModuleList( 357 | [ 358 | make_from_rgb(16), 359 | make_from_rgb(32), 360 | make_from_rgb(64), 361 | make_from_rgb(128), 362 | make_from_rgb(256), 363 | make_from_rgb(512), 364 | make_from_rgb(512), 365 | make_from_rgb(512), 366 | make_from_rgb(512), 367 | ] 368 | ) 369 | 370 | self.n_layer = len(self.progression) 371 | 372 | self.final = EqualConv2d(512, num_classes, 1) 373 | 374 | def forward(self, input, c, step=0, alpha=-1, return_hierarchical=False): 375 | hierarchical_out = [] 376 | for i in range(step, -1, -1): 377 | index = self.n_layer - i - 1 378 | 379 | if i == step: 380 | out = self.from_rgb[index](input) 381 | 382 | if i == 0: 383 | out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8) 384 | mean_std = out_std.mean() 385 | mean_std = mean_std.expand(out.size(0), 1, 4, 4) 386 | out = torch.cat([out, mean_std], 1) 387 | 388 | out = self.progression[index](out) 389 | if return_hierarchical: 390 | hierarchical_out.append(out) 391 | 392 | if i > 0: 393 | if i == step and 0 <= alpha < 1: 394 | skip_rgb = F.avg_pool2d(input, 2) 395 | skip_rgb = self.from_rgb[index + 1](skip_rgb) 396 | 397 | out = (1 - alpha) * skip_rgb + alpha * out 398 | 399 | 400 | out = self.final(out) 401 | out = out[torch.arange(out.size(0)), c].squeeze(-1) 402 | 403 | if return_hierarchical: 404 | hierarchical_out.append(out) 405 | return hierarchical_out 406 | 407 | return out 408 | -------------------------------------------------------------------------------- /model/stylegan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | from torch.nn import init 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | 8 | from math import sqrt 9 | 10 | import random 11 | 12 | 13 | def init_linear(linear): 14 | init.xavier_normal(linear.weight) 15 | linear.bias.data.zero_() 16 | 17 | 18 | def init_conv(conv, glu=True): 19 | init.kaiming_normal(conv.weight) 20 | if conv.bias is not None: 21 | conv.bias.data.zero_() 22 | 23 | 24 | class EqualLR: 25 | def __init__(self, name): 26 | self.name = name 27 | 28 | def compute_weight(self, module): 29 | weight = getattr(module, self.name + '_orig') 30 | fan_in = weight.data.size(1) * weight.data[0][0].numel() 31 | 32 | return weight * sqrt(2 / fan_in) 33 | 34 | @staticmethod 35 | def apply(module, name): 36 | fn = EqualLR(name) 37 | 38 | weight = getattr(module, name) 39 | del module._parameters[name] 40 | module.register_parameter(name + '_orig', nn.Parameter(weight.data)) 41 | module.register_forward_pre_hook(fn) 42 | 43 | return fn 44 | 45 | def __call__(self, module, input): 46 | weight = self.compute_weight(module) 47 | setattr(module, self.name, weight) 48 | 49 | 50 | def equal_lr(module, name='weight'): 51 | EqualLR.apply(module, name) 52 | 53 | return module 54 | 55 | 56 | class FusedUpsample(nn.Module): 57 | def __init__(self, in_channel, out_channel, kernel_size, padding=0): 58 | super().__init__() 59 | 60 | weight = torch.randn(in_channel, out_channel, kernel_size, kernel_size) 61 | bias = torch.zeros(out_channel) 62 | 63 | fan_in = in_channel * kernel_size * kernel_size 64 | self.multiplier = sqrt(2 / fan_in) 65 | 66 | self.weight = nn.Parameter(weight) 67 | self.bias = nn.Parameter(bias) 68 | 69 | self.pad = padding 70 | 71 | def forward(self, input): 72 | weight = F.pad(self.weight * self.multiplier, [1, 1, 1, 1]) 73 | weight = ( 74 | weight[:, :, 1:, 1:] 75 | + weight[:, :, :-1, 1:] 76 | + weight[:, :, 1:, :-1] 77 | + weight[:, :, :-1, :-1] 78 | ) / 4 79 | 80 | out = F.conv_transpose2d(input, weight, self.bias, stride=2, padding=self.pad) 81 | 82 | return out 83 | 84 | 85 | class FusedDownsample(nn.Module): 86 | def __init__(self, in_channel, out_channel, kernel_size, padding=0): 87 | super().__init__() 88 | 89 | weight = torch.randn(out_channel, in_channel, kernel_size, kernel_size) 90 | bias = torch.zeros(out_channel) 91 | 92 | fan_in = in_channel * kernel_size * kernel_size 93 | self.multiplier = sqrt(2 / fan_in) 94 | 95 | self.weight = nn.Parameter(weight) 96 | self.bias = nn.Parameter(bias) 97 | 98 | self.pad = padding 99 | 100 | def forward(self, input): 101 | weight = F.pad(self.weight * self.multiplier, [1, 1, 1, 1]) 102 | weight = ( 103 | weight[:, :, 1:, 1:] 104 | + weight[:, :, :-1, 1:] 105 | + weight[:, :, 1:, :-1] 106 | + weight[:, :, :-1, :-1] 107 | ) / 4 108 | 109 | out = F.conv2d(input, weight, self.bias, stride=2, padding=self.pad) 110 | 111 | return out 112 | 113 | 114 | class PixelNorm(nn.Module): 115 | def __init__(self): 116 | super().__init__() 117 | 118 | def forward(self, input): 119 | return input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 120 | 121 | 122 | class BlurFunctionBackward(Function): 123 | @staticmethod 124 | def forward(ctx, grad_output, kernel, kernel_flip): 125 | ctx.save_for_backward(kernel, kernel_flip) 126 | 127 | grad_input = F.conv2d( 128 | grad_output, kernel_flip, padding=1, groups=grad_output.shape[1] 129 | ) 130 | 131 | return grad_input 132 | 133 | @staticmethod 134 | def backward(ctx, gradgrad_output): 135 | kernel, kernel_flip = ctx.saved_tensors 136 | 137 | grad_input = F.conv2d( 138 | gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1] 139 | ) 140 | 141 | return grad_input, None, None 142 | 143 | 144 | class BlurFunction(Function): 145 | @staticmethod 146 | def forward(ctx, input, kernel, kernel_flip): 147 | ctx.save_for_backward(kernel, kernel_flip) 148 | 149 | output = F.conv2d(input, kernel, padding=1, groups=input.shape[1]) 150 | 151 | return output 152 | 153 | @staticmethod 154 | def backward(ctx, grad_output): 155 | kernel, kernel_flip = ctx.saved_tensors 156 | 157 | grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip) 158 | 159 | return grad_input, None, None 160 | 161 | 162 | blur = BlurFunction.apply 163 | 164 | 165 | class Blur(nn.Module): 166 | def __init__(self, channel): 167 | super().__init__() 168 | 169 | weight = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32) 170 | weight = weight.view(1, 1, 3, 3) 171 | weight = weight / weight.sum() 172 | weight_flip = torch.flip(weight, [2, 3]) 173 | 174 | self.register_buffer('weight', weight.repeat(channel, 1, 1, 1)) 175 | self.register_buffer('weight_flip', weight_flip.repeat(channel, 1, 1, 1)) 176 | 177 | def forward(self, input): 178 | return blur(input, self.weight, self.weight_flip) 179 | # return F.conv2d(input, self.weight, padding=1, groups=input.shape[1]) 180 | 181 | 182 | class EqualConv2d(nn.Module): 183 | def __init__(self, *args, **kwargs): 184 | super().__init__() 185 | 186 | conv = nn.Conv2d(*args, **kwargs) 187 | conv.weight.data.normal_() 188 | conv.bias.data.zero_() 189 | self.conv = equal_lr(conv) 190 | 191 | def forward(self, input): 192 | return self.conv(input) 193 | 194 | 195 | class EqualLinear(nn.Module): 196 | def __init__(self, in_dim, out_dim): 197 | super().__init__() 198 | 199 | linear = nn.Linear(in_dim, out_dim) 200 | linear.weight.data.normal_() 201 | linear.bias.data.zero_() 202 | 203 | self.linear = equal_lr(linear) 204 | 205 | def forward(self, input): 206 | return self.linear(input) 207 | 208 | 209 | class ConvBlock(nn.Module): 210 | def __init__( 211 | self, 212 | in_channel, 213 | out_channel, 214 | kernel_size, 215 | padding, 216 | kernel_size2=None, 217 | padding2=None, 218 | downsample=False, 219 | fused=False, 220 | ): 221 | super().__init__() 222 | 223 | pad1 = padding 224 | pad2 = padding 225 | if padding2 is not None: 226 | pad2 = padding2 227 | 228 | kernel1 = kernel_size 229 | kernel2 = kernel_size 230 | if kernel_size2 is not None: 231 | kernel2 = kernel_size2 232 | 233 | self.conv1 = nn.Sequential( 234 | EqualConv2d(in_channel, out_channel, kernel1, padding=pad1), 235 | nn.LeakyReLU(0.2), 236 | ) 237 | 238 | if downsample: 239 | if fused: 240 | self.conv2 = nn.Sequential( 241 | Blur(out_channel), 242 | FusedDownsample(out_channel, out_channel, kernel2, padding=pad2), 243 | nn.LeakyReLU(0.2), 244 | ) 245 | 246 | else: 247 | self.conv2 = nn.Sequential( 248 | Blur(out_channel), 249 | EqualConv2d(out_channel, out_channel, kernel2, padding=pad2), 250 | nn.AvgPool2d(2), 251 | nn.LeakyReLU(0.2), 252 | ) 253 | 254 | else: 255 | self.conv2 = nn.Sequential( 256 | EqualConv2d(out_channel, out_channel, kernel2, padding=pad2), 257 | nn.LeakyReLU(0.2), 258 | ) 259 | 260 | def forward(self, input): 261 | out = self.conv1(input) 262 | out = self.conv2(out) 263 | 264 | return out 265 | 266 | 267 | class AdaptiveInstanceNorm(nn.Module): 268 | def __init__(self, in_channel, style_dim): 269 | super().__init__() 270 | 271 | self.norm = nn.InstanceNorm2d(in_channel) 272 | self.style = EqualLinear(style_dim, in_channel * 2) 273 | 274 | self.style.linear.bias.data[:in_channel] = 1 275 | self.style.linear.bias.data[in_channel:] = 0 276 | 277 | def forward(self, input, style): 278 | style = self.style(style).unsqueeze(2).unsqueeze(3) 279 | gamma, beta = style.chunk(2, 1) 280 | 281 | out = self.norm(input) 282 | out = gamma * out + beta 283 | 284 | return out 285 | 286 | 287 | class NoiseInjection(nn.Module): 288 | def __init__(self, channel): 289 | super().__init__() 290 | 291 | self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1)) 292 | 293 | def forward(self, image, noise): 294 | return image + self.weight * noise 295 | 296 | 297 | class ConstantInput(nn.Module): 298 | def __init__(self, channel, size=4): 299 | super().__init__() 300 | 301 | self.input = nn.Parameter(torch.randn(1, channel, size, size)) 302 | 303 | def forward(self, input): 304 | batch = input.shape[0] 305 | out = self.input.repeat(batch, 1, 1, 1) 306 | 307 | return out 308 | 309 | 310 | class StyledConvBlock(nn.Module): 311 | def __init__( 312 | self, 313 | in_channel, 314 | out_channel, 315 | kernel_size=3, 316 | padding=1, 317 | style_dim=512, 318 | initial=False, 319 | upsample=False, 320 | fused=False, 321 | ): 322 | super().__init__() 323 | 324 | if initial: 325 | self.conv1 = ConstantInput(in_channel) 326 | 327 | else: 328 | if upsample: 329 | if fused: 330 | self.conv1 = nn.Sequential( 331 | FusedUpsample( 332 | in_channel, out_channel, kernel_size, padding=padding 333 | ), 334 | Blur(out_channel), 335 | ) 336 | 337 | else: 338 | self.conv1 = nn.Sequential( 339 | nn.Upsample(scale_factor=2, mode='nearest'), 340 | EqualConv2d( 341 | in_channel, out_channel, kernel_size, padding=padding 342 | ), 343 | Blur(out_channel), 344 | ) 345 | 346 | else: 347 | self.conv1 = EqualConv2d( 348 | in_channel, out_channel, kernel_size, padding=padding 349 | ) 350 | 351 | self.noise1 = equal_lr(NoiseInjection(out_channel)) 352 | self.adain1 = AdaptiveInstanceNorm(out_channel, style_dim) 353 | self.lrelu1 = nn.LeakyReLU(0.2) 354 | 355 | self.conv2 = EqualConv2d(out_channel, out_channel, kernel_size, padding=padding) 356 | self.noise2 = equal_lr(NoiseInjection(out_channel)) 357 | self.adain2 = AdaptiveInstanceNorm(out_channel, style_dim) 358 | self.lrelu2 = nn.LeakyReLU(0.2) 359 | 360 | def forward(self, input, style, noise): 361 | out = self.conv1(input) 362 | out = self.noise1(out, noise) 363 | out = self.lrelu1(out) 364 | out = self.adain1(out, style) 365 | 366 | out = self.conv2(out) 367 | out = self.noise2(out, noise) 368 | out = self.lrelu2(out) 369 | out = self.adain2(out, style) 370 | 371 | return out 372 | 373 | 374 | class Generator(nn.Module): 375 | def __init__(self, code_dim, fused=True): 376 | super().__init__() 377 | 378 | self.progression = nn.ModuleList( 379 | [ 380 | StyledConvBlock(512, 512, 3, 1, initial=True), # 4 381 | StyledConvBlock(512, 512, 3, 1, upsample=True), # 8 382 | StyledConvBlock(512, 512, 3, 1, upsample=True), # 16 383 | StyledConvBlock(512, 512, 3, 1, upsample=True), # 32 384 | StyledConvBlock(512, 256, 3, 1, upsample=True), # 64 385 | StyledConvBlock(256, 128, 3, 1, upsample=True, fused=fused), # 128 386 | StyledConvBlock(128, 64, 3, 1, upsample=True, fused=fused), # 256 387 | StyledConvBlock(64, 32, 3, 1, upsample=True, fused=fused), # 512 388 | StyledConvBlock(32, 16, 3, 1, upsample=True, fused=fused), # 1024 389 | ] 390 | ) 391 | 392 | self.to_rgb = nn.ModuleList( 393 | [ 394 | EqualConv2d(512, 3, 1), 395 | EqualConv2d(512, 3, 1), 396 | EqualConv2d(512, 3, 1), 397 | EqualConv2d(512, 3, 1), 398 | EqualConv2d(256, 3, 1), 399 | EqualConv2d(128, 3, 1), 400 | EqualConv2d(64, 3, 1), 401 | EqualConv2d(32, 3, 1), 402 | EqualConv2d(16, 3, 1), 403 | ] 404 | ) 405 | 406 | # self.blur = Blur() 407 | 408 | def forward(self, style, noise, step=0, alpha=-1, mixing_range=(-1, -1), return_hierarchical=False): 409 | out = noise[0] 410 | hierarchical_out = [] 411 | 412 | if len(style) < 2: 413 | inject_index = [len(self.progression) + 1] 414 | 415 | else: 416 | inject_index = sorted(random.sample(list(range(step)), len(style) - 1)) 417 | 418 | crossover = 0 419 | 420 | for i, (conv, to_rgb) in enumerate(zip(self.progression, self.to_rgb)): 421 | if mixing_range == (-1, -1): 422 | if crossover < len(inject_index) and i > inject_index[crossover]: 423 | crossover = min(crossover + 1, len(style)) 424 | 425 | style_step = style[crossover] 426 | 427 | else: 428 | if mixing_range[0] <= i <= mixing_range[1]: 429 | style_step = style[1] 430 | 431 | else: 432 | style_step = style[0] 433 | 434 | if i > 0 and step > 0: 435 | out_prev = out 436 | 437 | out = conv(out, style_step, noise[i]) 438 | if return_hierarchical: 439 | hierarchical_out.append(out) 440 | 441 | if i == step: 442 | out = to_rgb(out) 443 | 444 | if i > 0 and 0 <= alpha < 1: 445 | skip_rgb = self.to_rgb[i - 1](out_prev) 446 | skip_rgb = F.interpolate(skip_rgb, scale_factor=2, mode='nearest') 447 | out = (1 - alpha) * skip_rgb + alpha * out 448 | 449 | break 450 | 451 | if return_hierarchical: 452 | return out, hierarchical_out 453 | else: 454 | return out 455 | 456 | 457 | class StyledGenerator(nn.Module): 458 | def __init__(self, code_dim=512, n_mlp=8, c_dim=0): 459 | super().__init__() 460 | 461 | self.generator = Generator(code_dim) 462 | 463 | layers = [PixelNorm()] 464 | for i in range(n_mlp): 465 | layers.append(EqualLinear(code_dim + (code_dim if i == 0 and c_dim > 0 else 0), code_dim)) 466 | layers.append(nn.LeakyReLU(0.2)) 467 | 468 | self.style = nn.Sequential(*layers) 469 | self.embed = EqualLinear(c_dim, code_dim) if c_dim > 0 else None 470 | 471 | def forward( 472 | self, 473 | input, 474 | noise=None, 475 | step=0, 476 | alpha=-1, 477 | mean_style=None, 478 | style_weight=0, 479 | mixing_range=(-1, -1), 480 | return_hierarchical=False, 481 | c=None, 482 | ): 483 | styles = [] 484 | if type(input) not in (list, tuple): 485 | input = [input] 486 | 487 | for i in input: 488 | if c is not None: 489 | i = torch.cat([i, self.embed(c)], dim=1) 490 | styles.append(self.style(i)) 491 | 492 | batch = input[0].shape[0] 493 | 494 | if noise is None: 495 | noise = [] 496 | 497 | for i in range(step + 1): 498 | size = 4 * 2 ** i 499 | noise.append(torch.randn(batch, 1, size, size, device=input[0].device)) 500 | 501 | if mean_style is not None: 502 | styles_norm = [] 503 | 504 | for style in styles: 505 | styles_norm.append(mean_style + style_weight * (style - mean_style)) 506 | 507 | styles = styles_norm 508 | 509 | return self.generator(styles, noise, step, alpha, mixing_range=mixing_range, 510 | return_hierarchical=return_hierarchical) 511 | 512 | def mean_style(self, input): 513 | style = self.style(input).mean(0, keepdim=True) 514 | 515 | return style 516 | 517 | 518 | class Discriminator(nn.Module): 519 | def __init__(self, fused=True, from_rgb_activate=False, num_labels=0): 520 | super().__init__() 521 | 522 | self.progression = nn.ModuleList( 523 | [ 524 | ConvBlock(16, 32, 3, 1, downsample=True, fused=fused), # 512 525 | ConvBlock(32, 64, 3, 1, downsample=True, fused=fused), # 256 526 | ConvBlock(64, 128, 3, 1, downsample=True, fused=fused), # 128 527 | ConvBlock(128, 256, 3, 1, downsample=True, fused=fused), # 64 528 | ConvBlock(256, 512, 3, 1, downsample=True), # 32 529 | ConvBlock(512, 512, 3, 1, downsample=True), # 16 530 | ConvBlock(512, 512, 3, 1, downsample=True), # 8 531 | ConvBlock(512, 512, 3, 1, downsample=True), # 4 532 | ConvBlock(513, 512, 3, 1, 4, 0), 533 | ] 534 | ) 535 | 536 | def make_from_rgb(out_channel): 537 | if from_rgb_activate: 538 | return nn.Sequential(EqualConv2d(3, out_channel, 1), nn.LeakyReLU(0.2)) 539 | 540 | else: 541 | return EqualConv2d(3, out_channel, 1) 542 | 543 | self.from_rgb = nn.ModuleList( 544 | [ 545 | make_from_rgb(16), 546 | make_from_rgb(32), 547 | make_from_rgb(64), 548 | make_from_rgb(128), 549 | make_from_rgb(256), 550 | make_from_rgb(512), 551 | make_from_rgb(512), 552 | make_from_rgb(512), 553 | make_from_rgb(512), 554 | ] 555 | ) 556 | 557 | # self.blur = Blur() 558 | 559 | self.n_layer = len(self.progression) 560 | 561 | self.linear = EqualLinear(512, 1 if num_labels == 0 else num_labels) 562 | 563 | def forward(self, input, step=0, alpha=-1, c=None): 564 | for i in range(step, -1, -1): 565 | index = self.n_layer - i - 1 566 | 567 | if i == step: 568 | out = self.from_rgb[index](input) 569 | 570 | if i == 0: 571 | out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8) 572 | mean_std = out_std.mean() 573 | mean_std = mean_std.expand(out.size(0), 1, 4, 4) 574 | out = torch.cat([out, mean_std], 1) 575 | 576 | out = self.progression[index](out) 577 | 578 | if i > 0: 579 | if i == step and 0 <= alpha < 1: 580 | skip_rgb = F.avg_pool2d(input, 2) 581 | skip_rgb = self.from_rgb[index + 1](skip_rgb) 582 | 583 | out = (1 - alpha) * skip_rgb + alpha * out 584 | 585 | out = out.squeeze(2).squeeze(2) 586 | # print(input.size(), out.size(), step) 587 | out = self.linear(out) 588 | 589 | if c is not None: 590 | out = (out * c).sum(dim=1) # keepdims? 591 | 592 | return out 593 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from io import BytesIO 3 | import multiprocessing 4 | from functools import partial 5 | 6 | from PIL import Image 7 | import lmdb 8 | from tqdm import tqdm 9 | from torchvision import datasets 10 | from torchvision.transforms import functional as trans_fn 11 | 12 | 13 | def resize_and_convert(img, size, quality=100): 14 | img = trans_fn.resize(img, size, Image.LANCZOS) 15 | img = trans_fn.center_crop(img, size) 16 | buffer = BytesIO() 17 | img.save(buffer, format='jpeg', quality=quality) 18 | val = buffer.getvalue() 19 | 20 | return val 21 | 22 | 23 | def resize_multiple(img, sizes=(8, 16, 32, 64, 128, 256, 512, 1024), quality=100): 24 | imgs = [] 25 | 26 | for size in sizes: 27 | imgs.append(resize_and_convert(img, size, quality)) 28 | 29 | return imgs 30 | 31 | 32 | def resize_worker(img_file, sizes): 33 | i, file, label = img_file 34 | img = Image.open(file) 35 | img = img.convert('RGB') 36 | out = resize_multiple(img, sizes=sizes) 37 | 38 | return i, out, label 39 | 40 | 41 | def prepare(transaction, dataset, n_worker, sizes=(8, 16, 32, 64, 128, 256, 512, 1024)): 42 | resize_fn = partial(resize_worker, sizes=sizes) 43 | 44 | files = sorted(dataset.imgs, key=lambda x: x[0]) 45 | files = [(i, file, label) for i, (file, label) in enumerate(files)] 46 | num_classes = len(set(label for i, f, label in files)) 47 | total = 0 48 | 49 | with multiprocessing.Pool(n_worker) as pool: 50 | for i, imgs, label in tqdm(pool.imap_unordered(resize_fn, files), total=len(files)): 51 | for size, img in zip(sizes, imgs): 52 | key = f'{size}-{str(i).zfill(5)}'.encode('utf-8') 53 | transaction.put(key, img) 54 | 55 | key = f'label-{str(i).zfill(5)}'.encode('utf-8') 56 | transaction.put(key, str(label).encode('utf-8')) 57 | 58 | total += 1 59 | 60 | transaction.put('length'.encode('utf-8'), str(total).encode('utf-8')) 61 | transaction.put('num_classes'.encode('utf-8'), str(num_classes).encode('utf-8')) 62 | 63 | 64 | if __name__ == '__main__': 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--out', type=str) 67 | parser.add_argument('--n_worker', type=int, default=8) 68 | parser.add_argument('path', type=str) 69 | parser.add_argument('--labels', action='store_true', help='Add labels too') 70 | parser.add_argument('--sizes', nargs='+', type=int, default=(8, 16, 32, 64, 128, 256, 512, 1024)) 71 | parser.add_argument('--min_size', type=int, default=None, help='Minimum image size to not drop it') 72 | 73 | args = parser.parse_args() 74 | 75 | if args.min_size is None: 76 | valid_file = None 77 | else: 78 | valid_file = lambda x: min(Image.open(x).size) >= args.min_size 79 | 80 | imgset = datasets.ImageFolder(args.path, is_valid_file=valid_file) 81 | 82 | with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env: 83 | with env.begin(write=True) as txn: 84 | prepare(txn, imgset, args.n_worker, args.sizes) 85 | -------------------------------------------------------------------------------- /pretrained_converter.py: -------------------------------------------------------------------------------- 1 | """Utilities to fiddle with the models. Convert a pre-trained StyleGAN. Freeze all but hypernetwork layers.""" 2 | from math import sqrt 3 | 4 | 5 | def conv(num, layer=None, parameter='weight', fused=False): 6 | return f'conv{num}.{"" if layer is None else str(layer) + "."}{"" if fused else "conv."}{parameter}' 7 | 8 | 9 | def equalization_norm(w): 10 | fan_in = w.data.size(1) * w.data[0][0].numel() 11 | w = w * sqrt(2 / fan_in) 12 | return w 13 | 14 | 15 | def normalize_conv2d_weight(w): 16 | w = equalization_norm(w) 17 | w_mu = w.mean((2, 3), keepdim=True) 18 | w_std = w.std((2, 3), keepdim=True) 19 | w = (w - w_mu) / w_std 20 | return w 21 | 22 | 23 | def normalize_conv2d_bias(w): 24 | w_mu = w.mean() 25 | w_std = w.std() 26 | w = (w - w_mu) / w_std 27 | return w 28 | 29 | 30 | def rename_and_norm(state_new, k, parameter, fused=False): 31 | minus_layers = 1 if fused else 2 32 | 33 | if parameter == 'weight': 34 | state_new[k.rsplit('.', minus_layers)[0] + '.W'] = normalize_conv2d_weight(state_new.pop(k)) 35 | if parameter == 'bias': 36 | state_new[k.rsplit('.', minus_layers)[0] + '.b'] = normalize_conv2d_bias(state_new.pop(k)) 37 | 38 | 39 | def convert_generator(state): 40 | state_new = state.copy() # should be a shallow copy. I.e. tensors in the dict values should be referenced. 41 | for k, v in state.items(): 42 | parts = k.split('.') 43 | module = parts[0] 44 | 45 | if module == 'generator': 46 | module_gen = parts[1] 47 | if module_gen == 'progression': 48 | s = k.split('.', 3)[-1] 49 | 50 | if s == conv(1, 0, parameter='weight', fused=True): 51 | rename_and_norm(state_new, k, 'weight', fused=True) 52 | 53 | if s == conv(1, 0, parameter='bias', fused=True): 54 | rename_and_norm(state_new, k, 'bias', fused=True) 55 | 56 | if s == conv(1, 1, parameter='weight_orig'): 57 | rename_and_norm(state_new, k, 'weight') 58 | 59 | if s == conv(1, 1, parameter='bias'): 60 | rename_and_norm(state_new, k, 'bias') 61 | 62 | if s == conv(2, parameter='weight_orig'): 63 | rename_and_norm(state_new, k, 'weight') 64 | 65 | if s == conv(2, parameter='bias'): 66 | rename_and_norm(state_new, k, 'bias') 67 | 68 | state_new['generator.progression.0.conv1.0.input'] = state_new.pop('generator.progression.0.conv1.input') 69 | 70 | return state_new 71 | 72 | 73 | def assert_loaded_keys(missing, unexpected): 74 | def check(k): 75 | return not ( 76 | k.endswith('gamma') or k.endswith('beta') or k.endswith('bias_beta') 77 | or k.find('task') >= 0 or k.startswith('style.1.linear.') or k.startswith('final.') 78 | ) 79 | missing = [k for k in missing if check(k)] 80 | assert len(missing) == 0 81 | unexpected = [k for k in unexpected if not k.startswith('linear.')] 82 | assert len(unexpected) == 0 83 | 84 | 85 | def freeze(model): 86 | for param in model.parameters(): 87 | param.requires_grad = False 88 | 89 | 90 | def freeze_layers(*models): 91 | """Freeze selected generator and discriminator weights""" 92 | from model.hyper_mod import ConstantInput, StyledGenerator, Discriminator 93 | from torch.nn import DataParallel 94 | from torch.nn.parallel.distributed import DistributedDataParallel 95 | 96 | for model in models: 97 | if isinstance(model, DataParallel) or isinstance(model, DistributedDataParallel): 98 | model = model.module 99 | 100 | if isinstance(model, StyledGenerator): 101 | for styled_block in model.generator.progression: 102 | if isinstance(styled_block.conv1[0], ConstantInput): 103 | freeze(styled_block.conv1) 104 | 105 | # no need to freeze conv1 nor conv2 bc parameters have been already converted to buffers 106 | freeze(styled_block.noise1) 107 | freeze(styled_block.adain1) 108 | freeze(styled_block.noise2) 109 | freeze(styled_block.adain2) 110 | 111 | for to_rgb_i in model.generator.to_rgb: 112 | freeze(to_rgb_i) 113 | 114 | assert len(model.style) == 8 * 2 + 1 # 8 layers 115 | freeze(model.style) 116 | elif isinstance(model, Discriminator): 117 | for from_rgb_i in model.from_rgb: 118 | freeze(from_rgb_i) 119 | elif model is not None: 120 | print(type(model)) 121 | raise ValueError 122 | -------------------------------------------------------------------------------- /self_alignment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import torch 4 | from pathlib import Path 5 | from tqdm import trange 6 | from torch import nn, optim 7 | from torchvision.utils import make_grid 8 | from torch.utils.tensorboard import SummaryWriter 9 | 10 | from model.stylegan import StyledGenerator as OriginalStyledGenerator 11 | from utils import save_hparams, get_run_id 12 | 13 | 14 | def get_z(batch_size, code_size, mixing, device, fixed_seed=False): 15 | rng = None 16 | if fixed_seed: 17 | rng = torch.Generator(device=device) 18 | rng.manual_seed(2147483647) 19 | if mixing and random.random() < 0.9: 20 | return [torch.randn(batch_size, code_size, generator=rng, device=device), 21 | torch.randn(batch_size, code_size, generator=rng, device=device)] 22 | else: 23 | return torch.randn(batch_size, code_size, generator=rng, device=device) 24 | 25 | 26 | def get_noise(step, batch_size, device, fixed_seed=False): 27 | rng = None 28 | if fixed_seed: 29 | rng = torch.Generator(device=device) 30 | rng.manual_seed(2147483647) 31 | noise = [] 32 | for i in range(step + 1): 33 | size = 4 * 2 ** i 34 | noise.append(torch.randn(batch_size, 1, size, size, generator=rng, device=device)) 35 | 36 | return noise 37 | 38 | 39 | def log_images(writer, i, generator, task, generator_original, code_size, step, device, batch_size=9): 40 | with torch.no_grad(): 41 | z = get_z(batch_size=batch_size, code_size=code_size, mixing=False, device=device, fixed_seed=True) 42 | noise = get_noise(step, batch_size, device, fixed_seed=True) 43 | 44 | images_orig = generator_original(z, noise=noise, step=step, alpha=1, return_hierarchical=False) 45 | writer.add_image('sample_orig', make_grid(images_orig, nrow=3, normalize=True, range=(-1, 1)), i) 46 | 47 | task_in = task(torch.zeros(batch_size, device=device, dtype=torch.long)) 48 | images = generator(z, noise=noise, step=step, alpha=1, task=task_in, return_hierarchical=False) 49 | writer.add_image('sample_new', make_grid(images, nrow=3, normalize=True, range=(-1, 1)), i) 50 | 51 | 52 | def train(generator, task, generator_original, g_optimizer, criterion, iterations, args, 53 | batch_size=1, code_size=512, step=6, mixing=False, name=None, device='cpu', checkpoint_interval=5_000): 54 | name = 'unnamed' if name is None else name 55 | run_dir = f'{args.run_dir}/{get_run_id(args.run_dir):05d}-{name}' 56 | Path(run_dir).mkdir(parents=True) 57 | save_hparams(run_dir, args) 58 | writer = SummaryWriter(run_dir) 59 | alpha = 1 60 | for i in trange(iterations): 61 | g_optimizer.zero_grad() 62 | z = get_z(batch_size, code_size, mixing, device) 63 | noise = get_noise(step, batch_size, device) 64 | with torch.no_grad(): 65 | _, g_l_orig = generator_original(z, noise=noise, step=step, alpha=alpha, return_hierarchical=True) 66 | task_in = task(torch.zeros(batch_size, device=device, dtype=torch.long)) 67 | _, g_l = generator(z, noise=noise, step=step, alpha=alpha, task=task_in, return_hierarchical=True) 68 | 69 | assert len(g_l) == len(g_l_orig) 70 | g_loss = 0 71 | for input, target in zip(g_l, g_l_orig): 72 | g_loss += criterion(input, target) 73 | 74 | g_loss.backward() 75 | g_optimizer.step() 76 | 77 | writer.add_scalar('loss/g', g_loss, i) 78 | if i % args.log_interval == 0: 79 | log_images(writer, i, generator, task, generator_original, code_size, step, device) 80 | 81 | if i % checkpoint_interval == 0: 82 | torch.save({ 83 | 'generator': generator.state_dict(), 84 | 'task': task.state_dict(), 85 | }, f'{run_dir}/{i:06d}.model') 86 | 87 | 88 | def get_args(): 89 | parser = argparse.ArgumentParser(description='Self-alignment') 90 | 91 | parser.add_argument('checkpoint', type=str) 92 | parser.add_argument('--step', default=6, type=int, help='6 = 256px') 93 | parser.add_argument('--iterations', type=int, default=50_000, help='number of samples used') 94 | parser.add_argument('--lr', default=0.002, type=float, help='learning rate') 95 | parser.add_argument('--code_size', default=512, type=int) 96 | parser.add_argument('--n_mlp_style', default=8, type=int) 97 | parser.add_argument('--n_mlp_task', default=4, type=int) 98 | parser.add_argument('--task_size', default=64, type=int) 99 | parser.add_argument('--batch_size', default=32, type=int, help='max image size') 100 | parser.add_argument('--mixing', action='store_true', help='use mixing regularization') 101 | parser.add_argument('--device', default='cuda', type=str) 102 | parser.add_argument('--name', type=str, default=None) 103 | parser.add_argument('--run_dir', type=str, default='data/training-runs/self_align') 104 | parser.add_argument('--checkpoint_interval', type=int, default=5_000, help='number of samples used') 105 | parser.add_argument('--log_interval', type=int, default=500) 106 | 107 | return parser.parse_args() 108 | 109 | 110 | def main(args): 111 | from model.hyper_mod import StyledGenerator, Task 112 | from pretrained_converter import convert_generator, assert_loaded_keys, freeze_layers 113 | 114 | args.origin = 'self_align' 115 | g_running_orig = OriginalStyledGenerator(args.code_size).to(args.device) 116 | g_running_orig.train(False) 117 | generator = StyledGenerator(code_dim=args.code_size, task_dim=args.task_size, n_mlp=args.n_mlp_style).to(args.device) 118 | freeze_layers(generator) 119 | task = Task(args.task_size, n_mlp=args.n_mlp_task, num_labels=1).to(args.device) 120 | 121 | g_optimizer = optim.Adam(generator.generator.parameters(), lr=args.lr, betas=(0.0, 0.99)) 122 | g_optimizer.add_param_group( 123 | { 124 | 'params': generator.style.parameters(), 125 | 'lr': args.lr * 0.01, 126 | 'mult': 0.01, 127 | } 128 | ) 129 | g_optimizer.add_param_group( 130 | { 131 | 'params': task.parameters(), 132 | 'lr': args.lr * 0.01, 133 | 'mult': 0.001, 134 | } 135 | ) 136 | 137 | ckpt = torch.load(args.checkpoint) 138 | g_running_orig.load_state_dict(ckpt['g_running']) 139 | missing, unexpected = generator.load_state_dict(convert_generator(ckpt['g_running']), strict=False) 140 | assert_loaded_keys(missing, unexpected) 141 | del ckpt 142 | 143 | criterion = nn.L1Loss() 144 | 145 | train(generator=generator, task=task, generator_original=g_running_orig, g_optimizer=g_optimizer, 146 | criterion=criterion, iterations=args.iterations, args=args, batch_size=args.batch_size, 147 | code_size=args.code_size, step=args.step, mixing=args.mixing, name=args.name, device=args.device, 148 | checkpoint_interval=args.checkpoint_interval) 149 | 150 | 151 | if __name__ == '__main__': 152 | main(get_args()) 153 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import random 4 | import math 5 | 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from torch import nn, optim 10 | from torch.nn import functional as F 11 | from torch.autograd import grad 12 | from torchvision import transforms 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | from dataset import MultiResolutionDataset 16 | from model.hyper_mod import StyledGenerator, Task, Discriminator 17 | from pretrained_converter import convert_generator, assert_loaded_keys, freeze_layers 18 | from barlow_twins import BarlowTwins, DiffTransform as ContrTransform, compute_contrastive 19 | from utils import seed_everything, save_hparams, sample_data, adjust_lr, not_frozen_params, requires_grad, accumulate, log_images, get_run_id 20 | 21 | 22 | def train(args, dataset, generator, discriminator, task, g_running, t_running, g_optimizer, d_optimizer, 23 | contr_transform, contr_criterion): 24 | step = int(math.log2(args.init_size)) - 2 25 | resolution = 4 * 2 ** step 26 | loader = sample_data( 27 | dataset, args.batch.get(resolution, args.batch_default), resolution 28 | ) 29 | data_loader = iter(loader) 30 | 31 | adjust_lr(g_optimizer, args.lr.get(resolution, 0.002)) 32 | adjust_lr(d_optimizer, args.lr.get(resolution, 0.002)) 33 | 34 | pbar = tqdm(range(args.iterations)) 35 | 36 | grad_map_generator = not_frozen_params(generator) 37 | grad_map_discriminator = not_frozen_params(discriminator) 38 | requires_grad(generator, False) 39 | requires_grad(discriminator, True, grad_map=grad_map_discriminator) 40 | 41 | disc_loss_val = 0 42 | gen_loss_val = 0 43 | grad_loss_val = 0 44 | 45 | used_sample = 0 46 | num_imgs = 0 47 | 48 | max_step = int(math.log2(args.max_size)) - 2 49 | final_progress = False 50 | 51 | run_dir = f'{args.run_dir}/{get_run_id(args.run_dir):05d}-{args.name}' 52 | Path(run_dir).mkdir(parents=True) 53 | save_hparams(run_dir, args) 54 | if not args.no_tb: 55 | writer = SummaryWriter(run_dir) 56 | 57 | for i in pbar: 58 | d_optimizer.zero_grad() 59 | 60 | if args.const_alpha is not None: 61 | alpha = args.const_alpha 62 | else: 63 | alpha = min(1, 1 / args.phase * (used_sample + 1)) 64 | 65 | if (resolution == args.init_size and args.ckpt is None) or final_progress: 66 | if args.const_alpha is not None: 67 | alpha = args.const_alpha 68 | else: 69 | alpha = 1 70 | 71 | if used_sample > args.phase * 2: 72 | used_sample = 0 73 | step += 1 74 | 75 | if step > max_step: 76 | step = max_step 77 | final_progress = True 78 | ckpt_step = step + 1 79 | 80 | else: 81 | if args.const_alpha is not None: 82 | alpha = args.const_alpha 83 | else: 84 | alpha = 0 85 | ckpt_step = step 86 | 87 | resolution = 4 * 2 ** step 88 | 89 | loader = sample_data( 90 | dataset, args.batch.get(resolution, args.batch_default), resolution 91 | ) 92 | data_loader = iter(loader) 93 | 94 | if not args.no_checkpointing: 95 | torch.save( 96 | { 97 | 'generator': generator.module.state_dict(), 98 | 'discriminator': discriminator.module.state_dict(), 99 | 'g_optimizer': g_optimizer.state_dict(), 100 | 'd_optimizer': d_optimizer.state_dict(), 101 | 'g_running': g_running.state_dict(), 102 | 'task_g': task.module.state_dict(), 103 | 't_running_g': t_running.state_dict(), 104 | }, 105 | f'{run_dir}/train_step-{ckpt_step}.model', 106 | ) 107 | 108 | adjust_lr(g_optimizer, args.lr.get(resolution, 0.002)) 109 | adjust_lr(d_optimizer, args.lr.get(resolution, 0.002)) 110 | 111 | try: 112 | real_image, real_label = next(data_loader) 113 | 114 | except (OSError, StopIteration): 115 | data_loader = iter(loader) 116 | real_image, real_label = next(data_loader) 117 | 118 | used_sample += real_image.shape[0] 119 | num_imgs += real_image.shape[0] 120 | 121 | b_size = real_image.size(0) 122 | real_image = real_image.cuda() 123 | real_label = real_label.cuda() 124 | 125 | if args.loss == 'wgan-gp': 126 | real_predict = discriminator(real_image, c=real_label, step=step, alpha=alpha) 127 | real_predict = real_predict.mean() - 0.001 * (real_predict ** 2).mean() 128 | (-real_predict).backward() 129 | 130 | elif args.loss == 'r1': 131 | real_image.requires_grad = True 132 | real_scores = discriminator(real_image, c=real_label, step=step, alpha=alpha) 133 | real_predict = F.softplus(-real_scores).mean() 134 | real_predict.backward(retain_graph=True) 135 | 136 | grad_real = grad( 137 | outputs=real_scores.sum(), inputs=real_image, create_graph=True 138 | )[0] 139 | grad_penalty = ( 140 | grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2 141 | ).mean() 142 | grad_penalty = 10 / 2 * grad_penalty 143 | grad_penalty.backward() 144 | if i % 10 == 0: 145 | grad_loss_val = grad_penalty.item() 146 | 147 | if args.contrastive: 148 | contr_loss_real = compute_contrastive(discriminator, real_image, contr_criterion, contr_transform, c=real_label, 149 | step=step, alpha=alpha, weighting=args.contr_weighting) 150 | contr_loss_real.backward() 151 | 152 | if args.mixing and random.random() < 0.9: 153 | gen_in11, gen_in12, gen_in21, gen_in22 = torch.randn( 154 | 4, b_size, args.code_size, device='cuda' 155 | ).chunk(4, 0) 156 | gen_in1 = [gen_in11.squeeze(0), gen_in12.squeeze(0)] 157 | gen_in2 = [gen_in21.squeeze(0), gen_in22.squeeze(0)] 158 | 159 | else: 160 | gen_in1, gen_in2 = torch.randn(2, b_size, args.code_size, device='cuda').chunk( 161 | 2, 0 162 | ) 163 | gen_in1 = gen_in1.squeeze(0) 164 | gen_in2 = gen_in2.squeeze(0) 165 | 166 | fake_image = generator(gen_in1, step=step, alpha=alpha, task=task(real_label)) 167 | fake_predict = discriminator(fake_image, c=real_label, step=step, alpha=alpha) 168 | 169 | if args.loss == 'wgan-gp': 170 | fake_predict = fake_predict.mean() 171 | fake_predict.backward() 172 | 173 | eps = torch.rand(b_size, 1, 1, 1).cuda() 174 | x_hat = eps * real_image.data + (1 - eps) * fake_image.data 175 | x_hat.requires_grad = True 176 | hat_predict = discriminator(x_hat, c=real_label, step=step, alpha=alpha) 177 | grad_x_hat = grad( 178 | outputs=hat_predict.sum(), inputs=x_hat, create_graph=True 179 | )[0] 180 | grad_penalty = ( 181 | (grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1) - 1) ** 2 182 | ).mean() 183 | grad_penalty = 10 * grad_penalty 184 | grad_penalty.backward() 185 | if i % 10 == 0: 186 | grad_loss_val = grad_penalty.item() 187 | disc_loss_val = (-real_predict + fake_predict).item() 188 | 189 | elif args.loss == 'r1': 190 | fake_predict = F.softplus(fake_predict).mean() 191 | fake_predict.backward() 192 | if i % 10 == 0: 193 | disc_loss_val = (real_predict + fake_predict).item() 194 | 195 | if args.contrastive: 196 | fake_image = generator(gen_in1, step=step, alpha=alpha, task=task(real_label)) 197 | contr_loss_fake = compute_contrastive(discriminator, fake_image, contr_criterion, contr_transform, 198 | c=real_label, step=step, alpha=alpha, weighting=args.contr_weighting) 199 | contr_loss_fake.backward() 200 | 201 | d_optimizer.step() 202 | 203 | if (i + 1) % args.n_critic == 0: 204 | g_optimizer.zero_grad() 205 | 206 | requires_grad(generator, True, grad_map=grad_map_generator) 207 | requires_grad(discriminator, False) 208 | 209 | fake_image = generator(gen_in2, step=step, alpha=alpha, task=task(real_label)) 210 | 211 | predict = discriminator(fake_image, c=real_label, step=step, alpha=alpha) 212 | 213 | if args.loss == 'wgan-gp': 214 | loss = -predict.mean() 215 | 216 | elif args.loss == 'r1': 217 | loss = F.softplus(-predict).mean() 218 | 219 | if i % 10 == 0: 220 | gen_loss_val = loss.item() 221 | 222 | loss.backward() 223 | g_optimizer.step() 224 | accumulate(g_running, generator.module, grad_map=grad_map_generator) 225 | accumulate(t_running, task.module) 226 | 227 | requires_grad(generator, False) 228 | requires_grad(discriminator, True, grad_map=grad_map_discriminator) 229 | 230 | if i % args.img_log_interval == 0: 231 | log_images(f'{run_dir}/fakes{i:06d}.png', alpha, args, dataset, resolution, real_image.size(0), step, g_running, t_running) 232 | 233 | if i % args.checkpoint_interval == 0 and i >= args.begin_checkpointing_at and not args.no_checkpointing: 234 | torch.save( 235 | { 236 | 'generator': generator.module.state_dict(), 237 | 'discriminator': discriminator.module.state_dict(), 238 | # 'g_optimizer': g_optimizer.state_dict(), 239 | # 'd_optimizer': d_optimizer.state_dict(), 240 | 'g_running': g_running.state_dict(), 241 | 'task_g': task.module.state_dict(), 242 | 't_running_g': t_running.state_dict(), 243 | }, 244 | f'{run_dir}/{i:06d}.model', 245 | ) 246 | 247 | state_msg = ( 248 | f'Size: {4 * 2 ** step}; G: {gen_loss_val:.3f}; D: {disc_loss_val:.3f};' 249 | f' Grad: {grad_loss_val:.3f}; Alpha: {alpha:.5f}' 250 | ) 251 | 252 | pbar.set_description(state_msg) 253 | if not args.no_tb: 254 | writer.add_scalar('1.loss/G', gen_loss_val, i) 255 | writer.add_scalar('1.loss/D', disc_loss_val, i) 256 | writer.add_scalar('1.loss/grad', grad_loss_val, i) 257 | if args.contrastive: 258 | writer.add_scalar('1.loss/BT_real', contr_loss_real, i) 259 | writer.add_scalar('1.loss/BT_fake', contr_loss_fake, i) 260 | writer.add_scalar('alpha', alpha, i) 261 | writer.add_scalar('resolution', 4 * 2 ** step, i) 262 | writer.add_scalar('batch_size', loader.batch_size, i) 263 | writer.add_scalar('lr/discriminator', d_optimizer.param_groups[0]['lr'], i) 264 | writer.add_scalar('lr/generator', g_optimizer.param_groups[0]['lr'], i) 265 | writer.add_scalar('lr/task', g_optimizer.param_groups[1]['lr'], i) 266 | writer.add_scalar('kimgs', num_imgs / 1000, i) 267 | 268 | 269 | def get_args(): 270 | parser = argparse.ArgumentParser(description='Hyper-Modulation') 271 | 272 | # GAN config 273 | parser.add_argument('path', type=str, help='path of specified dataset') 274 | parser.add_argument('--iterations', default=3_000_000, type=int, help='training iterations') 275 | parser.add_argument('--phase', type=int, default=600_000, help='number of samples used for each training phases') 276 | parser.add_argument('--lr', default=0.001, type=float, help='learning rate') 277 | parser.add_argument('--sched', action='store_true', help='use lr scheduling') 278 | parser.add_argument('--init_size', default=8, type=int, help='initial image size') 279 | parser.add_argument('--max_size', default=1024, type=int, help='max image size') 280 | parser.add_argument('--batch_default', default=32, type=int, help='batch size if no lr scheduling is activated') 281 | parser.add_argument('--n_critic', default=1, type=int, help='Number of critic iterations in a global iteration') 282 | parser.add_argument('--ckpt', default=None, type=str, help='load from previous checkpoints') 283 | parser.add_argument('--no_from_rgb_activate', action='store_true', help='use activate in from_rgb (original implementation)') 284 | parser.add_argument('--mixing', action='store_true', help='use mixing regularization') 285 | parser.add_argument('--loss', type=str, default='wgan-gp', choices=['wgan-gp', 'r1'], help='class of gan loss') 286 | parser.add_argument('--const_alpha', default=None, type=float, help='Whether to use a constant alpha') 287 | # Logging config 288 | parser.add_argument('--name', default=None, type=str, help='Name of the experiment') 289 | parser.add_argument('--run_dir', default='data/training-runs', type=str, help='') 290 | parser.add_argument('--no_tb', action='store_true', help='dont log in tensorboard') 291 | parser.add_argument('--no_checkpointing', action='store_true', help='dont save checkpoints') 292 | parser.add_argument('--begin_checkpointing_at', default=0, type=int, help='dont save model before a training point') 293 | parser.add_argument('--checkpoint_interval', default=10_000, type=int, help='how often to save the model') 294 | parser.add_argument('--img_log_interval', default=1_000, type=int, help='how often to generate images') 295 | # Dataset config 296 | parser.add_argument('--filter_classes', default=None, nargs='+', type=int, help='Use only n classes') 297 | parser.add_argument('--dataset_length', default=None, type=int, help='Constrain the number of samples') 298 | parser.add_argument('--dataset_class_samples', default=None, type=int, help='Constrain the number of samples') 299 | parser.add_argument('--dataset_random_class_sampling', action='store_true', help='Random sample the subset of class samples') 300 | # Hyper-modulation config 301 | parser.add_argument('--n_mlp_style', default=8, type=int, help='') 302 | parser.add_argument('--n_mlp_task', default=4, type=int, help='Task net depth') 303 | parser.add_argument('--code_size', default=512, type=int, help='') 304 | parser.add_argument('--task_size', default=64, type=int, help='task width') 305 | parser.add_argument('--finetune', action='store_true', help='Finetune underlying pretrained model') 306 | # Contrastive config 307 | parser.add_argument('--contrastive', action='store_true', help='Enable contrastive training') 308 | parser.add_argument('--contr_weighting', default=0.01, type=float, help='Contrastive training weighting in the final loss') 309 | parser.add_argument('--contr_feats', default=512, type=int, help='Dimensionality of features passed to the contrastive loss') 310 | 311 | return parser.parse_args() 312 | 313 | 314 | def main(args): 315 | seed_everything(42) 316 | torch.backends.cudnn.benchmark = True 317 | contr_transform = None 318 | contr_criterion = None 319 | 320 | print(f'loss: {args.loss}') 321 | 322 | transform = transforms.Compose( 323 | [ 324 | transforms.RandomHorizontalFlip(), 325 | transforms.ToTensor(), 326 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 327 | ] 328 | ) 329 | if args.contrastive: 330 | contr_transform = ContrTransform(crop_resize=256) 331 | 332 | dataset = MultiResolutionDataset(args.path, transform, selected_classes=args.filter_classes, class_samples=args.dataset_class_samples, random_class_sampling=args.dataset_random_class_sampling, length=args.dataset_length) 333 | print(f'dataset len: {len(dataset)}') 334 | print(f'dataset classes: {dataset.num_classes}') 335 | args.num_classes = dataset.num_classes 336 | print(f'contrastive training: {args.contrastive}') 337 | 338 | generator = nn.DataParallel(StyledGenerator(code_dim=args.code_size, task_dim=args.task_size, n_mlp=args.n_mlp_style)).cuda() 339 | discriminator = nn.DataParallel( 340 | Discriminator(num_classes=args.num_classes, from_rgb_activate=not args.no_from_rgb_activate) 341 | ).cuda() 342 | task = nn.DataParallel(Task(args.task_size, n_mlp=args.n_mlp_task, num_labels=dataset.num_classes)).cuda() # class network 343 | g_running = StyledGenerator(code_dim=args.code_size, task_dim=args.task_size, n_mlp=args.n_mlp_style).cuda() 344 | g_running.train(False) 345 | t_running = Task(args.task_size, n_mlp=args.n_mlp_task, num_labels=dataset.num_classes).cuda() 346 | t_running.train(False) 347 | if args.contrastive: 348 | contr_criterion = nn.DataParallel(BarlowTwins(num_feats=args.contr_feats, use_projector=False)).cuda() 349 | 350 | g_optimizer = optim.Adam( 351 | generator.module.generator.parameters(), lr=args.lr, betas=(0.0, 0.99) 352 | ) 353 | g_optimizer.add_param_group( 354 | { 355 | 'params': task.parameters(), 356 | 'lr': args.lr * 0.01, 357 | 'mult': 0.001, 358 | } 359 | ) 360 | d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.0, 0.99)) 361 | if args.contrastive: 362 | d_optimizer.add_param_group({ 363 | 'params': contr_criterion.parameters(), 364 | 'lr': args.lr, 365 | }) 366 | 367 | accumulate(g_running, generator.module, 0) 368 | accumulate(t_running, task.module, 0) 369 | 370 | if args.ckpt is not None: 371 | ckpt = torch.load(args.ckpt) 372 | if 'task_g' in ckpt: 373 | ckpt['task'] = ckpt['task_g'] 374 | 375 | if 'task' in ckpt: # load self-aligned checkpoint 376 | generator.module.load_state_dict(ckpt['generator']) 377 | # copy original class embedding for all new classes 378 | ckpt['task']['task.0.weight_orig'] = ckpt['task']['task.0.weight_orig'].repeat((dataset.num_classes, 1)) 379 | task.module.load_state_dict(ckpt['task']) 380 | 381 | g_running.load_state_dict(ckpt['generator']) 382 | t_running.load_state_dict(ckpt['task']) 383 | del ckpt # avoid OOM 384 | 385 | # load vanilla discriminator 386 | ckpt = torch.load('data/stylegan-256px-new.model') 387 | missing, unexpected = discriminator.module.load_state_dict(ckpt['discriminator'], strict=False) 388 | unexpected = [k for k in unexpected if not k.startswith('linear.')] 389 | assert len(unexpected) == 0 390 | missing = [k for k in missing if not k.startswith('final.')] 391 | assert len(missing) == 0 392 | del ckpt # avoid OOM 393 | else: # begin training from pretrained vanilla stylegan 394 | missing, unexpected = generator.module.load_state_dict(convert_generator(ckpt['generator']), strict=False) 395 | assert_loaded_keys(missing, unexpected) 396 | missing, unexpected = g_running.load_state_dict(convert_generator(ckpt['g_running']), strict=False) 397 | assert_loaded_keys(missing, unexpected) 398 | 399 | # load vanilla discriminator 400 | missing, unexpected = discriminator.module.load_state_dict(ckpt['discriminator'], strict=False) 401 | unexpected = [k for k in unexpected if not k.startswith('linear.')] 402 | assert len(unexpected) == 0 403 | missing = [k for k in missing if not k.startswith('final.')] 404 | assert len(missing) == 0 405 | del ckpt # avoid OOM 406 | 407 | if not args.finetune: 408 | # can be stated here (after the optimizers declaration), as long as it's before the forward pass 409 | freeze_layers(generator, discriminator, g_running) 410 | 411 | if args.sched: 412 | args.lr = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003} 413 | args.batch = {4: 512, 8: 256, 16: 128, 32: 64, 64: 30, 128: 20, 256: 10} 414 | else: 415 | args.lr = {} 416 | args.batch = {} 417 | 418 | args.gen_sample = {512: (8, 4), 1024: (4, 2)} 419 | 420 | train(args, dataset, generator, discriminator, task, g_running, t_running, g_optimizer, d_optimizer, 421 | contr_transform, contr_criterion) 422 | 423 | 424 | if __name__ == '__main__': 425 | main(get_args()) 426 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | from pathlib import Path 4 | import math 5 | import pickle 6 | 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import DataLoader 10 | from torchvision.utils import save_image 11 | from pytorch_lightning import seed_everything as seed_everything_pl 12 | 13 | HPARAMS_FILENAME = 'hparams.yml' 14 | 15 | 16 | # ------------------------------------ I/O ------------------------------------ 17 | 18 | 19 | def args_to_yaml(path, args, exist_ok=False): 20 | file = Path(path) 21 | 22 | if file.exists(): 23 | if exist_ok: 24 | return 25 | else: 26 | raise FileExistsError(f'File already exists {file}') 27 | 28 | with file.open('w') as f: 29 | yaml.dump(args.__dict__, f, 30 | default_flow_style=False, 31 | sort_keys=True) 32 | 33 | 34 | def yaml_to_args(path): 35 | with open(path, 'r') as f: 36 | hparams = yaml.full_load(f) 37 | 38 | return argparse.Namespace(**hparams) 39 | 40 | 41 | def save_hparams(path, args, exist_ok=False): 42 | hparams_file = Path(path) / HPARAMS_FILENAME 43 | args_to_yaml(hparams_file, args, exist_ok=exist_ok) 44 | 45 | 46 | def load_hparams(path): 47 | path = Path(path) 48 | if not path.is_dir(): 49 | path = path.parent 50 | 51 | hparams_file = path / HPARAMS_FILENAME 52 | 53 | if hparams_file.exists(): 54 | return yaml_to_args(hparams_file) 55 | 56 | return None 57 | 58 | 59 | def num_to_one_hot(t, num_classes=3): 60 | return torch.nn.functional.one_hot(torch.tensor(t) if isinstance(t, int) or isinstance(t, list) else t, num_classes=num_classes).float() 61 | 62 | 63 | def load_checkpoint(args, hparams=None): 64 | if hparams is None: 65 | hparams = load_hparams(args.checkpoint_path) 66 | if hasattr(hparams, 'original_model') and hparams.original_model: 67 | from model.stylegan import StyledGenerator 68 | model = StyledGenerator(code_dim=hparams.code_size, n_mlp=hparams.n_mlp, c_dim=hparams.num_classes).to(args.device) 69 | checkpoint = torch.load(args.checkpoint_path, map_location=args.device) 70 | model.load_state_dict(checkpoint['g_running'] if 'g_running' in checkpoint else checkpoint) 71 | return model, num_to_one_hot 72 | from model.hyper_mod import StyledGenerator, Task 73 | 74 | checkpoint = torch.load(args.checkpoint_path, map_location=args.device) 75 | model = StyledGenerator(code_dim=hparams.code_size, task_dim=hparams.task_size, n_mlp=hparams.n_mlp_style).to(args.device) 76 | task = Task(hparams.task_size, n_mlp=hparams.n_mlp_task, num_labels=hparams.num_classes).to(args.device) 77 | 78 | from_self_align = False 79 | if hasattr(hparams, 'origin') and hparams.origin == 'self_align': 80 | from_self_align = True 81 | 82 | if from_self_align: 83 | model.load_state_dict(checkpoint['generator']) 84 | task.load_state_dict(checkpoint['task']) 85 | else: 86 | model.load_state_dict(checkpoint['g_running']) 87 | task.load_state_dict(checkpoint['t_running_g']) 88 | return model, task 89 | 90 | 91 | # ------------------------------------ Logging ------------------------------------ 92 | 93 | 94 | def get_run_id(outdir): 95 | import os, re 96 | # From StyleGAN repo 97 | # Pick output directory. 98 | prev_run_dirs = [] 99 | if os.path.isdir(outdir): 100 | prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))] 101 | prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs] 102 | prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] 103 | return max(prev_run_ids, default=-1) + 1 104 | 105 | 106 | def log_images(fname, alpha, args, dataset, resolution, batch_size, step, generator, task, device='cuda', seed=2147483647): 107 | if args.no_tb: 108 | return 109 | rng = torch.Generator(device=device) 110 | rng.manual_seed(seed) 111 | images = [] 112 | default_gen_n_classes = 8 113 | default_gen_n_samples = 4 114 | gen_n_classes, gen_n_samples = args.gen_sample.get(resolution, (default_gen_n_classes, default_gen_n_samples)) 115 | if dataset.num_classes < gen_n_classes: 116 | # cycle through the classes 117 | gen_class_range = list(range(dataset.num_classes)) * math.ceil(gen_n_classes / dataset.num_classes) 118 | gen_class_range = gen_class_range[:gen_n_classes] 119 | else: 120 | gen_class_range = list(range(gen_n_classes)) 121 | gen_class_range = gen_class_range * gen_n_samples 122 | with torch.no_grad(): 123 | for j in range(0, len(gen_class_range), batch_size): 124 | gen_classes = gen_class_range[j:j + batch_size] 125 | images.append( 126 | generator( 127 | torch.randn(len(gen_classes), args.code_size, generator=rng, device=rng.device), 128 | step=step, alpha=alpha, task=task(torch.tensor(gen_classes).to(device)) 129 | ).data.cpu() 130 | ) 131 | 132 | save_image(torch.cat(images, 0), fname, nrow=gen_n_classes, normalize=True, range=(-1, 1)) 133 | 134 | 135 | # ------------------------------------ Training ------------------------------------ 136 | 137 | 138 | def seed_everything(seed): # cleaner imports 139 | seed_everything_pl(seed) 140 | 141 | 142 | def not_frozen_params(model): 143 | require = {} 144 | for name, param in model.named_parameters(): 145 | require[name] = param.requires_grad 146 | 147 | return require 148 | 149 | 150 | def requires_grad(model, flag=True, grad_map=None): 151 | for n, p in model.named_parameters(): 152 | if flag and grad_map is not None and not grad_map[n]: # filter those params which were originally frozen 153 | continue 154 | 155 | p.requires_grad = flag 156 | 157 | 158 | def accumulate(model1, model2, decay=0.999, grad_map=None): 159 | par1 = dict(model1.named_parameters()) 160 | par2 = dict(model2.named_parameters()) 161 | 162 | for k in par1.keys(): 163 | if grad_map is not None and not grad_map['module.' + k]: # filter out frozen params 164 | continue 165 | 166 | par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay) 167 | 168 | 169 | def sample_data(dataset, batch_size, image_size=4): 170 | dataset.resolution = image_size 171 | loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=1, drop_last=True) 172 | 173 | return loader 174 | 175 | 176 | def adjust_lr(optimizer, lr): 177 | for group in optimizer.param_groups: 178 | mult = group.get('mult', 1) 179 | group['lr'] = lr * mult 180 | 181 | 182 | # ------------------------------------ Metrics ------------------------------------ 183 | 184 | 185 | def load_scores(path): 186 | ms = {} 187 | for m_f in Path(path).glob('*.metric'): 188 | with open(m_f, 'rb') as f: 189 | data = pickle.load(f) 190 | ms[int(m_f.stem)] = data 191 | return dict(sorted(ms.items())) 192 | 193 | 194 | def get_submetric(s): 195 | parts = s.rsplit('.', 1) 196 | if len(parts) == 1: 197 | return s, None 198 | if parts[1] == '': 199 | raise ValueError('dot is used to mark metric index, dont use at the end') 200 | return parts[0], parts[1] 201 | 202 | 203 | def class_mean(ms, metric='fid', return_steps=False): 204 | ms = load_scores(ms) 205 | steps = list(ms.keys()) 206 | class_m = 0 207 | classes = ms[steps[0]].keys() 208 | metric, sub_metric = get_submetric(metric) 209 | for class_i in classes: 210 | if sub_metric is not None: 211 | class_m += np.array([v[class_i][metric][sub_metric] for k, v in ms.items()]) 212 | else: 213 | class_m += np.array([v[class_i][metric] for k, v in ms.items()]) 214 | 215 | if return_steps: 216 | return steps, class_m / len(classes) 217 | return class_m / len(classes) 218 | --------------------------------------------------------------------------------