├── LICENSE
├── README.md
├── barlow_twins.py
├── dataset.py
├── doc
├── arch.png
└── sample_afhq.png
├── environment.yml
├── generate.py
├── latent_interpolation.py
├── metric_impl
├── dc.py
├── ppl.py
└── pr.py
├── metrics.py
├── model
├── hyper_mod.py
└── stylegan.py
├── prepare_data.py
├── pretrained_converter.py
├── self_alignment.py
├── train.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Kim Seonghyeon, Hector Laria
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Transferring Unconditional to Conditional GANs with Hyper-Modulation
2 | --------------------------------------------------------------------
3 |
4 | _GANs have matured in recent years and are able to generate high-resolution, realistic images. However, the computational resources and the data required for the training of high-quality GANs are enormous, and the study of transfer learning of these models is therefore an urgent topic. Many of the available high-quality pretrained GANs are unconditional (like StyleGAN). For many applications, however, conditional GANs are preferable, because they provide more control over the generation process, despite often suffering more training difficulties. Therefore, in this paper, we focus on transferring from high-quality pretrained unconditional GANs to conditional GANs. This requires architectural adaptation of the pretrained GAN to perform the conditioning. To this end, we propose hyper-modulated generative networks that allow for shared and complementary supervision. To prevent the additional weights of the hypernetwork to overfit, with subsequent mode collapse on small target domains, we introduce a self-initialization procedure that does not require any real data to initialize the hypernetwork parameters. To further improve the sample efficiency of the transfer, we apply contrastive learning in the discriminator, which effectively works on very limited batch sizes. In extensive experiments, we validate the efficiency of the hypernetworks, self-initialization and contrastive loss for knowledge transfer on standard benchmarks._
5 |
6 |
7 |
8 |
9 |
10 | Official implementation of [Hyper-Modulation](https://arxiv.org/abs/2112.02219).
11 |
12 |
13 |
14 |
15 | Given a pre-trained model, our hyper-modulation method learns to use its knowledge to produce out-of-domain generalizations. I.e., trained on human faces, it quickly learns animal faces (above) or further domains like objects and places (paper). No finetuning is needed. It also preserves and responds to the source domain style mechanism.
16 |
17 |
18 | ### Environment
19 | Please install the required packages with conda and activate the environment
20 | ```bash
21 | conda env create -f environment.yml
22 | conda activate hypermod
23 | ```
24 |
25 | ### Base model
26 | Pretrained unconditional models are available in the [base code repository](https://github.com/rosinality/style-based-gan-pytorch), concretely we used [this one](https://drive.google.com/open?id=1QlXFPIOFzsJyjZ1AtfpnVhqW4Z0r8GLZ) trained on FFHQ dataset.
27 | Let's assume it's downloaded in ``.
28 |
29 | ### Data
30 | To replicate the stated results, please download [AFHQ(v1)](https://github.com/clovaai/stargan-v2/blob/master/README.md#animal-faces-hq-dataset-afhq).
31 | Additional datasets like FFHQ, CelebA-HQ, Flowers102standard, Places356 can be downloaded from their respective sources.
32 |
33 | Once downloaded and extracted, the dataset must be preprocessed into an LMDB file with the following command
34 | ```bash
35 | python prepare_data.py /train/ --labels --sizes 256 --out
36 | ```
37 | The directory given to `prepare_data.py` must contain a folder per class, class samples within each respective folder as specified by [ImageFolder](https://pytorch.org/vision/stable/datasets.html#torchvision.datasets.ImageFolder) dataset.
38 |
39 | Additionally, we can compute the inception features of the dataset to calculate the metrics later.
40 | ```bash
41 | python metrics.py --output_path --batch_size 32 --operation extract-feats
42 | ```
43 |
44 | ### Training
45 |
46 | #### Self-alignment
47 |
48 | ```bash
49 | python self_alignment.py --batch_size 32 --name --checkpoint_interval 1000 --iterations 10000
50 | ```
51 | The model checkpoints will begin to get saved under `checkpoint/self_align/` folder (can be specified with `--checkpoint_dir`).
52 | Additional losses are also saved, can be visualized pointing tensorboard to folder `runs/self_align/` (can be specified with `--log_dir`).
53 |
54 | To assess the quality we can start computing metrics for each checkpoint with
55 | ```bash
56 | python metrics.py checkpoint/self_align/ --precomputed_feats --batch_size 32 --operation metrics
57 | ```
58 | For each checkpoint, a `.metric` file will be created containing the computed scores. The function `class_mean()` from `utils.py` can be used to read the whole directory.
59 | One would normally want to plot the metric evolution over time, which can be done as
60 | ```python
61 | import matplotlib.pyplot as plt
62 | from utils import class_mean
63 |
64 | p = 'checkpoint/self_align/'
65 | plt.plot(*class_mean(p, metric='fid', return_steps=True))
66 | ```
67 | We then pick a checkpoint with good enough scores. We will name it ``.
68 |
69 | #### Transfer learning
70 |
71 | To run the actual training, we can do
72 | ```bash
73 | python train.py --ckpt --max_size 256 --loss r1 --const_alpha 1 --phase 3000000 --batch_default 10 --init_size 256 --mixing --name --checkpoint_interval 1000 --iterations 200002
74 | ```
75 |
76 | For the finetuned version
77 | ```bash
78 | python train.py --finetune --ckpt --max_size 256 --loss r1 --const_alpha 1 --phase 3000000 --batch_default 10 --init_size 256 --mixing --name --checkpoint_interval 1000 --iterations 200002
79 | ```
80 |
81 | For the version with contrastive learning
82 | ```bash
83 | python train.py --contrastive --ckpt --max_size 256 --loss r1 --const_alpha 1 --phase 3000000 --batch_default 10 --init_size 256 --mixing --name --checkpoint_interval 1000 --iterations 200002
84 | ```
85 |
86 | Generated images and loss curves are found in `tensorboard --logdir runs/` by default.
87 | Checkpoints are saved in `checkpoint/`, where metrics can be computed and plotted by using `metrics.py` script and `class_mean` plot the same way as for self-alignment.
88 |
89 | Additionally, perceptual path length can be computed as
90 | ```bash
91 | python ppl.py --checkpoint_path
92 | ```
93 |
94 | ### Generation
95 |
96 | To make use of the final models, we can directly make use of `generate.py` or `latent_interpolation.py`.
97 |
98 | To generate a figure like the interpolation from Supplementary Material section 7 (Latent space), run
99 | ```bash
100 | python latent_interpolation.py --output interpolation_class_noise.png --num_interpolations 5 --type class_noise
101 | ```
102 |
103 | ### Acknowledgments
104 |
105 | The project is based on [this StyleGAN implementation](https://github.com/rosinality/style-based-gan-pytorch). We used [PIQ](https://github.com/photosynthesis-team/piq) package for metrics, and the PPL implementation was taken from [here](https://github.com/rosinality/stylegan2-pytorch/blob/master/ppl.py).
106 |
107 | ### Citation
108 |
109 | ```
110 | @InProceedings{Laria_2022_CVPR,
111 | author = {Laria, H\'ector and Wang, Yaxing and van de Weijer, Joost and Raducanu, Bogdan},
112 | title = {Transferring Unconditional to Conditional GANs With Hyper-Modulation},
113 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
114 | month = {June},
115 | year = {2022},
116 | pages = {3840-3849}
117 | }
118 | ```
119 |
--------------------------------------------------------------------------------
/barlow_twins.py:
--------------------------------------------------------------------------------
1 | """Base Barlow Twins implementation (BarlowTwinsLoss) taken from
2 | https://github.com/facebookresearch/barlowtwins/blob/main/main.py"""
3 | from typing import Tuple
4 | import torch
5 | from torch import nn
6 | from kornia import augmentation as K
7 |
8 |
9 | class DiffTransform(nn.Module):
10 | def __init__(self, crop_resize: int = 224):
11 | super().__init__()
12 |
13 | self.transform = K.AugmentationSequential(
14 | K.Normalize(mean=torch.tensor(-1), std=torch.tensor(2)), # from [-1, 1] to [0, 1]
15 | K.RandomResizedCrop((crop_resize, crop_resize), resample='BICUBIC'),
16 | K.RandomHorizontalFlip(p=0.5),
17 | K.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8),
18 | K.RandomGrayscale(p=0.2),
19 | K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(1.05, 1.05), p=0.5),
20 | K.RandomSolarize(thresholds=(0, 0.5), additions=0, p=0.0),
21 | K.Normalize(mean=torch.tensor(0.5), std=torch.tensor(0.5)), # back to [-1, 1]
22 | )
23 | self.transform_prime = K.AugmentationSequential(
24 | K.Normalize(mean=torch.tensor(-1), std=torch.tensor(2)),
25 | K.RandomResizedCrop((crop_resize, crop_resize), resample='BICUBIC'),
26 | K.RandomHorizontalFlip(p=0.5),
27 | K.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8),
28 | K.RandomGrayscale(p=0.2),
29 | K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(1.05, 1.05), p=0.1),
30 | K.RandomSolarize(thresholds=(0, 0.5), additions=0, p=0.2),
31 | K.Normalize(mean=torch.tensor(0.5), std=torch.tensor(0.5)),
32 | )
33 |
34 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
35 | y1 = self.transform(x)
36 | y2 = self.transform_prime(x)
37 | return y1, y2
38 |
39 |
40 | class BarlowTwins(nn.Module):
41 | def __init__(self, num_feats, lambd=5e-3, sizes=(512, 512, 512, 512), use_projector=True):
42 | super().__init__()
43 | self.lambd = lambd
44 |
45 | # projector
46 | if not use_projector:
47 | self.projector = nn.Identity()
48 | else:
49 | layers = []
50 | for i in range(len(sizes) - 2):
51 | layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
52 | layers.append(nn.BatchNorm1d(sizes[i + 1]))
53 | layers.append(nn.ReLU(inplace=True))
54 | layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
55 | self.projector = nn.Sequential(*layers)
56 |
57 | # normalization layer for the representations z1 and z2
58 | self.bn = nn.BatchNorm1d(num_feats, affine=False)
59 |
60 | def forward(self, z1, z2):
61 | z1 = self.projector(z1)
62 | z2 = self.projector(z2)
63 |
64 | return self.bn(z1), self.bn(z2)
65 |
66 | @staticmethod
67 | def off_diagonal(x):
68 | # return a flattened view of the off-diagonal elements of a square matrix
69 | n, m = x.shape
70 | assert n == m
71 | return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
72 |
73 | def compute_final_loss(self, z1, z2):
74 | # empirical cross-correlation matrix
75 | c = z1.T @ z2
76 |
77 | # sum the cross-correlation matrix between all gpus
78 | c.div_(z1.size(0))
79 |
80 | on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
81 | off_diag = self.off_diagonal(c).pow_(2).sum()
82 | loss = on_diag + self.lambd * off_diag
83 | return loss
84 |
85 |
86 | def compute_contrastive(discriminator, images, barlow_twins, transform, c, step, alpha, weighting=0.01):
87 | y1, y2 = transform(images)
88 | hierarchical_feats1 = discriminator(y1, c=c, step=step, alpha=alpha, return_hierarchical=True)
89 | hierarchical_feats2 = discriminator(y2, c=c, step=step, alpha=alpha, return_hierarchical=True)
90 |
91 | feats1 = hierarchical_feats1[-2].squeeze(-1).squeeze(-1)
92 | feats2 = hierarchical_feats2[-2].squeeze(-1).squeeze(-1)
93 |
94 | # hack to compute the cross-corr matrix among gpus on DP
95 | z1, z2 = barlow_twins(feats1, feats2) # gather zs from all gpus
96 | final_loss = barlow_twins.module.compute_final_loss(z1, z2) # .module = out of DP
97 |
98 | return final_loss * weighting
99 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 |
3 | import lmdb
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | import torch
7 |
8 |
9 | class MultiResolutionDataset(Dataset):
10 | def __init__(self, path, transform, resolution=8, selected_classes=None, class_samples=None, random_class_sampling=False, length=None, one_hot_label=False):
11 | self.path = path
12 | self.env = None
13 | self.open_lmdb()
14 |
15 | if not self.env:
16 | raise IOError('Cannot open lmdb dataset', path)
17 | self.filter_classes = None
18 |
19 | with self.env.begin(write=False) as txn:
20 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
21 | self.num_classes = txn.get('num_classes'.encode('utf-8'))
22 |
23 | if selected_classes is not None or class_samples is not None:
24 | if selected_classes is None:
25 | selected_classes = list(range(int(self.num_classes.decode('utf-8')))) if self.num_classes is not None else [0]
26 | self.filter_classes = self.filtering(txn, selected_classes, class_samples, random_sampling=random_class_sampling)
27 | self.length = len(self.filter_classes)
28 | self.num_classes = str(len(selected_classes)).encode('utf-8')
29 |
30 | if self.num_classes is not None:
31 | self.num_classes = int(self.num_classes.decode('utf-8'))
32 | assert length is None or length <= self.length, f'There are not enough samples in the dataset. {length} asked, {self.length} in total.'
33 | if length is not None:
34 | self.length = length
35 | self.resolution = resolution
36 | self.transform = transform
37 | self.one_hot_label = one_hot_label
38 |
39 | self.close_lmdb()
40 |
41 | def open_lmdb(self):
42 | self.env = lmdb.open(
43 | self.path,
44 | max_readers=32,
45 | readonly=True,
46 | lock=False,
47 | readahead=False,
48 | meminit=False,
49 | )
50 |
51 | def close_lmdb(self):
52 | self.env.close()
53 | self.env = None
54 |
55 | def __len__(self):
56 | return self.length
57 |
58 | def __getitem__(self, index):
59 | if self.env is None:
60 | self.open_lmdb()
61 |
62 | if self.filter_classes is not None:
63 | index = self.filter_classes[index]
64 |
65 | with self.env.begin(write=False) as txn:
66 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
67 | img_bytes = txn.get(key)
68 |
69 | key = f'label-{str(index).zfill(5)}'.encode('utf-8')
70 | label = txn.get(key)
71 |
72 | buffer = BytesIO(img_bytes)
73 | img = Image.open(buffer)
74 | img = self.transform(img)
75 |
76 | if label is not None:
77 | label = int(label.decode('utf-8'))
78 | if self.one_hot_label:
79 | label = torch.nn.functional.one_hot(torch.tensor(label), num_classes=self.num_classes).float()
80 |
81 | return img, label
82 |
83 | @staticmethod
84 | def filtering(txn, selected_classes, class_samples, random_sampling=False):
85 | def get_label(txn, i):
86 | return int(txn.get(f'label-{str(i).zfill(5)}'.encode('utf-8')).decode('utf-8'))
87 |
88 | from collections import defaultdict
89 | import itertools
90 | import random
91 | class_ids = defaultdict(list)
92 | length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
93 |
94 | # separate classes
95 | for i in range(length):
96 | class_ids[get_label(txn, i)].append(i)
97 |
98 | # drop unselected classes
99 | for i in list(class_ids.keys()):
100 | if i not in selected_classes:
101 | class_ids.pop(i)
102 |
103 | # reduce the samples per class
104 | for i in class_ids.keys():
105 | if random_sampling:
106 | class_ids[i] = random.sample(class_ids[i], k=class_samples)
107 | else:
108 | class_ids[i] = class_ids[i][:class_samples]
109 |
110 | # flatten
111 | return list(itertools.chain(*class_ids.values()))
112 |
--------------------------------------------------------------------------------
/doc/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hecoding/Hyper-Modulation/3c07d8052414470b2c8d10b265d522a1ed5d9b14/doc/arch.png
--------------------------------------------------------------------------------
/doc/sample_afhq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hecoding/Hyper-Modulation/3c07d8052414470b2c8d10b265d522a1ed5d9b14/doc/sample_afhq.png
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: hypermod
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - absl-py=0.11.0=py38h578d9bd_0
9 | - aiohttp=3.7.4=py38h27cfd23_1
10 | - async-timeout=3.0.1=py_1000
11 | - attrs=20.3.0=pyhd3deb0d_0
12 | - blas=1.0=mkl
13 | - blinker=1.4=py_1
14 | - brotlipy=0.7.0=py38h8df0ef7_1001
15 | - c-ares=1.17.1=h36c2ea0_0
16 | - ca-certificates=2021.1.19=h06a4308_0
17 | - cachetools=4.2.1=pyhd8ed1ab_0
18 | - certifi=2020.12.5=py38h06a4308_0
19 | - cffi=1.14.5=py38h261ae71_0
20 | - chardet=3.0.4=py38h924ce5b_1008
21 | - click=7.1.2=pyh9f0ad1d_0
22 | - cryptography=2.9.2=py38h766eaa4_0
23 | - cudatoolkit=10.2.89=hfd86e86_1
24 | - freetype=2.10.4=h5ab3b9f_0
25 | - fsspec=0.8.7=pyhd8ed1ab_0
26 | - future=0.18.2=py38h578d9bd_3
27 | - google-auth=1.24.0=pyhd3deb0d_0
28 | - google-auth-oauthlib=0.4.1=py_2
29 | - grpcio=1.33.2=py38heead2fc_2
30 | - idna=2.10=pyh9f0ad1d_0
31 | - importlib-metadata=3.7.0=py38h578d9bd_0
32 | - intel-openmp=2020.2=254
33 | - jpeg=9b=h024ee3a_2
34 | - lcms2=2.11=h396b838_0
35 | - ld_impl_linux-64=2.33.1=h53a641e_7
36 | - libedit=3.1.20191231=h14c3975_1
37 | - libffi=3.3=he6710b0_2
38 | - libgcc-ng=9.1.0=hdf63c60_0
39 | - libpng=1.6.37=hbc83047_0
40 | - libprotobuf=3.14.0=h8c45485_0
41 | - libstdcxx-ng=9.1.0=hdf63c60_0
42 | - libtiff=4.1.0=h2733197_1
43 | - libuv=1.40.0=h7b6447c_0
44 | - lmdb=0.9.28=h2531618_0
45 | - lz4-c=1.9.3=h2531618_0
46 | - markdown=3.3.4=pyhd8ed1ab_0
47 | - mkl=2020.2=256
48 | - mkl-service=2.3.0=py38he904b0f_0
49 | - mkl_fft=1.3.0=py38h54f3939_0
50 | - mkl_random=1.1.1=py38h0573a6f_0
51 | - multidict=5.1.0=py38h27cfd23_2
52 | - ncurses=6.2=he6710b0_1
53 | - ninja=1.10.2=py38hff7bd54_0
54 | - numpy=1.19.2=py38h54aff64_0
55 | - numpy-base=1.19.2=py38hfa32c7d_0
56 | - oauthlib=3.0.1=py_0
57 | - olefile=0.46=py_0
58 | - openssl=1.1.1j=h27cfd23_0
59 | - packaging=20.9=pyh44b312d_0
60 | - pillow=8.1.1=py38he98fc37_0
61 | - pip=21.0.1=py38h06a4308_0
62 | - protobuf=3.14.0=py38h2531618_1
63 | - pyasn1=0.4.8=py_0
64 | - pyasn1-modules=0.2.7=py_0
65 | - pycparser=2.20=pyh9f0ad1d_2
66 | - pyjwt=2.0.1=pyhd8ed1ab_0
67 | - pyopenssl=19.1.0=py38_0
68 | - pyparsing=2.4.7=pyh9f0ad1d_0
69 | - pysocks=1.7.1=py38h578d9bd_3
70 | - python=3.8.8=hdb3f193_4
71 | - python-lmdb=1.1.1=py38h2531618_1
72 | - python_abi=3.8=1_cp38
73 | - pytorch=1.7.1=py3.8_cuda10.2.89_cudnn7.6.5_0
74 | - pytorch-lightning=1.2.1=pyhd8ed1ab_0
75 | - pyyaml=5.3.1=py38h8df0ef7_1
76 | - readline=8.1=h27cfd23_0
77 | - requests=2.25.1=pyhd3deb0d_0
78 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0
79 | - rsa=4.7.2=pyh44b312d_0
80 | - setuptools=52.0.0=py38h06a4308_0
81 | - six=1.15.0=py38h06a4308_0
82 | - sqlite=3.33.0=h62c20be_0
83 | - tensorboard=2.4.1=pyhd8ed1ab_0
84 | - tensorboard-plugin-wit=1.8.0=pyh44b312d_0
85 | - tk=8.6.10=hbc83047_0
86 | - torchaudio=0.7.2=py38
87 | - torchvision=0.8.2=py38_cu102
88 | - tqdm=4.58.0=pyhd8ed1ab_0
89 | - typing-extensions=3.7.4.3=0
90 | - typing_extensions=3.7.4.3=py_0
91 | - urllib3=1.26.3=pyhd8ed1ab_0
92 | - werkzeug=1.0.1=pyh9f0ad1d_0
93 | - wheel=0.36.2=pyhd3eb1b0_0
94 | - xz=5.2.5=h7b6447c_0
95 | - yaml=0.2.5=h516909a_0
96 | - yarl=1.6.3=py38h25fe258_0
97 | - zipp=3.4.0=py_0
98 | - zlib=1.2.11=h7b6447c_3
99 | - zstd=1.4.5=h9ceee32_0
100 |
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 |
4 | import torch
5 | from torchvision import utils
6 |
7 | from utils import load_checkpoint
8 |
9 |
10 | @torch.no_grad()
11 | def get_mean_style(generator, device):
12 | mean_style = None
13 |
14 | for i in range(10):
15 | style = generator.mean_style(torch.randn(1024, 512).to(device))
16 |
17 | if mean_style is None:
18 | mean_style = style
19 |
20 | else:
21 | mean_style += style
22 |
23 | mean_style /= 10
24 | return mean_style
25 |
26 | @torch.no_grad()
27 | def sample(generator, class_network, n_class, step, mean_style, n_sample, device, seed):
28 | rng = torch.Generator()
29 | rng.manual_seed(seed)
30 | class_ = class_network(torch.tensor([n_class] * n_sample, device=device))
31 | image = generator(
32 | torch.randn(n_sample, 512, generator=rng).to(device),
33 | step=step,
34 | alpha=1,
35 | mean_style=mean_style,
36 | style_weight=0.7,
37 | task=class_,
38 | )
39 |
40 | return image
41 |
42 | @torch.no_grad()
43 | def style_mixing(generator, class_network, n_class, step, mean_style, n_source, n_target, device, seed):
44 | rng = torch.Generator()
45 | rng.manual_seed(seed)
46 | source_code = torch.randn(n_source, 512, generator=rng).to(device)
47 | target_code = torch.randn(n_target, 512, generator=rng).to(device)
48 | class_source = class_network(torch.tensor([n_class] * n_source, device=device))
49 | class_target = class_network(torch.tensor([n_class] * n_target, device=device))
50 |
51 | shape = 4 * 2 ** step
52 | alpha = 1
53 |
54 | images = [torch.ones(1, 3, shape, shape).to(device) * -1]
55 |
56 | source_image = generator(
57 | source_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7, task=class_source,
58 | )
59 | target_image = generator(
60 | target_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7, task=class_target,
61 | )
62 |
63 | images.append(source_image)
64 |
65 | for i in range(n_target):
66 | image = generator(
67 | [target_code[i].unsqueeze(0).repeat(n_source, 1), source_code],
68 | step=step,
69 | alpha=alpha,
70 | mean_style=mean_style,
71 | style_weight=0.7,
72 | mixing_range=(0, 1),
73 | task=class_source,
74 | )
75 | images.append(target_image[i].unsqueeze(0))
76 | images.append(image)
77 |
78 | images = torch.cat(images, 0)
79 |
80 | return images
81 |
82 |
83 | def main(args):
84 | print('Loading model...', end=' ')
85 | generator, class_network = load_checkpoint(args)
86 | generator.eval()
87 | print('Loaded')
88 |
89 | mean_style = get_mean_style(generator, args.device)
90 |
91 | step = int(math.log(args.size, 2)) - 2
92 |
93 | print('Generating sample')
94 | img = sample(generator, class_network, args.class_num, step, mean_style, args.n_row * args.n_col, args.device, args.seed)
95 | utils.save_image(img, args.out, nrow=args.n_col, normalize=True, range=(-1, 1))
96 |
97 | if not args.no_mixing:
98 | print('Generating style mixing')
99 | for j in range(20):
100 | img = style_mixing(generator, class_network, args.class_num, step, mean_style, args.n_col, args.n_row, args.device, args.seed)
101 | utils.save_image(
102 | img, f'{args.out}_sample_mixing_{j}.png', nrow=args.n_col + 1, normalize=True, range=(-1, 1)
103 | )
104 |
105 |
106 | if __name__ == '__main__':
107 | parser = argparse.ArgumentParser()
108 | parser.add_argument('--seed', type=int, default=2147483647, help='RNG seed')
109 | parser.add_argument('--size', type=int, default=1024, help='size of the image')
110 | parser.add_argument('--n_row', type=int, default=3, help='number of rows of sample matrix')
111 | parser.add_argument('--n_col', type=int, default=5, help='number of columns of sample matrix')
112 | parser.add_argument('checkpoint_path', type=str, help='path to checkpoint file')
113 | parser.add_argument('--device', type=str, default='cuda', help='')
114 | parser.add_argument('--no_mixing', action='store_true', help='Dont generate style mixing samples')
115 | parser.add_argument('--class_num', type=int, default=0, help='Which class to generate')
116 | parser.add_argument('--out', type=str, default='sample.png', help='')
117 |
118 | main(parser.parse_args())
119 |
--------------------------------------------------------------------------------
/latent_interpolation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from torchvision.utils import save_image
4 | from utils import seed_everything, load_hparams, load_checkpoint
5 |
6 |
7 | def slerp(low, high, val):
8 | # https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
9 | low_norm = low / torch.norm(low, dim=-1, keepdim=True)
10 | high_norm = high / torch.norm(high, dim=-1, keepdim=True)
11 | omega = torch.acos((low_norm * high_norm).sum(-1))
12 | so = torch.sin(omega)
13 | res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(-1) * low + (torch.sin(val * omega) / so).unsqueeze(-1) * high
14 | return res
15 |
16 |
17 | def class_to_class(args, hparams, class_network, generator, device):
18 | images = []
19 |
20 | print(f'Generating interpolation with {args.steps} steps between class {args.class_1} and {args.class_2}, seed {args.seed}')
21 | with torch.no_grad():
22 | class1 = class_network(torch.tensor(args.class_1, device=device))
23 | class2 = class_network(torch.tensor(args.class_2, device=device))
24 | for interpolation_i in range(args.num_interpolations):
25 | z = torch.randn(1, hparams.code_size, device=device)
26 | for alpha in torch.linspace(0, 1, steps=args.steps, device=device):
27 | if args.interpolation == 'linear':
28 | class_interpolation = class1.lerp(class2, alpha)
29 | else:
30 | class_interpolation = slerp(class1, class2, alpha)
31 | image = generator(z, step=6, alpha=1, task=class_interpolation)
32 | images.append(image.cpu())
33 |
34 | return images
35 |
36 |
37 | def z_to_z(args, hparams, class_network, generator, device):
38 | images = []
39 |
40 | print(f'Generating stuff')
41 | with torch.no_grad():
42 | class1 = class_network(torch.tensor(args.class_1, device=device))
43 | for interpolation_i in range(args.num_interpolations):
44 | z1 = torch.randn(1, hparams.code_size, device=device)
45 | z2 = torch.randn(1, hparams.code_size, device=device)
46 | for alpha in torch.linspace(0, 1, steps=args.steps, device=device):
47 | if args.interpolation == 'linear':
48 | z_interpolation = z1.lerp(z2, alpha)
49 | else:
50 | z_interpolation = slerp(z1, z2, alpha)
51 | image = generator(z_interpolation, step=6, alpha=1, task=class1)
52 | images.append(image.cpu())
53 |
54 | return images
55 |
56 |
57 | def class_noise(args, hparams, class_network, generator, device):
58 | images = []
59 |
60 | print(f'Generating stuff, seed {args.seed}')
61 | with torch.no_grad():
62 | class1 = class_network(torch.tensor(args.class_1, device=device))
63 | class2 = class_network(torch.tensor(args.class_2, device=device))
64 | z1 = torch.randn(1, hparams.code_size, device=device)
65 | z2 = torch.randn(1, hparams.code_size, device=device)
66 | for alpha_z in torch.linspace(0, 1, steps=args.steps, device=device):
67 | if args.interpolation == 'linear':
68 | z_interpolation = z1.lerp(z2, alpha_z)
69 | else:
70 | z_interpolation = slerp(z1, z2, alpha_z)
71 | for alpha_class in torch.linspace(0, 1, steps=args.steps, device=device):
72 | if args.interpolation == 'linear':
73 | class_interpolation = class1.lerp(class2, alpha_class)
74 | else:
75 | class_interpolation = slerp(class1, class2, alpha_class)
76 | image = generator(z_interpolation, step=6, alpha=1, task=class_interpolation)
77 | images.append(image.cpu())
78 |
79 | return images
80 |
81 |
82 | def random_z(args, hparams, class_network, generator, device):
83 | images = []
84 |
85 | print(f'Generating {args.steps * args.num_interpolations} from class {args.class_1} with random z, seed {args.seed}')
86 | with torch.no_grad():
87 | class1 = class_network(torch.tensor(args.class_1, device=device))
88 | for i in range(args.steps * args.num_interpolations):
89 | z = torch.randn(1, hparams.code_size, device=device)
90 | image = generator(z, step=6, alpha=1, task=class1)
91 | images.append(image.cpu())
92 |
93 | return images
94 |
95 |
96 | def noise_shift(args, hparams, class_network, generator, device):
97 | images = []
98 |
99 | print(f'Generating stuff')
100 | with torch.no_grad():
101 | class1 = class_network(torch.tensor(args.class_1, device=device))
102 | for interpolation_i in range(args.num_interpolations):
103 | z = torch.randn(1, hparams.code_size, device=device)
104 | for alpha in torch.linspace(-1, 1, steps=args.steps, device=device):
105 | image = generator(z + alpha, step=6, alpha=1, task=class1)
106 | images.append(image.cpu())
107 |
108 | return images
109 |
110 |
111 | def class_shift(args, hparams, class_network, generator, device):
112 | images = []
113 |
114 | print(f'Generating stuff')
115 | with torch.no_grad():
116 | class1 = class_network(torch.tensor(args.class_1, device=device))
117 | for interpolation_i in range(args.num_interpolations):
118 | z = torch.randn(1, hparams.code_size, device=device)
119 | for alpha in torch.linspace(-0.1, 0.1, steps=args.steps, device=device):
120 | image = generator(z, step=6, alpha=1, task=class1 + alpha)
121 | images.append(image.cpu())
122 |
123 | return images
124 |
125 |
126 | def main(args):
127 | seed_everything(args.seed)
128 | device = args.device
129 | print(f'Loading model {args.checkpoint_path}')
130 | args.from_self_align = False
131 | generator, class_network = load_checkpoint(args)
132 | hparams = load_hparams(args.checkpoint_path)
133 | print('Loaded')
134 | class_network.eval()
135 | generator.eval()
136 |
137 | if args.type == 'class':
138 | images = class_to_class(args, hparams, class_network, generator, device)
139 | elif args.type == 'noise':
140 | images = z_to_z(args, hparams, class_network, generator, device)
141 | elif args.type == 'class_noise':
142 | images = class_noise(args, hparams, class_network, generator, device)
143 | elif args.type == 'random_z':
144 | images = random_z(args, hparams, class_network, generator, device)
145 | elif args.type == 'z_shift':
146 | images = noise_shift(args, hparams, class_network, generator, device)
147 | elif args.type == 'class_shift':
148 | images = class_shift(args, hparams, class_network, generator, device)
149 | else:
150 | raise ValueError
151 |
152 | save_image(torch.cat(images, 0), args.output, nrow=args.steps, padding=0, normalize=True, range=(-1, 1))
153 | print(f'Saved to {args.output}')
154 |
155 |
156 | def get_args():
157 | parser = argparse.ArgumentParser(description='Progressive Growing of GANs')
158 |
159 | parser.add_argument('checkpoint_path', type=str, help='model to use')
160 | parser.add_argument('--output', default='interpolation.png', type=str, help='Output file')
161 | parser.add_argument('--steps', default=10, type=int, help='Number of interpolations')
162 | parser.add_argument('--seed', default=2147483647, type=int, help='Random seed')
163 | parser.add_argument('--device', default='cuda', type=str, help='Device to use')
164 | parser.add_argument('--class_1', default=0, type=int, help='First class for interpolation')
165 | parser.add_argument('--class_2', default=1, type=int, help='Second class for interpolation')
166 | parser.add_argument('--interpolation', default='linear', type=str, choices=['linear', 'spherical'], help='Interpolation type')
167 | parser.add_argument('--num_interpolations', default=1, type=int, help='How many interpolations to perform')
168 | parser.add_argument('--type', default='class', type=str, help='Which interpolation to do')
169 |
170 | return parser.parse_args()
171 |
172 |
173 | if __name__ == '__main__':
174 | main(get_args())
175 |
--------------------------------------------------------------------------------
/metric_impl/dc.py:
--------------------------------------------------------------------------------
1 | r"""PyTorch implementation of Density and Coverage (D&C). Based on Reliable Fidelity and Diversity Metrics for
2 | Generative Models https://arxiv.org/abs/2002.09797 and repository
3 | https://github.com/clovaai/generative-evaluation-prdc/blob/master/prdc/prdc.py
4 | """
5 | from typing import Tuple
6 | import torch
7 |
8 | from piq.base import BaseFeatureMetric
9 | from piq.utils import _validate_input
10 |
11 | from metric_impl.pr import _compute_nearest_neighbour_distances, _compute_pairwise_distance
12 |
13 |
14 | class DC(BaseFeatureMetric):
15 | r"""Interface of Density and Coverage.
16 | It's computed for a whole set of data and uses features from encoder instead of images itself to decrease
17 | computation cost. Density and Coverage can compare two data distributions with different number of samples.
18 | But dimensionalities should match, otherwise it won't be possible to correctly compute statistics.
19 |
20 | Args:
21 | real_features: Samples from data distribution. Shape :math:`(N_x, D)`
22 | fake_features: Samples from generated distribution. Shape :math:`(N_y, D)`
23 |
24 | Returns:
25 | density: Scalar value of the density of image sets features.
26 | coverage: Scalar value of the coverage of image sets features.
27 |
28 | References:
29 | Ferjad Naeem M. et al. (2020).
30 | Reliable Fidelity and Diversity Metrics for Generative Models.
31 | International Conference on Machine Learning,
32 | https://arxiv.org/abs/2002.09797
33 | """
34 |
35 | def __init__(self, nearest_k: int = 5) -> None:
36 | r"""
37 | Args:
38 | nearest_k: Nearest neighbor to compute the non-parametric representation. Shape :math:`1`
39 | """
40 | super(DC, self).__init__()
41 |
42 | self.nearest_k = nearest_k
43 |
44 | def compute_metric(self, y_features: torch.Tensor, x_features: torch.Tensor) \
45 | -> Tuple[torch.Tensor, torch.Tensor]:
46 | r"""Creates non-parametric representations of the manifolds of real and generated data and computes
47 | the density and coverage between them.
48 |
49 | Args:
50 | real_features: Samples from data distribution. Shape :math:`(N_x, D)`
51 | fake_features: Samples from fake distribution. Shape :math:`(N_x, D)`
52 | Returns:
53 | precision: Scalar value of the precision of the generated images.
54 | recall: Scalar value of the recall of the generated images.
55 | """
56 | # _validate_input([real_features, fake_features], dim_range=(2, 2), size_range=(1, 2))
57 | real_features = y_features
58 | fake_features = x_features
59 | real_nearest_neighbour_distances = _compute_nearest_neighbour_distances(real_features, self.nearest_k)
60 | distance_real_fake = _compute_pairwise_distance(real_features, fake_features)
61 |
62 | density = (1 / self.nearest_k) * (
63 | distance_real_fake < real_nearest_neighbour_distances.unsqueeze(1)
64 | ).sum(dim=0).float().mean()
65 |
66 | coverage = (
67 | distance_real_fake.min(dim=1)[0] < real_nearest_neighbour_distances
68 | ).float().mean()
69 |
70 | return density, coverage
71 |
--------------------------------------------------------------------------------
/metric_impl/ppl.py:
--------------------------------------------------------------------------------
1 | # From https://github.com/rosinality/stylegan2-pytorch/blob/master/ppl.py
2 | import argparse
3 |
4 | import torch
5 | from torch.nn import functional as F
6 | import numpy as np
7 | from tqdm import tqdm
8 |
9 | import lpips
10 | from utils import load_hparams, load_checkpoint
11 |
12 |
13 | def normalize(x):
14 | return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True))
15 |
16 |
17 | def slerp(a, b, t):
18 | a = normalize(a)
19 | b = normalize(b)
20 | d = (a * b).sum(-1, keepdim=True)
21 | p = t * torch.acos(d)
22 | c = normalize(b - d * a)
23 | d = a * torch.cos(p) + c * torch.sin(p)
24 |
25 | return normalize(d)
26 |
27 |
28 | def lerp(a, b, t):
29 | return a + (b - a) * t
30 |
31 |
32 | if __name__ == "__main__":
33 | parser = argparse.ArgumentParser(description="Perceptual Path Length calculator")
34 |
35 | parser.add_argument(
36 | "--space", choices=["z", "w"], default="w", help="space that PPL calculated with"
37 | )
38 | parser.add_argument(
39 | "--batch", type=int, default=64, help="batch size for the models"
40 | )
41 | parser.add_argument(
42 | "--n_sample",
43 | type=int,
44 | default=5000,
45 | help="number of the samples for calculating PPL",
46 | )
47 | parser.add_argument(
48 | "--size", type=int, default=256, help="output image sizes of the generator"
49 | )
50 | parser.add_argument(
51 | "--eps", type=float, default=1e-4, help="epsilon for numerical stability"
52 | )
53 | parser.add_argument(
54 | "--crop", action="store_true", help="apply center crop to the images"
55 | )
56 | parser.add_argument(
57 | "--sampling",
58 | default="end",
59 | choices=["end", "full"],
60 | help="set endpoint sampling method",
61 | )
62 | parser.add_argument('checkpoint_path', type=str, help='model to use')
63 | parser.add_argument('--device', type=str, default='cuda', help='model to use')
64 | parser.add_argument('--class_1', default=0, type=int, help='First class for interpolation')
65 | parser.add_argument('--class_2', default=1, type=int, help='Second class for interpolation')
66 |
67 | args = parser.parse_args()
68 | args.from_self_align = False
69 | device = args.device
70 |
71 | g, c_net = load_checkpoint(args)
72 | hparams = load_hparams(args.checkpoint_path)
73 | g.eval()
74 | c_net.eval()
75 |
76 | percept = lpips.PerceptualLoss(
77 | model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
78 | )
79 |
80 | distances = []
81 |
82 | n_batch = args.n_sample // args.batch
83 | resid = args.n_sample - (n_batch * args.batch)
84 | batch_sizes = [args.batch] * n_batch + [resid]
85 |
86 | with torch.no_grad():
87 | for batch in tqdm(batch_sizes):
88 |
89 | z = torch.randn([batch * 2, hparams.code_size], device=device)
90 | if args.sampling == "full":
91 | lerp_t = torch.rand(batch, device=device)
92 | else:
93 | lerp_t = torch.zeros(batch, device=device)
94 |
95 | if args.space == "w":
96 | latent_t0, latent_t1 = c_net(torch.tensor(batch * [args.class_1], device=device)), c_net(torch.tensor(batch * [args.class_2], device=device))
97 | latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None])
98 | latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps)
99 | latent_e = torch.stack([latent_e0, latent_e1], 1).view(*[batch * 2, hparams.task_size])
100 |
101 | # image, _ = g([latent_e])
102 | image = g(z, step=6, alpha=1, task=latent_e)
103 |
104 | if args.crop:
105 | c = image.shape[2] // 8
106 | image = image[:, :, c * 3 : c * 7, c * 2 : c * 6]
107 |
108 | factor = image.shape[2] // 256
109 |
110 | if factor > 1:
111 | image = F.interpolate(
112 | image, size=(256, 256), mode="bilinear", align_corners=False
113 | )
114 |
115 | dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / (
116 | args.eps ** 2
117 | )
118 | distances.append(dist.to("cpu").numpy())
119 |
120 | distances = np.concatenate(distances, 0)
121 |
122 | lo = np.percentile(distances, 1, interpolation="lower")
123 | hi = np.percentile(distances, 99, interpolation="higher")
124 | filtered_dist = np.extract(
125 | np.logical_and(lo <= distances, distances <= hi), distances
126 | )
127 |
128 | print("ppl:", filtered_dist.mean())
129 |
--------------------------------------------------------------------------------
/metric_impl/pr.py:
--------------------------------------------------------------------------------
1 | r"""PyTorch implementation of Improved Precision and Recall (P&R). Based on Improved Precision and Recall Metric for
2 | Assessing Generative Models https://arxiv.org/abs/1904.06991 and repository
3 | https://github.com/clovaai/generative-evaluation-prdc/blob/master/prdc/prdc.py
4 | """
5 | from typing import Tuple, Optional
6 | import torch
7 |
8 | from piq.base import BaseFeatureMetric
9 | from piq.utils import _validate_input
10 |
11 |
12 | def _compute_pairwise_distance(data_x: torch.Tensor, data_y: Optional[torch.Tensor] = None) -> torch.Tensor:
13 | r"""Compute Euclidean distance between :math:`x` and :math:`y`.
14 |
15 | Args:
16 | data_x: Tensor of shape :math:`(N, feature_dim)`
17 | data_y: Tensor of shape :math:`(N, feature_dim)`
18 | Returns:
19 | Tensor of shape :math:`(N, N)` of pairwise distances.
20 | """
21 | if data_y is None:
22 | data_y = data_x
23 | dists = torch.cdist(data_x, data_y, p=2)
24 | return dists
25 |
26 |
27 | def _get_kth_value(unsorted: torch.Tensor, k: int, axis: int = -1) -> torch.Tensor:
28 | r"""
29 | Args:
30 | unsorted: Tensor of any dimensionality.
31 | k: Int of the :math:`k`-th value to retrieve.
32 | Returns:
33 | kth values along the designated axis.
34 | """
35 | k_smallests = torch.topk(unsorted, k, dim=axis, largest=False)[0]
36 | kth_values = k_smallests.max(dim=axis)[0]
37 | return kth_values
38 |
39 |
40 | def _compute_nearest_neighbour_distances(input_features: torch.Tensor, nearest_k: int) -> torch.Tensor:
41 | r"""Compute K-nearest neighbour distances.
42 |
43 | Args:
44 | input_features: Tensor of shape :math:`(N, feature_dim)`
45 | nearest_k: Int of the :math:`k`-th nearest neighbour.
46 | Returns:
47 | Distances to :math:`k`-th nearest neighbours.
48 | """
49 | distances = _compute_pairwise_distance(input_features)
50 | radii = _get_kth_value(distances, k=nearest_k + 1, axis=-1)
51 | return radii
52 |
53 |
54 | class PR(BaseFeatureMetric):
55 | r"""Interface of Improved Precision and Recall.
56 | It's computed for a whole set of data and uses features from encoder instead of images itself to decrease
57 | computation cost. Precision and Recall can compare two data distributions with different number of samples.
58 | But dimensionalities should match, otherwise it won't be possible to correctly compute statistics.
59 |
60 | Args:
61 | real_features: Samples from data distribution. Shape :math:`(N_x, D)`
62 | fake_features: Samples from generated distribution. Shape :math:`(N_y, D)`
63 |
64 | Returns:
65 | precision: Scalar value of the precision of image sets features.
66 | recall: Scalar value of the recall of image sets features.
67 |
68 | References:
69 | Kynkäänniemi T. et al. (2019).
70 | Improved Precision and Recall Metric for Assessing Generative Models.
71 | Advances in Neural Information Processing Systems,
72 | https://arxiv.org/abs/1904.06991
73 | """
74 |
75 | def __init__(self, nearest_k: int = 5) -> None:
76 | r"""
77 | Args:
78 | nearest_k: Nearest neighbor to compute the non-parametric representation. Shape :math:`1`
79 | """
80 | super(PR, self).__init__()
81 |
82 | self.nearest_k = nearest_k
83 |
84 | def compute_metric(self, y_features: torch.Tensor, x_features: torch.Tensor) \
85 | -> Tuple[torch.Tensor, torch.Tensor]:
86 | r"""Creates non-parametric representations of the manifolds of real and generated data and computes
87 | the precision and recall between them.
88 |
89 | Args:
90 | real_features: Samples from data distribution. Shape :math:`(N_x, D)`
91 | fake_features: Samples from fake distribution. Shape :math:`(N_x, D)`
92 | Returns:
93 | precision: Scalar value of the precision of the generated images.
94 | recall: Scalar value of the recall of the generated images.
95 | """
96 | # _validate_input([real_features, fake_features], dim_range=(2, 2), size_range=(1, 2))
97 | real_features = y_features
98 | fake_features = x_features
99 | real_nearest_neighbour_distances = _compute_nearest_neighbour_distances(real_features, self.nearest_k)
100 | fake_nearest_neighbour_distances = _compute_nearest_neighbour_distances(fake_features, self.nearest_k)
101 | distance_real_fake = _compute_pairwise_distance(real_features, fake_features)
102 |
103 | precision = (
104 | distance_real_fake < real_nearest_neighbour_distances.unsqueeze(1)
105 | ).any(dim=0).float().mean()
106 |
107 | recall = (
108 | distance_real_fake < fake_nearest_neighbour_distances.unsqueeze(0)
109 | ).any(dim=1).float().mean()
110 |
111 | return precision, recall
112 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | from typing import Optional, Tuple, Union
4 | from io import BytesIO
5 |
6 | from tqdm import tqdm, trange
7 | import numpy as np
8 | import lmdb
9 | import torch
10 | from torchvision.transforms import ToTensor
11 | from dataset import MultiResolutionDataset
12 | from torch.utils.data import DataLoader
13 | from piq.feature_extractors import InceptionV3
14 | from piq import FID, KID
15 |
16 | from metric_impl.pr import PR
17 | from metric_impl.dc import DC
18 | from utils import load_hparams, load_checkpoint
19 |
20 |
21 | def to_bytes(img):
22 | buffer = BytesIO()
23 | torch.save(img, buffer)
24 | return buffer.getvalue()
25 |
26 |
27 | def get_z(num_samples=1, code_dim=512, device='cuda'):
28 | return torch.randn(num_samples, code_dim, device=device)
29 |
30 |
31 | def clamp_img(img: torch.Tensor, range: Optional[Tuple[int, int]] = (-1, 1)) -> None:
32 | img.clamp_(range[0], range[1])
33 |
34 |
35 | def featurize(extractor, imgs):
36 | with torch.no_grad():
37 | feats = extractor(imgs)
38 | return feats[0].squeeze(-1).squeeze(-1)
39 |
40 |
41 | def pick_samples(orig_size, selection_size, method):
42 | assert method in ['first', 'random']
43 | if selection_size is None:
44 | return torch.arange(orig_size)
45 | if selection_size > orig_size:
46 | selection_size = orig_size
47 |
48 | if method == 'first':
49 | return torch.arange(selection_size)
50 | elif method == 'random':
51 | return torch.from_numpy(np.random.choice(orig_size, size=selection_size, replace=False))
52 |
53 |
54 | def load_feats(path, classes=1, load_n=None, sampling_method='first', device='cpu'):
55 | """
56 | Load precomputed inception features from an LMDB file.
57 |
58 | :param path: Path to the features.
59 | :param classes: Int with the number of classes to load or iterable consisting of class indices.
60 | :param load_n: Will load N samples per class. To pick which ones, see sampling_method param.
61 | :param sampling_method: Pick the first N samples for each class, or pick N random ones.
62 | :param device: Which device to load features to.
63 | """
64 | if isinstance(classes, int):
65 | classes_iter = range(classes)
66 | elif isinstance(classes, tuple) or isinstance(classes, list):
67 | classes_iter = classes
68 | else:
69 | raise ValueError
70 |
71 | feats = {}
72 | with lmdb.open(path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) as env:
73 | for c in classes_iter:
74 | with env.begin(write=False) as txn:
75 | key = f'{c}'.encode('utf-8')
76 | feat = txn.get(key)
77 | feat = torch.load(BytesIO(feat), map_location=device)
78 | # in the loop bc classes can have different number of samples
79 | idx = pick_samples(orig_size=feat.shape[0], selection_size=load_n, method=sampling_method)
80 | # clone for keeping only the slice in memory, not the whole underlying tensor storage
81 | feats[c] = feat[idx].clone()
82 |
83 | return feats
84 |
85 |
86 | def generate(z, model, task, classes, step=6, device='cuda'):
87 | if isinstance(classes, tuple) or isinstance(classes, list) or isinstance(classes, range):
88 | classes = torch.tensor(classes, device=device)
89 | elif isinstance(classes, torch.Tensor):
90 | if classes.device != device:
91 | classes = classes.to(device)
92 | else:
93 | raise ValueError("Wrong classes type")
94 |
95 | with torch.no_grad():
96 | task_input = task(classes)
97 | output = model(z, task=task_input, step=step)
98 | return output
99 |
100 |
101 | def generate_class(model, task, class_i, num_samples=10, code_dim=512, device='cuda'):
102 | if isinstance(class_i, torch.Tensor):
103 | conditioning = class_i
104 | else:
105 | conditioning = [class_i] * num_samples
106 | z = get_z(num_samples, code_dim=code_dim, device=device)
107 | return generate(z, model, task, conditioning, device=device)
108 |
109 |
110 | def generate_metrics_by_class(generator, task, metrics: dict, real_feats: Union[torch.Tensor, dict], num_classes=1, num_samples=10_000, batch_size=2, code_dim=512, feats_dim=2048, device='cuda'):
111 | metric_results_classes = {}
112 | feature_extractor = InceptionV3(normalize_input=False) # already (-1, 1) range from generator output
113 | feature_extractor.to(device)
114 | feature_extractor.eval()
115 |
116 | for c in trange(num_classes, desc='Classes', leave=False, disable=num_classes == 1):
117 | metric_results = {}
118 | fake_feats = torch.empty((num_samples, feats_dim), device=device)
119 | last_batch_size = num_samples % batch_size
120 |
121 | for idx, i in enumerate(trange(0, num_samples, batch_size, desc='Samples', leave=False)):
122 | current_bs = last_batch_size if idx == num_samples // batch_size else batch_size
123 | imgs = generate_class(generator, task, c, current_bs, code_dim, device=device)
124 | clamp_img(imgs) # sometimes images go beyond their range, thus Inception extractor will complain
125 | fake_feats[i:i + current_bs] = featurize(feature_extractor, imgs)
126 |
127 | for metric_name, metric in metrics.items():
128 | fake_i = None if metric_name != 'kid' else real_feats[c].shape[0] # kid needs reals >= fakes
129 | score = metric(x_features=fake_feats[:fake_i], y_features=real_feats[c]) # predicted_feats, target_feats
130 | if isinstance(score, tuple):
131 | score = {f'{i}': s.item() for i, s in enumerate(score)}
132 | elif isinstance(score, torch.Tensor):
133 | score = score.item()
134 | else:
135 | raise RuntimeError('Unknown metric score type')
136 | metric_results[metric_name] = score
137 |
138 | metric_results_classes[c] = metric_results
139 |
140 | return metric_results_classes
141 |
142 |
143 | def generate_metrics_mixed_class(generator, task, metrics: dict, real_feats: Union[torch.Tensor, dict], num_classes=1, num_samples=10_000, batch_size=2, code_dim=512, feats_dim=2048, device='cuda'):
144 | feature_extractor = InceptionV3(normalize_input=False) # already (-1, 1) range from generator output
145 | feature_extractor.to(device)
146 | feature_extractor.eval()
147 |
148 | metric_results = {}
149 | real_feats = real_feats[0]
150 | fake_feats = torch.empty((num_samples, feats_dim), device=device)
151 | last_batch_size = num_samples % batch_size
152 |
153 | for idx, i in enumerate(trange(0, num_samples, batch_size, desc='Samples', leave=False)):
154 | current_bs = last_batch_size if idx == num_samples // batch_size else batch_size
155 | c = torch.from_numpy(np.random.choice(num_classes, size=current_bs)).to(device)
156 | imgs = generate_class(generator, task, c, current_bs, code_dim, device=device)
157 | clamp_img(imgs) # sometimes images go beyond their range, thus Inception extractor will complain
158 | fake_feats[i:i + current_bs] = featurize(feature_extractor, imgs)
159 |
160 | for metric_name, metric in metrics.items():
161 | fake_i = None if metric_name != 'kid' else real_feats.shape[0] # kid needs reals >= fakes
162 | score = metric(x_features=fake_feats[:fake_i], y_features=real_feats) # predicted_feats, target_feats
163 | if isinstance(score, tuple):
164 | score = {f'{i}': s.item() for i, s in enumerate(score)}
165 | elif isinstance(score, torch.Tensor):
166 | score = score.item()
167 | else:
168 | raise RuntimeError('Unknown metric score type')
169 | metric_results[metric_name] = score
170 |
171 | return {0: metric_results}
172 |
173 |
174 | def compute_feats_from_dataset(db_path, batch_size=1, output_path='output_lmdb', device='cuda'):
175 | feats_dim = 2048
176 | resolution = 256
177 | feature_extractor = InceptionV3(normalize_input=True)
178 | feature_extractor.to(device)
179 | feature_extractor.eval()
180 |
181 | dataset = MultiResolutionDataset(db_path, transform=ToTensor(), resolution=resolution)
182 | num_classes = dataset.num_classes
183 | with lmdb.open(output_path, map_size=1024 ** 4, readahead=False) as env:
184 | for c in trange(num_classes, desc='Classes'):
185 | dataset = MultiResolutionDataset(db_path, transform=ToTensor(), resolution=resolution, selected_classes=[c])
186 | feats_c = torch.empty((dataset.length, feats_dim), device=device)
187 | for i, (imgs, labels) in enumerate(tqdm(DataLoader(dataset, batch_size=batch_size, num_workers=4), desc='Batches', leave=False)):
188 | feats_c[i * batch_size:i * batch_size + imgs.size(0)] = featurize(feature_extractor, imgs.to(device))
189 | with env.begin(write=True) as txn:
190 | key = f'{c}'.encode('utf-8')
191 | txn.put(key, to_bytes(feats_c.cpu()))
192 |
193 |
194 | def compute_metrics_per_ckpt(args):
195 | """
196 | Compute specified metrics class-wise or for all classes and save results in a pickle. It will be run for each
197 | checkpoint it finds, creating a lock file for parallel computation (cheap but useful).
198 | """
199 | import pickle
200 | mix_dir = 'mixed' if args.classes == 'interclass' else ''
201 | metrics = {'fid': FID(), 'kid': KID(), 'pr': PR(), 'dc': DC()}
202 | hparams = load_hparams(args.checkpoint_path)
203 | if not hasattr(hparams, 'num_classes'):
204 | hparams.num_classes = 1
205 | # 1 class and random if mix classes bc the precomputed feats dataset supposed to be passed has all classes in
206 | # the first class, so getting the first 10k will likely just get all samples from one/two classes
207 | real_feats = load_feats(args.precomputed_feats, classes=1 if args.classes == 'interclass' else hparams.num_classes,
208 | load_n=10_000, sampling_method='random' if args.classes == 'interclass' else 'first',
209 | device=args.device)
210 |
211 | checkpoint_folder = Path(args.checkpoint_path)
212 | if args.classes == 'interclass':
213 | (checkpoint_folder / mix_dir).mkdir(exist_ok=True)
214 |
215 | for epoch in tqdm(list(sorted(checkpoint_folder.glob('*.model'), reverse=True)), desc='Epoch'):
216 | output_score = epoch.parent / mix_dir / (epoch.stem + '.metric')
217 | epoch_lock = epoch.parent / mix_dir / (epoch.stem + '.lock')
218 | if output_score.exists() or epoch_lock.exists():
219 | continue
220 | epoch_lock.touch(exist_ok=False)
221 |
222 | args.checkpoint_path = str(epoch)
223 | try:
224 | generator, task = load_checkpoint(args, hparams)
225 | except Exception as e:
226 | epoch_lock.unlink()
227 | raise e
228 |
229 | if args.classes == 'interclass':
230 | scores = generate_metrics_mixed_class(generator, task, metrics, real_feats,
231 | num_classes=hparams.num_classes, num_samples=args.num_samples,
232 | batch_size=args.batch_size, code_dim=hparams.code_size,
233 | feats_dim=args.feats_dim, device=args.device)
234 | elif args.classes == 'intraclass':
235 | scores = generate_metrics_by_class(generator, task, metrics, real_feats,
236 | num_classes=hparams.num_classes, num_samples=args.num_samples,
237 | batch_size=args.batch_size, code_dim=hparams.code_size,
238 | feats_dim=args.feats_dim, device=args.device)
239 | else:
240 | raise ValueError
241 |
242 | with open(output_score, 'wb') as f:
243 | pickle.dump(scores, f)
244 |
245 | if args.remove_checkpoint_afterwards:
246 | epoch.unlink()
247 | epoch_lock.unlink()
248 |
249 |
250 | def compute_metrics_biggan(imgs_path, num_samples=10_000, batch_size=2, feats_dim=2048, device='cuda'):
251 | """
252 | Read BigGAN generations (npz file) and compute metrics.
253 | """
254 | metrics = {'fid': FID(), 'kid': KID(), 'pr': PR(), 'dc': DC()}
255 | np_imgs = np.load(imgs_path)
256 | print(f'Loading {imgs_path}')
257 | y = np_imgs['y']
258 | num_classes = len(np.unique(y))
259 | assert np.array_equal(np.arange(num_classes), np.unique(y)), f'Discontinuous number of classes in {imgs_path}'
260 |
261 | x = {}
262 | for c in trange(num_classes, desc='Loading classes'):
263 | idcs = np.where(y == c)[0][:num_samples]
264 | if len(idcs) < num_samples:
265 | print(f'Class {c} has less number of samples than requested ({num_samples})')
266 | x[c] = np_imgs['x'][idcs]
267 |
268 | real_feats = load_feats(args.precomputed_feats, classes=num_classes, load_n=10_000, device=args.device)
269 |
270 | metric_results_classes = {}
271 | feature_extractor = InceptionV3(normalize_input=True) # [0, 1] to [-1, 1]
272 | feature_extractor.to(device)
273 | feature_extractor.eval()
274 |
275 | for c in trange(num_classes, desc='Classes', leave=False, disable=num_classes == 1):
276 | metric_results = {}
277 | fake_feats = torch.empty((num_samples, feats_dim), device=device)
278 | last_batch_size = num_samples % batch_size
279 |
280 | for idx, i in enumerate(trange(0, num_samples, batch_size, desc='Samples', leave=False)):
281 | current_bs = last_batch_size if idx == num_samples // batch_size else batch_size
282 | imgs = torch.from_numpy(x[c][i:i + current_bs]).to(device) / 255
283 | fake_feats[i:i + current_bs] = featurize(feature_extractor, imgs)
284 |
285 | for metric_name, metric in metrics.items():
286 | score = metric(real_feats[c], fake_feats)
287 | if isinstance(score, tuple):
288 | score = {f'{i}': s.item() for i, s in enumerate(score)}
289 | elif isinstance(score, torch.Tensor):
290 | score = score.item()
291 | else:
292 | raise RuntimeError('Unknown metric score type')
293 | metric_results[metric_name] = score
294 |
295 | metric_results_classes[c] = metric_results
296 |
297 | return metric_results_classes
298 |
299 |
300 | if __name__ == '__main__':
301 | parser = argparse.ArgumentParser()
302 | parser.add_argument('checkpoint_path', type=str, help='Path of specified dataset')
303 | parser.add_argument('--output_path', type=str, default='Output LMDB for precomputed features')
304 | parser.add_argument('--precomputed_feats', type=str, default=None, help='If the function needs it')
305 | parser.add_argument('--device', type=str, default='cuda')
306 | parser.add_argument('--num_samples', type=int, default=10_000)
307 | parser.add_argument('--batch_size', type=int, default=32)
308 | parser.add_argument('--feats_dim', type=int, default=2048)
309 | parser.add_argument('--remove_checkpoint_afterwards', action='store_true', help='Remove computing metrics')
310 | parser.add_argument('--operation', type=str, default='metrics', choices=['extract-feats', 'metrics', 'metrics-biggan'])
311 | parser.add_argument('--classes', type=str, default='intraclass', choices=['interclass', 'intraclass'], help='Evaluation regarding conditioning')
312 | args = parser.parse_args()
313 |
314 | if args.operation == 'extract-feats':
315 | compute_feats_from_dataset(args.checkpoint_path, batch_size=args.batch_size, output_path=args.output_path, device=args.device)
316 | elif args.operation == 'metrics':
317 | compute_metrics_per_ckpt(args)
318 | elif args.operation == 'metrics-biggan':
319 | print(compute_metrics_biggan(args.checkpoint_path, batch_size=args.batch_size))
320 |
--------------------------------------------------------------------------------
/model/hyper_mod.py:
--------------------------------------------------------------------------------
1 | import random
2 | from math import sqrt
3 |
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 |
8 | from model.stylegan import PixelNorm, equal_lr, EqualLinear, EqualConv2d, ConstantInput, Blur, NoiseInjection,\
9 | AdaptiveInstanceNorm, ConvBlock
10 |
11 |
12 | class SequentialWithParams(nn.Sequential):
13 | def forward(self, x):
14 | x, task = x
15 | for m in self:
16 | if isinstance(m, (AdaptiveFilterModulation, FusedUpsampleAdaFM)):
17 | x = m(x, task)
18 | else:
19 | x = m(x)
20 |
21 | return x
22 |
23 |
24 | def equalize(weight):
25 | """Batched equalization. Since all kernels have the same shape we can apply the scaling to all batches."""
26 | fan_in = weight[0].data.size(1) * weight[0].data[0][0].numel()
27 | weight *= sqrt(2 / fan_in)
28 |
29 |
30 | class AdaptiveFilterModulation(nn.Module):
31 | def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, task_dim=0):
32 | super().__init__()
33 | self.stride = stride
34 | self.padding = padding
35 |
36 | # these two will be normalized and initialized when loading from a pretrained model
37 | self.register_buffer('W', torch.empty((out_channel, in_channel, kernel_size, kernel_size)))
38 | self.register_buffer('b', torch.empty(out_channel))
39 |
40 | self.task = EqualLinear(task_dim, out_channel * in_channel * 2)
41 | self.task.linear.bias.data[:] = 0
42 |
43 | self.task_bias_beta = EqualLinear(task_dim, out_channel)
44 | self.task_bias_beta.linear.bias.data[:] = 0
45 |
46 | def forward(self, x, task):
47 | gamma, beta = self.task(task).view(-1, self.W.shape[0], self.W.shape[1], 2).unsqueeze(4).chunk(2, 3)
48 | bias_beta = self.task_bias_beta(task)
49 |
50 | W_i = self.W * gamma + beta
51 | equalize(W_i)
52 | b_i = self.b + bias_beta
53 | out = conv1to1(x, W_i, bias=b_i, stride=self.stride, padding=self.padding)
54 | return out
55 |
56 |
57 | class FusedUpsampleAdaFM(nn.Module):
58 | def __init__(self, in_channel, out_channel, kernel_size, stride=2, padding=0, task_dim=0):
59 | super().__init__()
60 | self.stride = stride
61 | self.padding = padding
62 |
63 | self.register_buffer('W', torch.empty((in_channel, out_channel, kernel_size, kernel_size)))
64 | self.register_buffer('b', torch.empty(out_channel))
65 |
66 | self.task = EqualLinear(task_dim, out_channel * in_channel * 2)
67 | self.task.linear.bias.data[:] = 0
68 |
69 | self.task_bias_beta = EqualLinear(task_dim, out_channel)
70 | self.task_bias_beta.linear.bias.data[:] = 0
71 |
72 | def forward(self, x, task):
73 | weight = F.pad(self.W, [1, 1, 1, 1])
74 | weight = (
75 | weight[:, :, 1:, 1:]
76 | + weight[:, :, :-1, 1:]
77 | + weight[:, :, 1:, :-1]
78 | + weight[:, :, :-1, :-1]
79 | ) / 4
80 |
81 | gamma, beta = self.task(task).view(-1, weight.shape[0], weight.shape[1], 2).unsqueeze(4).chunk(2, 3)
82 | bias_beta = self.task_bias_beta(task)
83 |
84 | W_i = weight * gamma + beta
85 | equalize(W_i)
86 | b_i = self.b + bias_beta
87 | out = conv1to1(x, W_i, bias=b_i, stride=self.stride, padding=self.padding, transposed=True, mode='upsample')
88 | return out
89 |
90 |
91 | def conv1to1(x, weights, bias=None, stride=1, padding=0, transposed=False, mode='normal'):
92 | """Apply a different kernel to a different sample. result[i] = conv(input[i], weights[i]).
93 | From https://discuss.pytorch.org/t/how-to-apply-different-kernels-to-each-example-in-a-batch-when-using-convolution/84848/3"""
94 | F_conv = F.conv_transpose2d if transposed else F.conv2d
95 | N, C, H, W = x.shape
96 | KN, KO, KI, KH, KW = weights.shape
97 | assert N == KN
98 |
99 | # group weights
100 | weights = weights.view(-1, KI, KH, KW)
101 | bias = bias.view(-1)
102 |
103 | # move batch dim into channels
104 | x = x.view(1, -1, H, W)
105 |
106 | # Apply grouped conv
107 | outputs_grouped = F_conv(x, weights, bias=bias, stride=stride, padding=padding, groups=N)
108 | if mode == 'upsample' or mode == 'downsample' or (outputs_grouped.shape[2] == 1 and outputs_grouped.shape[3] == 1):
109 | outputs_grouped = outputs_grouped.view(N, -1, outputs_grouped.shape[2], outputs_grouped.shape[3])
110 | else:
111 | outputs_grouped = outputs_grouped.view(N, KO, H, W)
112 | return outputs_grouped
113 |
114 |
115 | class StyledConvBlock(nn.Module):
116 | def __init__(
117 | self,
118 | in_channel,
119 | out_channel,
120 | kernel_size=3,
121 | padding=1,
122 | style_dim=512,
123 | initial=False,
124 | upsample=False,
125 | fused=False,
126 | task_dim=512,
127 | ):
128 | super().__init__()
129 |
130 | if initial:
131 | self.conv1 = SequentialWithParams(
132 | ConstantInput(in_channel)
133 | )
134 |
135 | else:
136 | if upsample:
137 | if fused:
138 | self.conv1 = SequentialWithParams(
139 | FusedUpsampleAdaFM(in_channel, out_channel, kernel_size, padding=padding, task_dim=task_dim),
140 | Blur(out_channel),
141 | )
142 |
143 | else:
144 | self.conv1 = SequentialWithParams(
145 | nn.Upsample(scale_factor=2, mode='nearest'),
146 | AdaptiveFilterModulation(in_channel, out_channel, kernel_size, padding=padding, task_dim=task_dim),
147 | Blur(out_channel),
148 | )
149 |
150 | else:
151 | self.conv1 = SequentialWithParams(
152 | AdaptiveFilterModulation(in_channel, out_channel, kernel_size, padding=padding, task_dim=task_dim))
153 |
154 | self.noise1 = equal_lr(NoiseInjection(out_channel))
155 | self.adain1 = AdaptiveInstanceNorm(out_channel, style_dim)
156 | self.lrelu1 = nn.LeakyReLU(0.2)
157 |
158 | self.conv2 = AdaptiveFilterModulation(out_channel, out_channel, kernel_size, padding=padding, task_dim=task_dim)
159 | self.noise2 = equal_lr(NoiseInjection(out_channel))
160 | self.adain2 = AdaptiveInstanceNorm(out_channel, style_dim)
161 | self.lrelu2 = nn.LeakyReLU(0.2)
162 |
163 | def forward(self, input, style, noise, task):
164 | out = self.conv1((input, task))
165 | out = self.noise1(out, noise)
166 | out = self.lrelu1(out)
167 | out = self.adain1(out, style)
168 |
169 | out = self.conv2(out, task)
170 | out = self.noise2(out, noise)
171 | out = self.lrelu2(out)
172 | out = self.adain2(out, style)
173 |
174 | return out
175 |
176 |
177 | class Generator(nn.Module):
178 | def __init__(self, code_dim, task_dim, fused=True):
179 | super().__init__()
180 |
181 | self.progression = nn.ModuleList(
182 | [
183 | StyledConvBlock(512, 512, 3, 1, initial=True, task_dim=task_dim), # 4
184 | StyledConvBlock(512, 512, 3, 1, upsample=True, task_dim=task_dim), # 8
185 | StyledConvBlock(512, 512, 3, 1, upsample=True, task_dim=task_dim), # 16
186 | StyledConvBlock(512, 512, 3, 1, upsample=True, task_dim=task_dim), # 32
187 | StyledConvBlock(512, 256, 3, 1, upsample=True, task_dim=task_dim), # 64
188 | StyledConvBlock(256, 128, 3, 1, upsample=True, fused=fused, task_dim=task_dim), # 128
189 | StyledConvBlock(128, 64, 3, 1, upsample=True, fused=fused, task_dim=task_dim), # 256
190 | StyledConvBlock(64, 32, 3, 1, upsample=True, fused=fused, task_dim=task_dim), # 512
191 | StyledConvBlock(32, 16, 3, 1, upsample=True, fused=fused, task_dim=task_dim), # 1024
192 | ]
193 | )
194 |
195 | self.to_rgb = nn.ModuleList(
196 | [
197 | EqualConv2d(512, 3, 1),
198 | EqualConv2d(512, 3, 1),
199 | EqualConv2d(512, 3, 1),
200 | EqualConv2d(512, 3, 1),
201 | EqualConv2d(256, 3, 1),
202 | EqualConv2d(128, 3, 1),
203 | EqualConv2d(64, 3, 1),
204 | EqualConv2d(32, 3, 1),
205 | EqualConv2d(16, 3, 1),
206 | ]
207 | )
208 |
209 | def forward(self, style, noise, step=0, alpha=-1, mixing_range=(-1, -1), task=None, return_hierarchical=False):
210 | out = noise[0]
211 | hierarchical_out = []
212 |
213 | if len(style) < 2:
214 | inject_index = [len(self.progression) + 1]
215 |
216 | else:
217 | inject_index = sorted(random.sample(list(range(step)), len(style) - 1))
218 |
219 | crossover = 0
220 |
221 | for i, (conv, to_rgb) in enumerate(zip(self.progression, self.to_rgb)):
222 | if mixing_range == (-1, -1):
223 | if crossover < len(inject_index) and i > inject_index[crossover]:
224 | crossover = min(crossover + 1, len(style))
225 |
226 | style_step = style[crossover]
227 |
228 | else:
229 | if mixing_range[0] <= i <= mixing_range[1]:
230 | style_step = style[1]
231 |
232 | else:
233 | style_step = style[0]
234 |
235 | if i > 0 and step > 0:
236 | out_prev = out
237 |
238 | out = conv(out, style_step, noise[i], task)
239 | if return_hierarchical:
240 | hierarchical_out.append(out)
241 |
242 | if i == step:
243 | out = to_rgb(out)
244 |
245 | if i > 0 and 0 <= alpha < 1:
246 | skip_rgb = self.to_rgb[i - 1](out_prev)
247 | skip_rgb = F.interpolate(skip_rgb, scale_factor=2, mode='nearest')
248 | out = (1 - alpha) * skip_rgb + alpha * out
249 |
250 | break
251 |
252 | if return_hierarchical:
253 | return out, hierarchical_out
254 | else:
255 | return out
256 |
257 |
258 | class StyledGenerator(nn.Module):
259 | def __init__(self, code_dim=512, n_mlp=8, task_dim=512):
260 | super().__init__()
261 |
262 | self.generator = Generator(code_dim, task_dim)
263 |
264 | layers = [PixelNorm()]
265 | for i in range(n_mlp):
266 | layers.append(EqualLinear(code_dim, code_dim))
267 | layers.append(nn.LeakyReLU(0.2))
268 |
269 | self.style = nn.Sequential(*layers)
270 |
271 | def forward(
272 | self,
273 | input,
274 | noise=None,
275 | step=0,
276 | alpha=-1,
277 | mean_style=None,
278 | style_weight=0,
279 | mixing_range=(-1, -1),
280 | task=None,
281 | return_hierarchical=False,
282 | ):
283 | styles = []
284 | if type(input) not in (list, tuple):
285 | input = [input]
286 |
287 | for i in input:
288 | styles.append(self.style(i))
289 |
290 | batch = input[0].shape[0]
291 |
292 | if noise is None:
293 | noise = []
294 |
295 | for i in range(step + 1):
296 | size = 4 * 2 ** i
297 | noise.append(torch.randn(batch, 1, size, size, device=input[0].device))
298 |
299 | if mean_style is not None:
300 | styles_norm = []
301 |
302 | for style in styles:
303 | styles_norm.append(mean_style + style_weight * (style - mean_style))
304 |
305 | styles = styles_norm
306 |
307 | return self.generator(styles, noise, step, alpha, mixing_range=mixing_range, task=task,
308 | return_hierarchical=return_hierarchical)
309 |
310 | def mean_style(self, input):
311 | style = self.style(input).mean(0, keepdim=True)
312 |
313 | return style
314 |
315 |
316 | class Task(nn.Module):
317 | def __init__(self, code_dim=512, n_mlp=4, num_labels=0):
318 | super().__init__()
319 |
320 | layers = [equal_lr(nn.Embedding(num_labels, code_dim))]
321 | for i in range(n_mlp):
322 | layers.append(EqualLinear(code_dim, code_dim))
323 | layers.append(nn.LeakyReLU(0.2))
324 |
325 | self.task = nn.Sequential(*layers)
326 |
327 | def forward(self, x):
328 | return self.task(x)
329 |
330 |
331 | class Discriminator(nn.Module):
332 | def __init__(self, num_classes, fused=True, from_rgb_activate=False):
333 | super().__init__()
334 |
335 | self.progression = nn.ModuleList(
336 | [
337 | ConvBlock(16, 32, 3, 1, downsample=True, fused=fused), # 512
338 | ConvBlock(32, 64, 3, 1, downsample=True, fused=fused), # 256
339 | ConvBlock(64, 128, 3, 1, downsample=True, fused=fused), # 128
340 | ConvBlock(128, 256, 3, 1, downsample=True, fused=fused), # 64
341 | ConvBlock(256, 512, 3, 1, downsample=True), # 32
342 | ConvBlock(512, 512, 3, 1, downsample=True), # 16
343 | ConvBlock(512, 512, 3, 1, downsample=True), # 8
344 | ConvBlock(512, 512, 3, 1, downsample=True), # 4
345 | ConvBlock(513, 512, 3, 1, 4, 0),
346 | ]
347 | )
348 |
349 | def make_from_rgb(out_channel):
350 | if from_rgb_activate:
351 | return nn.Sequential(EqualConv2d(3, out_channel, 1), nn.LeakyReLU(0.2))
352 |
353 | else:
354 | return EqualConv2d(3, out_channel, 1)
355 |
356 | self.from_rgb = nn.ModuleList(
357 | [
358 | make_from_rgb(16),
359 | make_from_rgb(32),
360 | make_from_rgb(64),
361 | make_from_rgb(128),
362 | make_from_rgb(256),
363 | make_from_rgb(512),
364 | make_from_rgb(512),
365 | make_from_rgb(512),
366 | make_from_rgb(512),
367 | ]
368 | )
369 |
370 | self.n_layer = len(self.progression)
371 |
372 | self.final = EqualConv2d(512, num_classes, 1)
373 |
374 | def forward(self, input, c, step=0, alpha=-1, return_hierarchical=False):
375 | hierarchical_out = []
376 | for i in range(step, -1, -1):
377 | index = self.n_layer - i - 1
378 |
379 | if i == step:
380 | out = self.from_rgb[index](input)
381 |
382 | if i == 0:
383 | out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8)
384 | mean_std = out_std.mean()
385 | mean_std = mean_std.expand(out.size(0), 1, 4, 4)
386 | out = torch.cat([out, mean_std], 1)
387 |
388 | out = self.progression[index](out)
389 | if return_hierarchical:
390 | hierarchical_out.append(out)
391 |
392 | if i > 0:
393 | if i == step and 0 <= alpha < 1:
394 | skip_rgb = F.avg_pool2d(input, 2)
395 | skip_rgb = self.from_rgb[index + 1](skip_rgb)
396 |
397 | out = (1 - alpha) * skip_rgb + alpha * out
398 |
399 |
400 | out = self.final(out)
401 | out = out[torch.arange(out.size(0)), c].squeeze(-1)
402 |
403 | if return_hierarchical:
404 | hierarchical_out.append(out)
405 | return hierarchical_out
406 |
407 | return out
408 |
--------------------------------------------------------------------------------
/model/stylegan.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch import nn
4 | from torch.nn import init
5 | from torch.nn import functional as F
6 | from torch.autograd import Function
7 |
8 | from math import sqrt
9 |
10 | import random
11 |
12 |
13 | def init_linear(linear):
14 | init.xavier_normal(linear.weight)
15 | linear.bias.data.zero_()
16 |
17 |
18 | def init_conv(conv, glu=True):
19 | init.kaiming_normal(conv.weight)
20 | if conv.bias is not None:
21 | conv.bias.data.zero_()
22 |
23 |
24 | class EqualLR:
25 | def __init__(self, name):
26 | self.name = name
27 |
28 | def compute_weight(self, module):
29 | weight = getattr(module, self.name + '_orig')
30 | fan_in = weight.data.size(1) * weight.data[0][0].numel()
31 |
32 | return weight * sqrt(2 / fan_in)
33 |
34 | @staticmethod
35 | def apply(module, name):
36 | fn = EqualLR(name)
37 |
38 | weight = getattr(module, name)
39 | del module._parameters[name]
40 | module.register_parameter(name + '_orig', nn.Parameter(weight.data))
41 | module.register_forward_pre_hook(fn)
42 |
43 | return fn
44 |
45 | def __call__(self, module, input):
46 | weight = self.compute_weight(module)
47 | setattr(module, self.name, weight)
48 |
49 |
50 | def equal_lr(module, name='weight'):
51 | EqualLR.apply(module, name)
52 |
53 | return module
54 |
55 |
56 | class FusedUpsample(nn.Module):
57 | def __init__(self, in_channel, out_channel, kernel_size, padding=0):
58 | super().__init__()
59 |
60 | weight = torch.randn(in_channel, out_channel, kernel_size, kernel_size)
61 | bias = torch.zeros(out_channel)
62 |
63 | fan_in = in_channel * kernel_size * kernel_size
64 | self.multiplier = sqrt(2 / fan_in)
65 |
66 | self.weight = nn.Parameter(weight)
67 | self.bias = nn.Parameter(bias)
68 |
69 | self.pad = padding
70 |
71 | def forward(self, input):
72 | weight = F.pad(self.weight * self.multiplier, [1, 1, 1, 1])
73 | weight = (
74 | weight[:, :, 1:, 1:]
75 | + weight[:, :, :-1, 1:]
76 | + weight[:, :, 1:, :-1]
77 | + weight[:, :, :-1, :-1]
78 | ) / 4
79 |
80 | out = F.conv_transpose2d(input, weight, self.bias, stride=2, padding=self.pad)
81 |
82 | return out
83 |
84 |
85 | class FusedDownsample(nn.Module):
86 | def __init__(self, in_channel, out_channel, kernel_size, padding=0):
87 | super().__init__()
88 |
89 | weight = torch.randn(out_channel, in_channel, kernel_size, kernel_size)
90 | bias = torch.zeros(out_channel)
91 |
92 | fan_in = in_channel * kernel_size * kernel_size
93 | self.multiplier = sqrt(2 / fan_in)
94 |
95 | self.weight = nn.Parameter(weight)
96 | self.bias = nn.Parameter(bias)
97 |
98 | self.pad = padding
99 |
100 | def forward(self, input):
101 | weight = F.pad(self.weight * self.multiplier, [1, 1, 1, 1])
102 | weight = (
103 | weight[:, :, 1:, 1:]
104 | + weight[:, :, :-1, 1:]
105 | + weight[:, :, 1:, :-1]
106 | + weight[:, :, :-1, :-1]
107 | ) / 4
108 |
109 | out = F.conv2d(input, weight, self.bias, stride=2, padding=self.pad)
110 |
111 | return out
112 |
113 |
114 | class PixelNorm(nn.Module):
115 | def __init__(self):
116 | super().__init__()
117 |
118 | def forward(self, input):
119 | return input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
120 |
121 |
122 | class BlurFunctionBackward(Function):
123 | @staticmethod
124 | def forward(ctx, grad_output, kernel, kernel_flip):
125 | ctx.save_for_backward(kernel, kernel_flip)
126 |
127 | grad_input = F.conv2d(
128 | grad_output, kernel_flip, padding=1, groups=grad_output.shape[1]
129 | )
130 |
131 | return grad_input
132 |
133 | @staticmethod
134 | def backward(ctx, gradgrad_output):
135 | kernel, kernel_flip = ctx.saved_tensors
136 |
137 | grad_input = F.conv2d(
138 | gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1]
139 | )
140 |
141 | return grad_input, None, None
142 |
143 |
144 | class BlurFunction(Function):
145 | @staticmethod
146 | def forward(ctx, input, kernel, kernel_flip):
147 | ctx.save_for_backward(kernel, kernel_flip)
148 |
149 | output = F.conv2d(input, kernel, padding=1, groups=input.shape[1])
150 |
151 | return output
152 |
153 | @staticmethod
154 | def backward(ctx, grad_output):
155 | kernel, kernel_flip = ctx.saved_tensors
156 |
157 | grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)
158 |
159 | return grad_input, None, None
160 |
161 |
162 | blur = BlurFunction.apply
163 |
164 |
165 | class Blur(nn.Module):
166 | def __init__(self, channel):
167 | super().__init__()
168 |
169 | weight = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
170 | weight = weight.view(1, 1, 3, 3)
171 | weight = weight / weight.sum()
172 | weight_flip = torch.flip(weight, [2, 3])
173 |
174 | self.register_buffer('weight', weight.repeat(channel, 1, 1, 1))
175 | self.register_buffer('weight_flip', weight_flip.repeat(channel, 1, 1, 1))
176 |
177 | def forward(self, input):
178 | return blur(input, self.weight, self.weight_flip)
179 | # return F.conv2d(input, self.weight, padding=1, groups=input.shape[1])
180 |
181 |
182 | class EqualConv2d(nn.Module):
183 | def __init__(self, *args, **kwargs):
184 | super().__init__()
185 |
186 | conv = nn.Conv2d(*args, **kwargs)
187 | conv.weight.data.normal_()
188 | conv.bias.data.zero_()
189 | self.conv = equal_lr(conv)
190 |
191 | def forward(self, input):
192 | return self.conv(input)
193 |
194 |
195 | class EqualLinear(nn.Module):
196 | def __init__(self, in_dim, out_dim):
197 | super().__init__()
198 |
199 | linear = nn.Linear(in_dim, out_dim)
200 | linear.weight.data.normal_()
201 | linear.bias.data.zero_()
202 |
203 | self.linear = equal_lr(linear)
204 |
205 | def forward(self, input):
206 | return self.linear(input)
207 |
208 |
209 | class ConvBlock(nn.Module):
210 | def __init__(
211 | self,
212 | in_channel,
213 | out_channel,
214 | kernel_size,
215 | padding,
216 | kernel_size2=None,
217 | padding2=None,
218 | downsample=False,
219 | fused=False,
220 | ):
221 | super().__init__()
222 |
223 | pad1 = padding
224 | pad2 = padding
225 | if padding2 is not None:
226 | pad2 = padding2
227 |
228 | kernel1 = kernel_size
229 | kernel2 = kernel_size
230 | if kernel_size2 is not None:
231 | kernel2 = kernel_size2
232 |
233 | self.conv1 = nn.Sequential(
234 | EqualConv2d(in_channel, out_channel, kernel1, padding=pad1),
235 | nn.LeakyReLU(0.2),
236 | )
237 |
238 | if downsample:
239 | if fused:
240 | self.conv2 = nn.Sequential(
241 | Blur(out_channel),
242 | FusedDownsample(out_channel, out_channel, kernel2, padding=pad2),
243 | nn.LeakyReLU(0.2),
244 | )
245 |
246 | else:
247 | self.conv2 = nn.Sequential(
248 | Blur(out_channel),
249 | EqualConv2d(out_channel, out_channel, kernel2, padding=pad2),
250 | nn.AvgPool2d(2),
251 | nn.LeakyReLU(0.2),
252 | )
253 |
254 | else:
255 | self.conv2 = nn.Sequential(
256 | EqualConv2d(out_channel, out_channel, kernel2, padding=pad2),
257 | nn.LeakyReLU(0.2),
258 | )
259 |
260 | def forward(self, input):
261 | out = self.conv1(input)
262 | out = self.conv2(out)
263 |
264 | return out
265 |
266 |
267 | class AdaptiveInstanceNorm(nn.Module):
268 | def __init__(self, in_channel, style_dim):
269 | super().__init__()
270 |
271 | self.norm = nn.InstanceNorm2d(in_channel)
272 | self.style = EqualLinear(style_dim, in_channel * 2)
273 |
274 | self.style.linear.bias.data[:in_channel] = 1
275 | self.style.linear.bias.data[in_channel:] = 0
276 |
277 | def forward(self, input, style):
278 | style = self.style(style).unsqueeze(2).unsqueeze(3)
279 | gamma, beta = style.chunk(2, 1)
280 |
281 | out = self.norm(input)
282 | out = gamma * out + beta
283 |
284 | return out
285 |
286 |
287 | class NoiseInjection(nn.Module):
288 | def __init__(self, channel):
289 | super().__init__()
290 |
291 | self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
292 |
293 | def forward(self, image, noise):
294 | return image + self.weight * noise
295 |
296 |
297 | class ConstantInput(nn.Module):
298 | def __init__(self, channel, size=4):
299 | super().__init__()
300 |
301 | self.input = nn.Parameter(torch.randn(1, channel, size, size))
302 |
303 | def forward(self, input):
304 | batch = input.shape[0]
305 | out = self.input.repeat(batch, 1, 1, 1)
306 |
307 | return out
308 |
309 |
310 | class StyledConvBlock(nn.Module):
311 | def __init__(
312 | self,
313 | in_channel,
314 | out_channel,
315 | kernel_size=3,
316 | padding=1,
317 | style_dim=512,
318 | initial=False,
319 | upsample=False,
320 | fused=False,
321 | ):
322 | super().__init__()
323 |
324 | if initial:
325 | self.conv1 = ConstantInput(in_channel)
326 |
327 | else:
328 | if upsample:
329 | if fused:
330 | self.conv1 = nn.Sequential(
331 | FusedUpsample(
332 | in_channel, out_channel, kernel_size, padding=padding
333 | ),
334 | Blur(out_channel),
335 | )
336 |
337 | else:
338 | self.conv1 = nn.Sequential(
339 | nn.Upsample(scale_factor=2, mode='nearest'),
340 | EqualConv2d(
341 | in_channel, out_channel, kernel_size, padding=padding
342 | ),
343 | Blur(out_channel),
344 | )
345 |
346 | else:
347 | self.conv1 = EqualConv2d(
348 | in_channel, out_channel, kernel_size, padding=padding
349 | )
350 |
351 | self.noise1 = equal_lr(NoiseInjection(out_channel))
352 | self.adain1 = AdaptiveInstanceNorm(out_channel, style_dim)
353 | self.lrelu1 = nn.LeakyReLU(0.2)
354 |
355 | self.conv2 = EqualConv2d(out_channel, out_channel, kernel_size, padding=padding)
356 | self.noise2 = equal_lr(NoiseInjection(out_channel))
357 | self.adain2 = AdaptiveInstanceNorm(out_channel, style_dim)
358 | self.lrelu2 = nn.LeakyReLU(0.2)
359 |
360 | def forward(self, input, style, noise):
361 | out = self.conv1(input)
362 | out = self.noise1(out, noise)
363 | out = self.lrelu1(out)
364 | out = self.adain1(out, style)
365 |
366 | out = self.conv2(out)
367 | out = self.noise2(out, noise)
368 | out = self.lrelu2(out)
369 | out = self.adain2(out, style)
370 |
371 | return out
372 |
373 |
374 | class Generator(nn.Module):
375 | def __init__(self, code_dim, fused=True):
376 | super().__init__()
377 |
378 | self.progression = nn.ModuleList(
379 | [
380 | StyledConvBlock(512, 512, 3, 1, initial=True), # 4
381 | StyledConvBlock(512, 512, 3, 1, upsample=True), # 8
382 | StyledConvBlock(512, 512, 3, 1, upsample=True), # 16
383 | StyledConvBlock(512, 512, 3, 1, upsample=True), # 32
384 | StyledConvBlock(512, 256, 3, 1, upsample=True), # 64
385 | StyledConvBlock(256, 128, 3, 1, upsample=True, fused=fused), # 128
386 | StyledConvBlock(128, 64, 3, 1, upsample=True, fused=fused), # 256
387 | StyledConvBlock(64, 32, 3, 1, upsample=True, fused=fused), # 512
388 | StyledConvBlock(32, 16, 3, 1, upsample=True, fused=fused), # 1024
389 | ]
390 | )
391 |
392 | self.to_rgb = nn.ModuleList(
393 | [
394 | EqualConv2d(512, 3, 1),
395 | EqualConv2d(512, 3, 1),
396 | EqualConv2d(512, 3, 1),
397 | EqualConv2d(512, 3, 1),
398 | EqualConv2d(256, 3, 1),
399 | EqualConv2d(128, 3, 1),
400 | EqualConv2d(64, 3, 1),
401 | EqualConv2d(32, 3, 1),
402 | EqualConv2d(16, 3, 1),
403 | ]
404 | )
405 |
406 | # self.blur = Blur()
407 |
408 | def forward(self, style, noise, step=0, alpha=-1, mixing_range=(-1, -1), return_hierarchical=False):
409 | out = noise[0]
410 | hierarchical_out = []
411 |
412 | if len(style) < 2:
413 | inject_index = [len(self.progression) + 1]
414 |
415 | else:
416 | inject_index = sorted(random.sample(list(range(step)), len(style) - 1))
417 |
418 | crossover = 0
419 |
420 | for i, (conv, to_rgb) in enumerate(zip(self.progression, self.to_rgb)):
421 | if mixing_range == (-1, -1):
422 | if crossover < len(inject_index) and i > inject_index[crossover]:
423 | crossover = min(crossover + 1, len(style))
424 |
425 | style_step = style[crossover]
426 |
427 | else:
428 | if mixing_range[0] <= i <= mixing_range[1]:
429 | style_step = style[1]
430 |
431 | else:
432 | style_step = style[0]
433 |
434 | if i > 0 and step > 0:
435 | out_prev = out
436 |
437 | out = conv(out, style_step, noise[i])
438 | if return_hierarchical:
439 | hierarchical_out.append(out)
440 |
441 | if i == step:
442 | out = to_rgb(out)
443 |
444 | if i > 0 and 0 <= alpha < 1:
445 | skip_rgb = self.to_rgb[i - 1](out_prev)
446 | skip_rgb = F.interpolate(skip_rgb, scale_factor=2, mode='nearest')
447 | out = (1 - alpha) * skip_rgb + alpha * out
448 |
449 | break
450 |
451 | if return_hierarchical:
452 | return out, hierarchical_out
453 | else:
454 | return out
455 |
456 |
457 | class StyledGenerator(nn.Module):
458 | def __init__(self, code_dim=512, n_mlp=8, c_dim=0):
459 | super().__init__()
460 |
461 | self.generator = Generator(code_dim)
462 |
463 | layers = [PixelNorm()]
464 | for i in range(n_mlp):
465 | layers.append(EqualLinear(code_dim + (code_dim if i == 0 and c_dim > 0 else 0), code_dim))
466 | layers.append(nn.LeakyReLU(0.2))
467 |
468 | self.style = nn.Sequential(*layers)
469 | self.embed = EqualLinear(c_dim, code_dim) if c_dim > 0 else None
470 |
471 | def forward(
472 | self,
473 | input,
474 | noise=None,
475 | step=0,
476 | alpha=-1,
477 | mean_style=None,
478 | style_weight=0,
479 | mixing_range=(-1, -1),
480 | return_hierarchical=False,
481 | c=None,
482 | ):
483 | styles = []
484 | if type(input) not in (list, tuple):
485 | input = [input]
486 |
487 | for i in input:
488 | if c is not None:
489 | i = torch.cat([i, self.embed(c)], dim=1)
490 | styles.append(self.style(i))
491 |
492 | batch = input[0].shape[0]
493 |
494 | if noise is None:
495 | noise = []
496 |
497 | for i in range(step + 1):
498 | size = 4 * 2 ** i
499 | noise.append(torch.randn(batch, 1, size, size, device=input[0].device))
500 |
501 | if mean_style is not None:
502 | styles_norm = []
503 |
504 | for style in styles:
505 | styles_norm.append(mean_style + style_weight * (style - mean_style))
506 |
507 | styles = styles_norm
508 |
509 | return self.generator(styles, noise, step, alpha, mixing_range=mixing_range,
510 | return_hierarchical=return_hierarchical)
511 |
512 | def mean_style(self, input):
513 | style = self.style(input).mean(0, keepdim=True)
514 |
515 | return style
516 |
517 |
518 | class Discriminator(nn.Module):
519 | def __init__(self, fused=True, from_rgb_activate=False, num_labels=0):
520 | super().__init__()
521 |
522 | self.progression = nn.ModuleList(
523 | [
524 | ConvBlock(16, 32, 3, 1, downsample=True, fused=fused), # 512
525 | ConvBlock(32, 64, 3, 1, downsample=True, fused=fused), # 256
526 | ConvBlock(64, 128, 3, 1, downsample=True, fused=fused), # 128
527 | ConvBlock(128, 256, 3, 1, downsample=True, fused=fused), # 64
528 | ConvBlock(256, 512, 3, 1, downsample=True), # 32
529 | ConvBlock(512, 512, 3, 1, downsample=True), # 16
530 | ConvBlock(512, 512, 3, 1, downsample=True), # 8
531 | ConvBlock(512, 512, 3, 1, downsample=True), # 4
532 | ConvBlock(513, 512, 3, 1, 4, 0),
533 | ]
534 | )
535 |
536 | def make_from_rgb(out_channel):
537 | if from_rgb_activate:
538 | return nn.Sequential(EqualConv2d(3, out_channel, 1), nn.LeakyReLU(0.2))
539 |
540 | else:
541 | return EqualConv2d(3, out_channel, 1)
542 |
543 | self.from_rgb = nn.ModuleList(
544 | [
545 | make_from_rgb(16),
546 | make_from_rgb(32),
547 | make_from_rgb(64),
548 | make_from_rgb(128),
549 | make_from_rgb(256),
550 | make_from_rgb(512),
551 | make_from_rgb(512),
552 | make_from_rgb(512),
553 | make_from_rgb(512),
554 | ]
555 | )
556 |
557 | # self.blur = Blur()
558 |
559 | self.n_layer = len(self.progression)
560 |
561 | self.linear = EqualLinear(512, 1 if num_labels == 0 else num_labels)
562 |
563 | def forward(self, input, step=0, alpha=-1, c=None):
564 | for i in range(step, -1, -1):
565 | index = self.n_layer - i - 1
566 |
567 | if i == step:
568 | out = self.from_rgb[index](input)
569 |
570 | if i == 0:
571 | out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8)
572 | mean_std = out_std.mean()
573 | mean_std = mean_std.expand(out.size(0), 1, 4, 4)
574 | out = torch.cat([out, mean_std], 1)
575 |
576 | out = self.progression[index](out)
577 |
578 | if i > 0:
579 | if i == step and 0 <= alpha < 1:
580 | skip_rgb = F.avg_pool2d(input, 2)
581 | skip_rgb = self.from_rgb[index + 1](skip_rgb)
582 |
583 | out = (1 - alpha) * skip_rgb + alpha * out
584 |
585 | out = out.squeeze(2).squeeze(2)
586 | # print(input.size(), out.size(), step)
587 | out = self.linear(out)
588 |
589 | if c is not None:
590 | out = (out * c).sum(dim=1) # keepdims?
591 |
592 | return out
593 |
--------------------------------------------------------------------------------
/prepare_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from io import BytesIO
3 | import multiprocessing
4 | from functools import partial
5 |
6 | from PIL import Image
7 | import lmdb
8 | from tqdm import tqdm
9 | from torchvision import datasets
10 | from torchvision.transforms import functional as trans_fn
11 |
12 |
13 | def resize_and_convert(img, size, quality=100):
14 | img = trans_fn.resize(img, size, Image.LANCZOS)
15 | img = trans_fn.center_crop(img, size)
16 | buffer = BytesIO()
17 | img.save(buffer, format='jpeg', quality=quality)
18 | val = buffer.getvalue()
19 |
20 | return val
21 |
22 |
23 | def resize_multiple(img, sizes=(8, 16, 32, 64, 128, 256, 512, 1024), quality=100):
24 | imgs = []
25 |
26 | for size in sizes:
27 | imgs.append(resize_and_convert(img, size, quality))
28 |
29 | return imgs
30 |
31 |
32 | def resize_worker(img_file, sizes):
33 | i, file, label = img_file
34 | img = Image.open(file)
35 | img = img.convert('RGB')
36 | out = resize_multiple(img, sizes=sizes)
37 |
38 | return i, out, label
39 |
40 |
41 | def prepare(transaction, dataset, n_worker, sizes=(8, 16, 32, 64, 128, 256, 512, 1024)):
42 | resize_fn = partial(resize_worker, sizes=sizes)
43 |
44 | files = sorted(dataset.imgs, key=lambda x: x[0])
45 | files = [(i, file, label) for i, (file, label) in enumerate(files)]
46 | num_classes = len(set(label for i, f, label in files))
47 | total = 0
48 |
49 | with multiprocessing.Pool(n_worker) as pool:
50 | for i, imgs, label in tqdm(pool.imap_unordered(resize_fn, files), total=len(files)):
51 | for size, img in zip(sizes, imgs):
52 | key = f'{size}-{str(i).zfill(5)}'.encode('utf-8')
53 | transaction.put(key, img)
54 |
55 | key = f'label-{str(i).zfill(5)}'.encode('utf-8')
56 | transaction.put(key, str(label).encode('utf-8'))
57 |
58 | total += 1
59 |
60 | transaction.put('length'.encode('utf-8'), str(total).encode('utf-8'))
61 | transaction.put('num_classes'.encode('utf-8'), str(num_classes).encode('utf-8'))
62 |
63 |
64 | if __name__ == '__main__':
65 | parser = argparse.ArgumentParser()
66 | parser.add_argument('--out', type=str)
67 | parser.add_argument('--n_worker', type=int, default=8)
68 | parser.add_argument('path', type=str)
69 | parser.add_argument('--labels', action='store_true', help='Add labels too')
70 | parser.add_argument('--sizes', nargs='+', type=int, default=(8, 16, 32, 64, 128, 256, 512, 1024))
71 | parser.add_argument('--min_size', type=int, default=None, help='Minimum image size to not drop it')
72 |
73 | args = parser.parse_args()
74 |
75 | if args.min_size is None:
76 | valid_file = None
77 | else:
78 | valid_file = lambda x: min(Image.open(x).size) >= args.min_size
79 |
80 | imgset = datasets.ImageFolder(args.path, is_valid_file=valid_file)
81 |
82 | with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
83 | with env.begin(write=True) as txn:
84 | prepare(txn, imgset, args.n_worker, args.sizes)
85 |
--------------------------------------------------------------------------------
/pretrained_converter.py:
--------------------------------------------------------------------------------
1 | """Utilities to fiddle with the models. Convert a pre-trained StyleGAN. Freeze all but hypernetwork layers."""
2 | from math import sqrt
3 |
4 |
5 | def conv(num, layer=None, parameter='weight', fused=False):
6 | return f'conv{num}.{"" if layer is None else str(layer) + "."}{"" if fused else "conv."}{parameter}'
7 |
8 |
9 | def equalization_norm(w):
10 | fan_in = w.data.size(1) * w.data[0][0].numel()
11 | w = w * sqrt(2 / fan_in)
12 | return w
13 |
14 |
15 | def normalize_conv2d_weight(w):
16 | w = equalization_norm(w)
17 | w_mu = w.mean((2, 3), keepdim=True)
18 | w_std = w.std((2, 3), keepdim=True)
19 | w = (w - w_mu) / w_std
20 | return w
21 |
22 |
23 | def normalize_conv2d_bias(w):
24 | w_mu = w.mean()
25 | w_std = w.std()
26 | w = (w - w_mu) / w_std
27 | return w
28 |
29 |
30 | def rename_and_norm(state_new, k, parameter, fused=False):
31 | minus_layers = 1 if fused else 2
32 |
33 | if parameter == 'weight':
34 | state_new[k.rsplit('.', minus_layers)[0] + '.W'] = normalize_conv2d_weight(state_new.pop(k))
35 | if parameter == 'bias':
36 | state_new[k.rsplit('.', minus_layers)[0] + '.b'] = normalize_conv2d_bias(state_new.pop(k))
37 |
38 |
39 | def convert_generator(state):
40 | state_new = state.copy() # should be a shallow copy. I.e. tensors in the dict values should be referenced.
41 | for k, v in state.items():
42 | parts = k.split('.')
43 | module = parts[0]
44 |
45 | if module == 'generator':
46 | module_gen = parts[1]
47 | if module_gen == 'progression':
48 | s = k.split('.', 3)[-1]
49 |
50 | if s == conv(1, 0, parameter='weight', fused=True):
51 | rename_and_norm(state_new, k, 'weight', fused=True)
52 |
53 | if s == conv(1, 0, parameter='bias', fused=True):
54 | rename_and_norm(state_new, k, 'bias', fused=True)
55 |
56 | if s == conv(1, 1, parameter='weight_orig'):
57 | rename_and_norm(state_new, k, 'weight')
58 |
59 | if s == conv(1, 1, parameter='bias'):
60 | rename_and_norm(state_new, k, 'bias')
61 |
62 | if s == conv(2, parameter='weight_orig'):
63 | rename_and_norm(state_new, k, 'weight')
64 |
65 | if s == conv(2, parameter='bias'):
66 | rename_and_norm(state_new, k, 'bias')
67 |
68 | state_new['generator.progression.0.conv1.0.input'] = state_new.pop('generator.progression.0.conv1.input')
69 |
70 | return state_new
71 |
72 |
73 | def assert_loaded_keys(missing, unexpected):
74 | def check(k):
75 | return not (
76 | k.endswith('gamma') or k.endswith('beta') or k.endswith('bias_beta')
77 | or k.find('task') >= 0 or k.startswith('style.1.linear.') or k.startswith('final.')
78 | )
79 | missing = [k for k in missing if check(k)]
80 | assert len(missing) == 0
81 | unexpected = [k for k in unexpected if not k.startswith('linear.')]
82 | assert len(unexpected) == 0
83 |
84 |
85 | def freeze(model):
86 | for param in model.parameters():
87 | param.requires_grad = False
88 |
89 |
90 | def freeze_layers(*models):
91 | """Freeze selected generator and discriminator weights"""
92 | from model.hyper_mod import ConstantInput, StyledGenerator, Discriminator
93 | from torch.nn import DataParallel
94 | from torch.nn.parallel.distributed import DistributedDataParallel
95 |
96 | for model in models:
97 | if isinstance(model, DataParallel) or isinstance(model, DistributedDataParallel):
98 | model = model.module
99 |
100 | if isinstance(model, StyledGenerator):
101 | for styled_block in model.generator.progression:
102 | if isinstance(styled_block.conv1[0], ConstantInput):
103 | freeze(styled_block.conv1)
104 |
105 | # no need to freeze conv1 nor conv2 bc parameters have been already converted to buffers
106 | freeze(styled_block.noise1)
107 | freeze(styled_block.adain1)
108 | freeze(styled_block.noise2)
109 | freeze(styled_block.adain2)
110 |
111 | for to_rgb_i in model.generator.to_rgb:
112 | freeze(to_rgb_i)
113 |
114 | assert len(model.style) == 8 * 2 + 1 # 8 layers
115 | freeze(model.style)
116 | elif isinstance(model, Discriminator):
117 | for from_rgb_i in model.from_rgb:
118 | freeze(from_rgb_i)
119 | elif model is not None:
120 | print(type(model))
121 | raise ValueError
122 |
--------------------------------------------------------------------------------
/self_alignment.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | import torch
4 | from pathlib import Path
5 | from tqdm import trange
6 | from torch import nn, optim
7 | from torchvision.utils import make_grid
8 | from torch.utils.tensorboard import SummaryWriter
9 |
10 | from model.stylegan import StyledGenerator as OriginalStyledGenerator
11 | from utils import save_hparams, get_run_id
12 |
13 |
14 | def get_z(batch_size, code_size, mixing, device, fixed_seed=False):
15 | rng = None
16 | if fixed_seed:
17 | rng = torch.Generator(device=device)
18 | rng.manual_seed(2147483647)
19 | if mixing and random.random() < 0.9:
20 | return [torch.randn(batch_size, code_size, generator=rng, device=device),
21 | torch.randn(batch_size, code_size, generator=rng, device=device)]
22 | else:
23 | return torch.randn(batch_size, code_size, generator=rng, device=device)
24 |
25 |
26 | def get_noise(step, batch_size, device, fixed_seed=False):
27 | rng = None
28 | if fixed_seed:
29 | rng = torch.Generator(device=device)
30 | rng.manual_seed(2147483647)
31 | noise = []
32 | for i in range(step + 1):
33 | size = 4 * 2 ** i
34 | noise.append(torch.randn(batch_size, 1, size, size, generator=rng, device=device))
35 |
36 | return noise
37 |
38 |
39 | def log_images(writer, i, generator, task, generator_original, code_size, step, device, batch_size=9):
40 | with torch.no_grad():
41 | z = get_z(batch_size=batch_size, code_size=code_size, mixing=False, device=device, fixed_seed=True)
42 | noise = get_noise(step, batch_size, device, fixed_seed=True)
43 |
44 | images_orig = generator_original(z, noise=noise, step=step, alpha=1, return_hierarchical=False)
45 | writer.add_image('sample_orig', make_grid(images_orig, nrow=3, normalize=True, range=(-1, 1)), i)
46 |
47 | task_in = task(torch.zeros(batch_size, device=device, dtype=torch.long))
48 | images = generator(z, noise=noise, step=step, alpha=1, task=task_in, return_hierarchical=False)
49 | writer.add_image('sample_new', make_grid(images, nrow=3, normalize=True, range=(-1, 1)), i)
50 |
51 |
52 | def train(generator, task, generator_original, g_optimizer, criterion, iterations, args,
53 | batch_size=1, code_size=512, step=6, mixing=False, name=None, device='cpu', checkpoint_interval=5_000):
54 | name = 'unnamed' if name is None else name
55 | run_dir = f'{args.run_dir}/{get_run_id(args.run_dir):05d}-{name}'
56 | Path(run_dir).mkdir(parents=True)
57 | save_hparams(run_dir, args)
58 | writer = SummaryWriter(run_dir)
59 | alpha = 1
60 | for i in trange(iterations):
61 | g_optimizer.zero_grad()
62 | z = get_z(batch_size, code_size, mixing, device)
63 | noise = get_noise(step, batch_size, device)
64 | with torch.no_grad():
65 | _, g_l_orig = generator_original(z, noise=noise, step=step, alpha=alpha, return_hierarchical=True)
66 | task_in = task(torch.zeros(batch_size, device=device, dtype=torch.long))
67 | _, g_l = generator(z, noise=noise, step=step, alpha=alpha, task=task_in, return_hierarchical=True)
68 |
69 | assert len(g_l) == len(g_l_orig)
70 | g_loss = 0
71 | for input, target in zip(g_l, g_l_orig):
72 | g_loss += criterion(input, target)
73 |
74 | g_loss.backward()
75 | g_optimizer.step()
76 |
77 | writer.add_scalar('loss/g', g_loss, i)
78 | if i % args.log_interval == 0:
79 | log_images(writer, i, generator, task, generator_original, code_size, step, device)
80 |
81 | if i % checkpoint_interval == 0:
82 | torch.save({
83 | 'generator': generator.state_dict(),
84 | 'task': task.state_dict(),
85 | }, f'{run_dir}/{i:06d}.model')
86 |
87 |
88 | def get_args():
89 | parser = argparse.ArgumentParser(description='Self-alignment')
90 |
91 | parser.add_argument('checkpoint', type=str)
92 | parser.add_argument('--step', default=6, type=int, help='6 = 256px')
93 | parser.add_argument('--iterations', type=int, default=50_000, help='number of samples used')
94 | parser.add_argument('--lr', default=0.002, type=float, help='learning rate')
95 | parser.add_argument('--code_size', default=512, type=int)
96 | parser.add_argument('--n_mlp_style', default=8, type=int)
97 | parser.add_argument('--n_mlp_task', default=4, type=int)
98 | parser.add_argument('--task_size', default=64, type=int)
99 | parser.add_argument('--batch_size', default=32, type=int, help='max image size')
100 | parser.add_argument('--mixing', action='store_true', help='use mixing regularization')
101 | parser.add_argument('--device', default='cuda', type=str)
102 | parser.add_argument('--name', type=str, default=None)
103 | parser.add_argument('--run_dir', type=str, default='data/training-runs/self_align')
104 | parser.add_argument('--checkpoint_interval', type=int, default=5_000, help='number of samples used')
105 | parser.add_argument('--log_interval', type=int, default=500)
106 |
107 | return parser.parse_args()
108 |
109 |
110 | def main(args):
111 | from model.hyper_mod import StyledGenerator, Task
112 | from pretrained_converter import convert_generator, assert_loaded_keys, freeze_layers
113 |
114 | args.origin = 'self_align'
115 | g_running_orig = OriginalStyledGenerator(args.code_size).to(args.device)
116 | g_running_orig.train(False)
117 | generator = StyledGenerator(code_dim=args.code_size, task_dim=args.task_size, n_mlp=args.n_mlp_style).to(args.device)
118 | freeze_layers(generator)
119 | task = Task(args.task_size, n_mlp=args.n_mlp_task, num_labels=1).to(args.device)
120 |
121 | g_optimizer = optim.Adam(generator.generator.parameters(), lr=args.lr, betas=(0.0, 0.99))
122 | g_optimizer.add_param_group(
123 | {
124 | 'params': generator.style.parameters(),
125 | 'lr': args.lr * 0.01,
126 | 'mult': 0.01,
127 | }
128 | )
129 | g_optimizer.add_param_group(
130 | {
131 | 'params': task.parameters(),
132 | 'lr': args.lr * 0.01,
133 | 'mult': 0.001,
134 | }
135 | )
136 |
137 | ckpt = torch.load(args.checkpoint)
138 | g_running_orig.load_state_dict(ckpt['g_running'])
139 | missing, unexpected = generator.load_state_dict(convert_generator(ckpt['g_running']), strict=False)
140 | assert_loaded_keys(missing, unexpected)
141 | del ckpt
142 |
143 | criterion = nn.L1Loss()
144 |
145 | train(generator=generator, task=task, generator_original=g_running_orig, g_optimizer=g_optimizer,
146 | criterion=criterion, iterations=args.iterations, args=args, batch_size=args.batch_size,
147 | code_size=args.code_size, step=args.step, mixing=args.mixing, name=args.name, device=args.device,
148 | checkpoint_interval=args.checkpoint_interval)
149 |
150 |
151 | if __name__ == '__main__':
152 | main(get_args())
153 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | import random
4 | import math
5 |
6 | from tqdm import tqdm
7 |
8 | import torch
9 | from torch import nn, optim
10 | from torch.nn import functional as F
11 | from torch.autograd import grad
12 | from torchvision import transforms
13 | from torch.utils.tensorboard import SummaryWriter
14 |
15 | from dataset import MultiResolutionDataset
16 | from model.hyper_mod import StyledGenerator, Task, Discriminator
17 | from pretrained_converter import convert_generator, assert_loaded_keys, freeze_layers
18 | from barlow_twins import BarlowTwins, DiffTransform as ContrTransform, compute_contrastive
19 | from utils import seed_everything, save_hparams, sample_data, adjust_lr, not_frozen_params, requires_grad, accumulate, log_images, get_run_id
20 |
21 |
22 | def train(args, dataset, generator, discriminator, task, g_running, t_running, g_optimizer, d_optimizer,
23 | contr_transform, contr_criterion):
24 | step = int(math.log2(args.init_size)) - 2
25 | resolution = 4 * 2 ** step
26 | loader = sample_data(
27 | dataset, args.batch.get(resolution, args.batch_default), resolution
28 | )
29 | data_loader = iter(loader)
30 |
31 | adjust_lr(g_optimizer, args.lr.get(resolution, 0.002))
32 | adjust_lr(d_optimizer, args.lr.get(resolution, 0.002))
33 |
34 | pbar = tqdm(range(args.iterations))
35 |
36 | grad_map_generator = not_frozen_params(generator)
37 | grad_map_discriminator = not_frozen_params(discriminator)
38 | requires_grad(generator, False)
39 | requires_grad(discriminator, True, grad_map=grad_map_discriminator)
40 |
41 | disc_loss_val = 0
42 | gen_loss_val = 0
43 | grad_loss_val = 0
44 |
45 | used_sample = 0
46 | num_imgs = 0
47 |
48 | max_step = int(math.log2(args.max_size)) - 2
49 | final_progress = False
50 |
51 | run_dir = f'{args.run_dir}/{get_run_id(args.run_dir):05d}-{args.name}'
52 | Path(run_dir).mkdir(parents=True)
53 | save_hparams(run_dir, args)
54 | if not args.no_tb:
55 | writer = SummaryWriter(run_dir)
56 |
57 | for i in pbar:
58 | d_optimizer.zero_grad()
59 |
60 | if args.const_alpha is not None:
61 | alpha = args.const_alpha
62 | else:
63 | alpha = min(1, 1 / args.phase * (used_sample + 1))
64 |
65 | if (resolution == args.init_size and args.ckpt is None) or final_progress:
66 | if args.const_alpha is not None:
67 | alpha = args.const_alpha
68 | else:
69 | alpha = 1
70 |
71 | if used_sample > args.phase * 2:
72 | used_sample = 0
73 | step += 1
74 |
75 | if step > max_step:
76 | step = max_step
77 | final_progress = True
78 | ckpt_step = step + 1
79 |
80 | else:
81 | if args.const_alpha is not None:
82 | alpha = args.const_alpha
83 | else:
84 | alpha = 0
85 | ckpt_step = step
86 |
87 | resolution = 4 * 2 ** step
88 |
89 | loader = sample_data(
90 | dataset, args.batch.get(resolution, args.batch_default), resolution
91 | )
92 | data_loader = iter(loader)
93 |
94 | if not args.no_checkpointing:
95 | torch.save(
96 | {
97 | 'generator': generator.module.state_dict(),
98 | 'discriminator': discriminator.module.state_dict(),
99 | 'g_optimizer': g_optimizer.state_dict(),
100 | 'd_optimizer': d_optimizer.state_dict(),
101 | 'g_running': g_running.state_dict(),
102 | 'task_g': task.module.state_dict(),
103 | 't_running_g': t_running.state_dict(),
104 | },
105 | f'{run_dir}/train_step-{ckpt_step}.model',
106 | )
107 |
108 | adjust_lr(g_optimizer, args.lr.get(resolution, 0.002))
109 | adjust_lr(d_optimizer, args.lr.get(resolution, 0.002))
110 |
111 | try:
112 | real_image, real_label = next(data_loader)
113 |
114 | except (OSError, StopIteration):
115 | data_loader = iter(loader)
116 | real_image, real_label = next(data_loader)
117 |
118 | used_sample += real_image.shape[0]
119 | num_imgs += real_image.shape[0]
120 |
121 | b_size = real_image.size(0)
122 | real_image = real_image.cuda()
123 | real_label = real_label.cuda()
124 |
125 | if args.loss == 'wgan-gp':
126 | real_predict = discriminator(real_image, c=real_label, step=step, alpha=alpha)
127 | real_predict = real_predict.mean() - 0.001 * (real_predict ** 2).mean()
128 | (-real_predict).backward()
129 |
130 | elif args.loss == 'r1':
131 | real_image.requires_grad = True
132 | real_scores = discriminator(real_image, c=real_label, step=step, alpha=alpha)
133 | real_predict = F.softplus(-real_scores).mean()
134 | real_predict.backward(retain_graph=True)
135 |
136 | grad_real = grad(
137 | outputs=real_scores.sum(), inputs=real_image, create_graph=True
138 | )[0]
139 | grad_penalty = (
140 | grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2
141 | ).mean()
142 | grad_penalty = 10 / 2 * grad_penalty
143 | grad_penalty.backward()
144 | if i % 10 == 0:
145 | grad_loss_val = grad_penalty.item()
146 |
147 | if args.contrastive:
148 | contr_loss_real = compute_contrastive(discriminator, real_image, contr_criterion, contr_transform, c=real_label,
149 | step=step, alpha=alpha, weighting=args.contr_weighting)
150 | contr_loss_real.backward()
151 |
152 | if args.mixing and random.random() < 0.9:
153 | gen_in11, gen_in12, gen_in21, gen_in22 = torch.randn(
154 | 4, b_size, args.code_size, device='cuda'
155 | ).chunk(4, 0)
156 | gen_in1 = [gen_in11.squeeze(0), gen_in12.squeeze(0)]
157 | gen_in2 = [gen_in21.squeeze(0), gen_in22.squeeze(0)]
158 |
159 | else:
160 | gen_in1, gen_in2 = torch.randn(2, b_size, args.code_size, device='cuda').chunk(
161 | 2, 0
162 | )
163 | gen_in1 = gen_in1.squeeze(0)
164 | gen_in2 = gen_in2.squeeze(0)
165 |
166 | fake_image = generator(gen_in1, step=step, alpha=alpha, task=task(real_label))
167 | fake_predict = discriminator(fake_image, c=real_label, step=step, alpha=alpha)
168 |
169 | if args.loss == 'wgan-gp':
170 | fake_predict = fake_predict.mean()
171 | fake_predict.backward()
172 |
173 | eps = torch.rand(b_size, 1, 1, 1).cuda()
174 | x_hat = eps * real_image.data + (1 - eps) * fake_image.data
175 | x_hat.requires_grad = True
176 | hat_predict = discriminator(x_hat, c=real_label, step=step, alpha=alpha)
177 | grad_x_hat = grad(
178 | outputs=hat_predict.sum(), inputs=x_hat, create_graph=True
179 | )[0]
180 | grad_penalty = (
181 | (grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1) - 1) ** 2
182 | ).mean()
183 | grad_penalty = 10 * grad_penalty
184 | grad_penalty.backward()
185 | if i % 10 == 0:
186 | grad_loss_val = grad_penalty.item()
187 | disc_loss_val = (-real_predict + fake_predict).item()
188 |
189 | elif args.loss == 'r1':
190 | fake_predict = F.softplus(fake_predict).mean()
191 | fake_predict.backward()
192 | if i % 10 == 0:
193 | disc_loss_val = (real_predict + fake_predict).item()
194 |
195 | if args.contrastive:
196 | fake_image = generator(gen_in1, step=step, alpha=alpha, task=task(real_label))
197 | contr_loss_fake = compute_contrastive(discriminator, fake_image, contr_criterion, contr_transform,
198 | c=real_label, step=step, alpha=alpha, weighting=args.contr_weighting)
199 | contr_loss_fake.backward()
200 |
201 | d_optimizer.step()
202 |
203 | if (i + 1) % args.n_critic == 0:
204 | g_optimizer.zero_grad()
205 |
206 | requires_grad(generator, True, grad_map=grad_map_generator)
207 | requires_grad(discriminator, False)
208 |
209 | fake_image = generator(gen_in2, step=step, alpha=alpha, task=task(real_label))
210 |
211 | predict = discriminator(fake_image, c=real_label, step=step, alpha=alpha)
212 |
213 | if args.loss == 'wgan-gp':
214 | loss = -predict.mean()
215 |
216 | elif args.loss == 'r1':
217 | loss = F.softplus(-predict).mean()
218 |
219 | if i % 10 == 0:
220 | gen_loss_val = loss.item()
221 |
222 | loss.backward()
223 | g_optimizer.step()
224 | accumulate(g_running, generator.module, grad_map=grad_map_generator)
225 | accumulate(t_running, task.module)
226 |
227 | requires_grad(generator, False)
228 | requires_grad(discriminator, True, grad_map=grad_map_discriminator)
229 |
230 | if i % args.img_log_interval == 0:
231 | log_images(f'{run_dir}/fakes{i:06d}.png', alpha, args, dataset, resolution, real_image.size(0), step, g_running, t_running)
232 |
233 | if i % args.checkpoint_interval == 0 and i >= args.begin_checkpointing_at and not args.no_checkpointing:
234 | torch.save(
235 | {
236 | 'generator': generator.module.state_dict(),
237 | 'discriminator': discriminator.module.state_dict(),
238 | # 'g_optimizer': g_optimizer.state_dict(),
239 | # 'd_optimizer': d_optimizer.state_dict(),
240 | 'g_running': g_running.state_dict(),
241 | 'task_g': task.module.state_dict(),
242 | 't_running_g': t_running.state_dict(),
243 | },
244 | f'{run_dir}/{i:06d}.model',
245 | )
246 |
247 | state_msg = (
248 | f'Size: {4 * 2 ** step}; G: {gen_loss_val:.3f}; D: {disc_loss_val:.3f};'
249 | f' Grad: {grad_loss_val:.3f}; Alpha: {alpha:.5f}'
250 | )
251 |
252 | pbar.set_description(state_msg)
253 | if not args.no_tb:
254 | writer.add_scalar('1.loss/G', gen_loss_val, i)
255 | writer.add_scalar('1.loss/D', disc_loss_val, i)
256 | writer.add_scalar('1.loss/grad', grad_loss_val, i)
257 | if args.contrastive:
258 | writer.add_scalar('1.loss/BT_real', contr_loss_real, i)
259 | writer.add_scalar('1.loss/BT_fake', contr_loss_fake, i)
260 | writer.add_scalar('alpha', alpha, i)
261 | writer.add_scalar('resolution', 4 * 2 ** step, i)
262 | writer.add_scalar('batch_size', loader.batch_size, i)
263 | writer.add_scalar('lr/discriminator', d_optimizer.param_groups[0]['lr'], i)
264 | writer.add_scalar('lr/generator', g_optimizer.param_groups[0]['lr'], i)
265 | writer.add_scalar('lr/task', g_optimizer.param_groups[1]['lr'], i)
266 | writer.add_scalar('kimgs', num_imgs / 1000, i)
267 |
268 |
269 | def get_args():
270 | parser = argparse.ArgumentParser(description='Hyper-Modulation')
271 |
272 | # GAN config
273 | parser.add_argument('path', type=str, help='path of specified dataset')
274 | parser.add_argument('--iterations', default=3_000_000, type=int, help='training iterations')
275 | parser.add_argument('--phase', type=int, default=600_000, help='number of samples used for each training phases')
276 | parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
277 | parser.add_argument('--sched', action='store_true', help='use lr scheduling')
278 | parser.add_argument('--init_size', default=8, type=int, help='initial image size')
279 | parser.add_argument('--max_size', default=1024, type=int, help='max image size')
280 | parser.add_argument('--batch_default', default=32, type=int, help='batch size if no lr scheduling is activated')
281 | parser.add_argument('--n_critic', default=1, type=int, help='Number of critic iterations in a global iteration')
282 | parser.add_argument('--ckpt', default=None, type=str, help='load from previous checkpoints')
283 | parser.add_argument('--no_from_rgb_activate', action='store_true', help='use activate in from_rgb (original implementation)')
284 | parser.add_argument('--mixing', action='store_true', help='use mixing regularization')
285 | parser.add_argument('--loss', type=str, default='wgan-gp', choices=['wgan-gp', 'r1'], help='class of gan loss')
286 | parser.add_argument('--const_alpha', default=None, type=float, help='Whether to use a constant alpha')
287 | # Logging config
288 | parser.add_argument('--name', default=None, type=str, help='Name of the experiment')
289 | parser.add_argument('--run_dir', default='data/training-runs', type=str, help='')
290 | parser.add_argument('--no_tb', action='store_true', help='dont log in tensorboard')
291 | parser.add_argument('--no_checkpointing', action='store_true', help='dont save checkpoints')
292 | parser.add_argument('--begin_checkpointing_at', default=0, type=int, help='dont save model before a training point')
293 | parser.add_argument('--checkpoint_interval', default=10_000, type=int, help='how often to save the model')
294 | parser.add_argument('--img_log_interval', default=1_000, type=int, help='how often to generate images')
295 | # Dataset config
296 | parser.add_argument('--filter_classes', default=None, nargs='+', type=int, help='Use only n classes')
297 | parser.add_argument('--dataset_length', default=None, type=int, help='Constrain the number of samples')
298 | parser.add_argument('--dataset_class_samples', default=None, type=int, help='Constrain the number of samples')
299 | parser.add_argument('--dataset_random_class_sampling', action='store_true', help='Random sample the subset of class samples')
300 | # Hyper-modulation config
301 | parser.add_argument('--n_mlp_style', default=8, type=int, help='')
302 | parser.add_argument('--n_mlp_task', default=4, type=int, help='Task net depth')
303 | parser.add_argument('--code_size', default=512, type=int, help='')
304 | parser.add_argument('--task_size', default=64, type=int, help='task width')
305 | parser.add_argument('--finetune', action='store_true', help='Finetune underlying pretrained model')
306 | # Contrastive config
307 | parser.add_argument('--contrastive', action='store_true', help='Enable contrastive training')
308 | parser.add_argument('--contr_weighting', default=0.01, type=float, help='Contrastive training weighting in the final loss')
309 | parser.add_argument('--contr_feats', default=512, type=int, help='Dimensionality of features passed to the contrastive loss')
310 |
311 | return parser.parse_args()
312 |
313 |
314 | def main(args):
315 | seed_everything(42)
316 | torch.backends.cudnn.benchmark = True
317 | contr_transform = None
318 | contr_criterion = None
319 |
320 | print(f'loss: {args.loss}')
321 |
322 | transform = transforms.Compose(
323 | [
324 | transforms.RandomHorizontalFlip(),
325 | transforms.ToTensor(),
326 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
327 | ]
328 | )
329 | if args.contrastive:
330 | contr_transform = ContrTransform(crop_resize=256)
331 |
332 | dataset = MultiResolutionDataset(args.path, transform, selected_classes=args.filter_classes, class_samples=args.dataset_class_samples, random_class_sampling=args.dataset_random_class_sampling, length=args.dataset_length)
333 | print(f'dataset len: {len(dataset)}')
334 | print(f'dataset classes: {dataset.num_classes}')
335 | args.num_classes = dataset.num_classes
336 | print(f'contrastive training: {args.contrastive}')
337 |
338 | generator = nn.DataParallel(StyledGenerator(code_dim=args.code_size, task_dim=args.task_size, n_mlp=args.n_mlp_style)).cuda()
339 | discriminator = nn.DataParallel(
340 | Discriminator(num_classes=args.num_classes, from_rgb_activate=not args.no_from_rgb_activate)
341 | ).cuda()
342 | task = nn.DataParallel(Task(args.task_size, n_mlp=args.n_mlp_task, num_labels=dataset.num_classes)).cuda() # class network
343 | g_running = StyledGenerator(code_dim=args.code_size, task_dim=args.task_size, n_mlp=args.n_mlp_style).cuda()
344 | g_running.train(False)
345 | t_running = Task(args.task_size, n_mlp=args.n_mlp_task, num_labels=dataset.num_classes).cuda()
346 | t_running.train(False)
347 | if args.contrastive:
348 | contr_criterion = nn.DataParallel(BarlowTwins(num_feats=args.contr_feats, use_projector=False)).cuda()
349 |
350 | g_optimizer = optim.Adam(
351 | generator.module.generator.parameters(), lr=args.lr, betas=(0.0, 0.99)
352 | )
353 | g_optimizer.add_param_group(
354 | {
355 | 'params': task.parameters(),
356 | 'lr': args.lr * 0.01,
357 | 'mult': 0.001,
358 | }
359 | )
360 | d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.0, 0.99))
361 | if args.contrastive:
362 | d_optimizer.add_param_group({
363 | 'params': contr_criterion.parameters(),
364 | 'lr': args.lr,
365 | })
366 |
367 | accumulate(g_running, generator.module, 0)
368 | accumulate(t_running, task.module, 0)
369 |
370 | if args.ckpt is not None:
371 | ckpt = torch.load(args.ckpt)
372 | if 'task_g' in ckpt:
373 | ckpt['task'] = ckpt['task_g']
374 |
375 | if 'task' in ckpt: # load self-aligned checkpoint
376 | generator.module.load_state_dict(ckpt['generator'])
377 | # copy original class embedding for all new classes
378 | ckpt['task']['task.0.weight_orig'] = ckpt['task']['task.0.weight_orig'].repeat((dataset.num_classes, 1))
379 | task.module.load_state_dict(ckpt['task'])
380 |
381 | g_running.load_state_dict(ckpt['generator'])
382 | t_running.load_state_dict(ckpt['task'])
383 | del ckpt # avoid OOM
384 |
385 | # load vanilla discriminator
386 | ckpt = torch.load('data/stylegan-256px-new.model')
387 | missing, unexpected = discriminator.module.load_state_dict(ckpt['discriminator'], strict=False)
388 | unexpected = [k for k in unexpected if not k.startswith('linear.')]
389 | assert len(unexpected) == 0
390 | missing = [k for k in missing if not k.startswith('final.')]
391 | assert len(missing) == 0
392 | del ckpt # avoid OOM
393 | else: # begin training from pretrained vanilla stylegan
394 | missing, unexpected = generator.module.load_state_dict(convert_generator(ckpt['generator']), strict=False)
395 | assert_loaded_keys(missing, unexpected)
396 | missing, unexpected = g_running.load_state_dict(convert_generator(ckpt['g_running']), strict=False)
397 | assert_loaded_keys(missing, unexpected)
398 |
399 | # load vanilla discriminator
400 | missing, unexpected = discriminator.module.load_state_dict(ckpt['discriminator'], strict=False)
401 | unexpected = [k for k in unexpected if not k.startswith('linear.')]
402 | assert len(unexpected) == 0
403 | missing = [k for k in missing if not k.startswith('final.')]
404 | assert len(missing) == 0
405 | del ckpt # avoid OOM
406 |
407 | if not args.finetune:
408 | # can be stated here (after the optimizers declaration), as long as it's before the forward pass
409 | freeze_layers(generator, discriminator, g_running)
410 |
411 | if args.sched:
412 | args.lr = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
413 | args.batch = {4: 512, 8: 256, 16: 128, 32: 64, 64: 30, 128: 20, 256: 10}
414 | else:
415 | args.lr = {}
416 | args.batch = {}
417 |
418 | args.gen_sample = {512: (8, 4), 1024: (4, 2)}
419 |
420 | train(args, dataset, generator, discriminator, task, g_running, t_running, g_optimizer, d_optimizer,
421 | contr_transform, contr_criterion)
422 |
423 |
424 | if __name__ == '__main__':
425 | main(get_args())
426 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import yaml
3 | from pathlib import Path
4 | import math
5 | import pickle
6 |
7 | import numpy as np
8 | import torch
9 | from torch.utils.data import DataLoader
10 | from torchvision.utils import save_image
11 | from pytorch_lightning import seed_everything as seed_everything_pl
12 |
13 | HPARAMS_FILENAME = 'hparams.yml'
14 |
15 |
16 | # ------------------------------------ I/O ------------------------------------
17 |
18 |
19 | def args_to_yaml(path, args, exist_ok=False):
20 | file = Path(path)
21 |
22 | if file.exists():
23 | if exist_ok:
24 | return
25 | else:
26 | raise FileExistsError(f'File already exists {file}')
27 |
28 | with file.open('w') as f:
29 | yaml.dump(args.__dict__, f,
30 | default_flow_style=False,
31 | sort_keys=True)
32 |
33 |
34 | def yaml_to_args(path):
35 | with open(path, 'r') as f:
36 | hparams = yaml.full_load(f)
37 |
38 | return argparse.Namespace(**hparams)
39 |
40 |
41 | def save_hparams(path, args, exist_ok=False):
42 | hparams_file = Path(path) / HPARAMS_FILENAME
43 | args_to_yaml(hparams_file, args, exist_ok=exist_ok)
44 |
45 |
46 | def load_hparams(path):
47 | path = Path(path)
48 | if not path.is_dir():
49 | path = path.parent
50 |
51 | hparams_file = path / HPARAMS_FILENAME
52 |
53 | if hparams_file.exists():
54 | return yaml_to_args(hparams_file)
55 |
56 | return None
57 |
58 |
59 | def num_to_one_hot(t, num_classes=3):
60 | return torch.nn.functional.one_hot(torch.tensor(t) if isinstance(t, int) or isinstance(t, list) else t, num_classes=num_classes).float()
61 |
62 |
63 | def load_checkpoint(args, hparams=None):
64 | if hparams is None:
65 | hparams = load_hparams(args.checkpoint_path)
66 | if hasattr(hparams, 'original_model') and hparams.original_model:
67 | from model.stylegan import StyledGenerator
68 | model = StyledGenerator(code_dim=hparams.code_size, n_mlp=hparams.n_mlp, c_dim=hparams.num_classes).to(args.device)
69 | checkpoint = torch.load(args.checkpoint_path, map_location=args.device)
70 | model.load_state_dict(checkpoint['g_running'] if 'g_running' in checkpoint else checkpoint)
71 | return model, num_to_one_hot
72 | from model.hyper_mod import StyledGenerator, Task
73 |
74 | checkpoint = torch.load(args.checkpoint_path, map_location=args.device)
75 | model = StyledGenerator(code_dim=hparams.code_size, task_dim=hparams.task_size, n_mlp=hparams.n_mlp_style).to(args.device)
76 | task = Task(hparams.task_size, n_mlp=hparams.n_mlp_task, num_labels=hparams.num_classes).to(args.device)
77 |
78 | from_self_align = False
79 | if hasattr(hparams, 'origin') and hparams.origin == 'self_align':
80 | from_self_align = True
81 |
82 | if from_self_align:
83 | model.load_state_dict(checkpoint['generator'])
84 | task.load_state_dict(checkpoint['task'])
85 | else:
86 | model.load_state_dict(checkpoint['g_running'])
87 | task.load_state_dict(checkpoint['t_running_g'])
88 | return model, task
89 |
90 |
91 | # ------------------------------------ Logging ------------------------------------
92 |
93 |
94 | def get_run_id(outdir):
95 | import os, re
96 | # From StyleGAN repo
97 | # Pick output directory.
98 | prev_run_dirs = []
99 | if os.path.isdir(outdir):
100 | prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))]
101 | prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
102 | prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
103 | return max(prev_run_ids, default=-1) + 1
104 |
105 |
106 | def log_images(fname, alpha, args, dataset, resolution, batch_size, step, generator, task, device='cuda', seed=2147483647):
107 | if args.no_tb:
108 | return
109 | rng = torch.Generator(device=device)
110 | rng.manual_seed(seed)
111 | images = []
112 | default_gen_n_classes = 8
113 | default_gen_n_samples = 4
114 | gen_n_classes, gen_n_samples = args.gen_sample.get(resolution, (default_gen_n_classes, default_gen_n_samples))
115 | if dataset.num_classes < gen_n_classes:
116 | # cycle through the classes
117 | gen_class_range = list(range(dataset.num_classes)) * math.ceil(gen_n_classes / dataset.num_classes)
118 | gen_class_range = gen_class_range[:gen_n_classes]
119 | else:
120 | gen_class_range = list(range(gen_n_classes))
121 | gen_class_range = gen_class_range * gen_n_samples
122 | with torch.no_grad():
123 | for j in range(0, len(gen_class_range), batch_size):
124 | gen_classes = gen_class_range[j:j + batch_size]
125 | images.append(
126 | generator(
127 | torch.randn(len(gen_classes), args.code_size, generator=rng, device=rng.device),
128 | step=step, alpha=alpha, task=task(torch.tensor(gen_classes).to(device))
129 | ).data.cpu()
130 | )
131 |
132 | save_image(torch.cat(images, 0), fname, nrow=gen_n_classes, normalize=True, range=(-1, 1))
133 |
134 |
135 | # ------------------------------------ Training ------------------------------------
136 |
137 |
138 | def seed_everything(seed): # cleaner imports
139 | seed_everything_pl(seed)
140 |
141 |
142 | def not_frozen_params(model):
143 | require = {}
144 | for name, param in model.named_parameters():
145 | require[name] = param.requires_grad
146 |
147 | return require
148 |
149 |
150 | def requires_grad(model, flag=True, grad_map=None):
151 | for n, p in model.named_parameters():
152 | if flag and grad_map is not None and not grad_map[n]: # filter those params which were originally frozen
153 | continue
154 |
155 | p.requires_grad = flag
156 |
157 |
158 | def accumulate(model1, model2, decay=0.999, grad_map=None):
159 | par1 = dict(model1.named_parameters())
160 | par2 = dict(model2.named_parameters())
161 |
162 | for k in par1.keys():
163 | if grad_map is not None and not grad_map['module.' + k]: # filter out frozen params
164 | continue
165 |
166 | par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
167 |
168 |
169 | def sample_data(dataset, batch_size, image_size=4):
170 | dataset.resolution = image_size
171 | loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=1, drop_last=True)
172 |
173 | return loader
174 |
175 |
176 | def adjust_lr(optimizer, lr):
177 | for group in optimizer.param_groups:
178 | mult = group.get('mult', 1)
179 | group['lr'] = lr * mult
180 |
181 |
182 | # ------------------------------------ Metrics ------------------------------------
183 |
184 |
185 | def load_scores(path):
186 | ms = {}
187 | for m_f in Path(path).glob('*.metric'):
188 | with open(m_f, 'rb') as f:
189 | data = pickle.load(f)
190 | ms[int(m_f.stem)] = data
191 | return dict(sorted(ms.items()))
192 |
193 |
194 | def get_submetric(s):
195 | parts = s.rsplit('.', 1)
196 | if len(parts) == 1:
197 | return s, None
198 | if parts[1] == '':
199 | raise ValueError('dot is used to mark metric index, dont use at the end')
200 | return parts[0], parts[1]
201 |
202 |
203 | def class_mean(ms, metric='fid', return_steps=False):
204 | ms = load_scores(ms)
205 | steps = list(ms.keys())
206 | class_m = 0
207 | classes = ms[steps[0]].keys()
208 | metric, sub_metric = get_submetric(metric)
209 | for class_i in classes:
210 | if sub_metric is not None:
211 | class_m += np.array([v[class_i][metric][sub_metric] for k, v in ms.items()])
212 | else:
213 | class_m += np.array([v[class_i][metric] for k, v in ms.items()])
214 |
215 | if return_steps:
216 | return steps, class_m / len(classes)
217 | return class_m / len(classes)
218 |
--------------------------------------------------------------------------------