├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── configs ├── discriminator_testbed │ ├── default.yaml │ ├── paper_runs │ │ ├── baboon64_pggan_blurpool.yaml │ │ ├── baboon_high64_pggan_blurpool.yaml │ │ ├── toyset64_pggan_avg.yaml │ │ ├── toyset64_pggan_blurpool.yaml │ │ ├── toyset64_pggan_stride.yaml │ │ └── toyset64_pggan_stride_spectrum_discriminator.yaml │ ├── pggan.yaml │ └── sg2.yaml └── generator_testbed │ ├── default.yaml │ ├── paper_runs │ ├── baboon64_pggan_bilinear.yaml │ ├── baboon_high64_pggan_bilinear.yaml │ ├── toyset64_pggan_bilinear.yaml │ ├── toyset64_pggan_nearest.yaml │ ├── toyset64_pggan_reshape.yaml │ └── toyset64_pggan_zeros.yaml │ ├── pggan.yaml │ ├── sg2.yaml │ └── sg3.yaml ├── data ├── baboon │ └── baboon.jpg └── toyset.py ├── dataset.py ├── discriminator_testbed.py ├── environment.yml ├── environment_sg3.yml ├── eval_discriminator.py ├── eval_generator.py ├── generator_testbed.py ├── gfx ├── teaser_disc.gif └── teaser_gen.gif ├── loss.py ├── models ├── __init__.py ├── direct_generator.py ├── mlp.py ├── pggan_discriminator.py ├── pggan_generator.py ├── stylegan2_discriminator.py ├── stylegan2_generator.py ├── stylegan3 │ └── __init__.py ├── stylegan3_generator.py └── utils.py ├── scripts ├── demo_discriminator_testbed.sh └── demo_generator_testbed.sh └── utils ├── __init__.py ├── checkpoints.py ├── gan_training.py ├── logger.py ├── metrics.py ├── misc.py ├── plot.py └── spectrum.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "models/stylegan3/stylegan3"] 2 | path = models/stylegan3/stylegan3 3 | url = https://github.com/NVlabs/stylegan3.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 autonomousvision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Frequency Bias of Generative Models 2 | 3 | ![](gfx/teaser_gen.gif) | ![](gfx/teaser_disc.gif) 4 | :---:| :---: 5 | Generator Testbed | Discriminator Testbed 6 | 7 | This repository contains official code for the paper 8 | [On the Frequency Bias of Generative Models](http://cvlibs.net/publications/Schwarz2021NEURIPS.pdf). 9 | 10 | You can find detailed usage instructions for analyzing standard GAN-architectures and your own models below. 11 | 12 | 13 | If you find our code or paper useful, please consider citing 14 | 15 | @inproceedings{Schwarz2021NEURIPS, 16 | title = {On the Frequency Bias of Generative Models}, 17 | author = {Schwarz, Katja and Liao, Yiyi and Geiger, Andreas}, 18 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, 19 | year = {2021} 20 | } 21 | 22 | ## Installation 23 | Please note, that this repo requires one GPU for running. 24 | First you have to make sure that you have all dependencies in place. 25 | The simplest way to do so, is to use [anaconda](https://www.anaconda.com/). 26 | 27 | You can create an anaconda environment called `fbias` using 28 | ``` 29 | conda env create -f environment.yml 30 | conda activate fbias 31 | ``` 32 | 33 | ## Generator Testbed 34 | 35 | You can run a demo of our generator testbed via: 36 | ``` 37 | chmod +x ./scripts/demo_generator_testbed.sh 38 | ./scripts/demo_generator_testbed.sh 39 | ``` 40 | This will train the Generator of [Progressive Growing GAN](https://arxiv.org/abs/1710.10196) to regress a single image. 41 | Further, the training progression on the image regression, spectrum, and spectrum error are summarized in `output/generator_testbed/baboon64/pggan/eval`. 42 | 43 | In general, to analyze the spectral properties of a generator architecture you can train a model by running 44 | ``` 45 | python generator_testbed.py *EXPERIMENT_NAME* *PATH/TO/CONFIG* 46 | ``` 47 | This script should create a folder `output/generator_testbed/*EXPERIMENT_NAME*` where you can find the training progress. 48 | To evaluate the spectral properties of the trained model run 49 | ``` 50 | python eval_generator.py *EXPERIMENT_NAME* --psnr --image-evolution --spectrum-evolution --spectrum-error-evolution 51 | ``` 52 | This will print the average PSNR of the regressed images and visualize image evolution, spectrum evolution, 53 | and spectrum error evolution in `output/generator_testbed/*EXPERIMENT_NAME*/eval`. 54 | 55 | ## Discriminator Testbed 56 | 57 | You can run a demo of our discriminator testbed via: 58 | ``` 59 | chmod +x ./scripts/demo_discriminator_testbed.sh 60 | ./scripts/demo_discriminator_testbed.sh 61 | ``` 62 | This will train the Discriminator of [Progressive Growing GAN](https://arxiv.org/abs/1710.10196) to regress a single image. 63 | Further, the training progression on the image regression, spectrum, and spectrum error are summarized in `output/discriminator_testbed/baboon64/pggan/eval`. 64 | 65 | In general, to analyze the spectral properties of a discriminator architecture you can train a model by running 66 | ``` 67 | python discriminator_testbed.py *EXPERIMENT_NAME* *PATH/TO/CONFIG* 68 | ``` 69 | This script should create a folder `output/discriminator_testbed/*EXPERIMENT_NAME*` where you can find the training progress. 70 | To evaluate the spectral properties of the trained model run 71 | ``` 72 | python eval_discriminator.py *EXPERIMENT_NAME* --psnr --image-evolution --spectrum-evolution --spectrum-error-evolution 73 | ``` 74 | This will print the average PSNR of the regressed images and visualize image evolution, spectrum evolution, 75 | and spectrum error evolution in `output/discriminator_testbed/*EXPERIMENT_NAME*/eval`. 76 | 77 | 78 | ## Datasets 79 | 80 | ### Toyset 81 | 82 | You can generate a toy dataset with Gaussian peaks as spectrum by running 83 | ``` 84 | cd data 85 | python toyset.py 64 100 86 | cd .. 87 | ``` 88 | This creates a folder `data/toyset/` and generates 100 images of resolution 64x64 pixels. 89 | 90 | ### CelebA-HQ 91 | 92 | Download [celebA_hq](https://github.com/tkarras/progressive_growing_of_gans). 93 | Then, update `data:root: *PATH/TO/CELEBA_HQ*` in the config file. 94 | 95 | ### Other datasets 96 | 97 | The config setting `data:root: *PATH/TO/DATA*` needs to point to a folder with the training images. 98 | You can use any dataset which follows the folder structure 99 | ``` 100 | *PATH/TO/DATA*/xxx.png 101 | *PATH/TO/DATA*/xxy.png 102 | ... 103 | ``` 104 | By default, the images are center-cropped and optionally resized to the resolution specified in the config file under`data:resolution`. 105 | Note, that you can also use a subset of images via `data:subset`. 106 | 107 | ## Architectures 108 | 109 | ### StyleGAN Support 110 | 111 | In addition to [Progressive Growing GAN](https://arxiv.org/abs/1710.10196), this repository supports analyzing the following architectures 112 | - [StyleGAN2](https://arxiv.org/abs/1912.04958) Generator 113 | - [StyleGAN2](https://arxiv.org/abs/1912.04958) Discriminator 114 | - [StyleGAN3](https://nvlabs-fi-cdn.nvidia.com/stylegan3/stylegan3-paper.pdf) Generator 115 | 116 | For this, you need to initialize the stylegan3 submodule by running 117 | ``` 118 | git pull --recurse-submodules 119 | cd models/stylegan3/stylegan3 120 | git submodule init 121 | git submodule update 122 | cd ../../../ 123 | ``` 124 | 125 | Next, you need to install any additional requirements for this repo. You can do this by running 126 | ``` 127 | conda activate fbias 128 | conda env update --file environment_sg3.yml --prune 129 | ``` 130 | 131 | You can now analyze the spectral properties of the StyleGAN architectures by running 132 | ``` 133 | # StyleGAN2 134 | python generator_testbed.py baboon64/StyleGAN2 configs/generator_testbed/sg2.yaml 135 | python discriminator_testbed.py baboon64/StyleGAN2 configs/discriminator_testbed/sg2.yaml 136 | # StyleGAN3 137 | python generator_testbed.py baboon64/StyleGAN3 configs/generator_testbed/sg3.yaml 138 | ``` 139 | 140 | ### Other architectures 141 | 142 | To analyze any other network architectures, you can add the respective model file (or submodule) under `models`. 143 | You then need to write a wrapper class to integrate the architecture seamlessly into this code base. 144 | Examples for wrapper classes are given in 145 | - `models/stylegan2_generator.py` for the Generator 146 | - `models/stylegan2_discriminator.py` for the Discriminator 147 | 148 | 149 | ## Further Information 150 | 151 | This repository builds on Lars Mescheder's awesome framework for [GAN training](https://github.com/LMescheder/GAN_stability). 152 | Further, we utilize code from the [Stylegan3-repo](https://github.com/NVlabs/stylegan3.git) and [GenForce](https://github.com/genforce/genforce). -------------------------------------------------------------------------------- /configs/discriminator_testbed/default.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | root: data/baboon 3 | resolution: 64 4 | subset: 5 | model: 6 | spectrum_disc: False 7 | training: 8 | batch_size: 10 9 | nworkers: 0 10 | monitoring: tensorboard 11 | nepochs: 10000 12 | print_every: 100 13 | eval_every: 100 14 | save_every: 1000 15 | lr_g: 0.01 16 | lr_d: 0.001 17 | reg_param: 10. 18 | model_average_beta: 0.999 19 | criterion: 20 | class_name: torch.nn.BCEWithLogitsLoss 21 | weight: 22 | model_file: model.pt 23 | seed: 0 24 | -------------------------------------------------------------------------------- /configs/discriminator_testbed/paper_runs/baboon64_pggan_blurpool.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | resolution: 128 3 | model: 4 | class_name: models.PGGANDiscriminator 5 | divide_channels_by: 8 6 | downsampling_mode: blurpool 7 | training: 8 | nepochs: 30000 9 | save_every: 10000 -------------------------------------------------------------------------------- /configs/discriminator_testbed/paper_runs/baboon_high64_pggan_blurpool.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | resolution: 128 3 | highpass: True 4 | model: 5 | class_name: models.PGGANDiscriminator 6 | divide_channels_by: 8 7 | downsampling_mode: blurpool 8 | training: 9 | nepochs: 30000 10 | save_every: 10000 -------------------------------------------------------------------------------- /configs/discriminator_testbed/paper_runs/toyset64_pggan_avg.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | root: data/toyset64_100 3 | subset: 10 4 | model: 5 | class_name: models.PGGANDiscriminator 6 | divide_channels_by: 8 7 | downsampling_mode: avg 8 | training: 9 | nepochs: 100000 10 | print_every: 1000 11 | eval_every: 1000 12 | save_every: 10000 -------------------------------------------------------------------------------- /configs/discriminator_testbed/paper_runs/toyset64_pggan_blurpool.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | root: data/toyset64_100 3 | subset: 10 4 | model: 5 | class_name: models.PGGANDiscriminator 6 | divide_channels_by: 8 7 | downsampling_mode: blurpool 8 | training: 9 | nepochs: 100000 10 | print_every: 1000 11 | eval_every: 1000 12 | save_every: 10000 -------------------------------------------------------------------------------- /configs/discriminator_testbed/paper_runs/toyset64_pggan_stride.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | root: data/toyset64_100 3 | subset: 10 4 | model: 5 | class_name: models.PGGANDiscriminator 6 | divide_channels_by: 8 7 | downsampling_mode: stride 8 | training: 9 | nepochs: 100000 10 | print_every: 1000 11 | eval_every: 1000 12 | save_every: 10000 -------------------------------------------------------------------------------- /configs/discriminator_testbed/paper_runs/toyset64_pggan_stride_spectrum_discriminator.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | root: data/toyset64_100 3 | subset: 10 4 | model: 5 | class_name: models.PGGANDiscriminator 6 | divide_channels_by: 8 7 | downsampling_mode: stride 8 | spectrum_disc: True 9 | training: 10 | nepochs: 100000 11 | print_every: 1000 12 | eval_every: 1000 13 | save_every: 10000 -------------------------------------------------------------------------------- /configs/discriminator_testbed/pggan.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | class_name: models.PGGANDiscriminator 3 | divide_channels_by: 8 4 | downsampling_mode: avg 5 | padding_mode: zeros 6 | -------------------------------------------------------------------------------- /configs/discriminator_testbed/sg2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | class_name: models.SG2Discriminator 3 | divide_channels_by: 8 4 | -------------------------------------------------------------------------------- /configs/generator_testbed/default.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | root: data/baboon 3 | resolution: 64 4 | subset: 5 | model: 6 | training: 7 | batch_size: 10 8 | nworkers: 0 9 | monitoring: tensorboard 10 | nepochs: 10000 11 | print_every: 100 12 | eval_every: 100 13 | save_every: 1000 14 | lr: 0.001 15 | criterion: 16 | class_name: torch.nn.MSELoss 17 | weight: 18 | model_file: model.pt 19 | seed: 0 20 | -------------------------------------------------------------------------------- /configs/generator_testbed/paper_runs/baboon64_pggan_bilinear.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | resolution: 128 3 | model: 4 | class_name: models.PGGANGenerator 5 | z_dim: 64 6 | divide_channels_by: 8 7 | upsampling_mode: bilinear 8 | padding_mode: reflect 9 | training: 10 | nepochs: 30000 11 | save_every: 10000 -------------------------------------------------------------------------------- /configs/generator_testbed/paper_runs/baboon_high64_pggan_bilinear.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | resolution: 128 3 | highpass: True 4 | model: 5 | class_name: models.PGGANGenerator 6 | z_dim: 64 7 | divide_channels_by: 8 8 | upsampling_mode: bilinear 9 | padding_mode: reflect 10 | training: 11 | nepochs: 30000 12 | save_every: 10000 -------------------------------------------------------------------------------- /configs/generator_testbed/paper_runs/toyset64_pggan_bilinear.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | root: data/toyset64_100 3 | subset: 10 4 | model: 5 | class_name: models.PGGANGenerator 6 | z_dim: 64 7 | divide_channels_by: 8 8 | upsampling_mode: bilinear 9 | training: 10 | nepochs: 100000 11 | print_every: 1000 12 | eval_every: 1000 13 | save_every: 10000 -------------------------------------------------------------------------------- /configs/generator_testbed/paper_runs/toyset64_pggan_nearest.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | root: data/toyset64_100 3 | subset: 10 4 | model: 5 | class_name: models.PGGANGenerator 6 | z_dim: 64 7 | divide_channels_by: 8 8 | upsampling_mode: nearest 9 | training: 10 | nepochs: 100000 11 | print_every: 1000 12 | eval_every: 1000 13 | save_every: 10000 -------------------------------------------------------------------------------- /configs/generator_testbed/paper_runs/toyset64_pggan_reshape.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | root: data/toyset64_100 3 | subset: 10 4 | model: 5 | class_name: models.PGGANGenerator 6 | z_dim: 64 7 | divide_channels_by: 12 8 | upsampling_mode: shuffle 9 | training: 10 | nepochs: 100000 11 | print_every: 1000 12 | eval_every: 1000 13 | save_every: 10000 -------------------------------------------------------------------------------- /configs/generator_testbed/paper_runs/toyset64_pggan_zeros.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | root: data/toyset64_100 3 | subset: 10 4 | model: 5 | class_name: models.PGGANGenerator 6 | z_dim: 64 7 | divide_channels_by: 8 8 | upsampling_mode: zeros 9 | training: 10 | nepochs: 100000 11 | print_every: 1000 12 | eval_every: 1000 13 | save_every: 10000 -------------------------------------------------------------------------------- /configs/generator_testbed/pggan.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | class_name: models.PGGANGenerator 3 | z_dim: 64 4 | divide_channels_by: 8 5 | upsampling_mode: bilinear 6 | padding_mode: reflect 7 | -------------------------------------------------------------------------------- /configs/generator_testbed/sg2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | class_name: models.SG2Generator 3 | z_dim: 64 4 | divide_channels_by: 8 -------------------------------------------------------------------------------- /configs/generator_testbed/sg3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | class_name: models.SG3Generator 3 | z_dim: 64 4 | divide_channels_by: 8 -------------------------------------------------------------------------------- /data/baboon/baboon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/frequency_bias/1bd267e897bd2d122524cbe3f01f241b0dea75c5/data/baboon/baboon.jpg -------------------------------------------------------------------------------- /data/toyset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from tqdm import tqdm 4 | from math import sqrt, pi 5 | from torch.fft import irfftn 6 | from torchvision.transforms import CenterCrop, ToPILImage 7 | 8 | 9 | def gaussian(x, mu, sigma): 10 | return 1 / (sqrt(2 * pi) * sigma) * torch.exp(- (x - mu) ** 2 / (2 * sigma ** 2)) 11 | 12 | 13 | def make_circ_magnitude(resolution, freq): 14 | """Create circular magnitude image using a 2D Gaussian with mean=freq and sigma~1pixel.""" 15 | if resolution % 2 != 0: 16 | raise NotImplementedError 17 | 18 | # magnitude image of (real) image has shape: (res x res//2+1) 19 | # we first only create the first quadrant of shape (res//2 x res//2) 20 | spectrum_size = resolution // 2 21 | 22 | # Use a 2D gaussian to create circular spectrum 23 | r = torch.stack( 24 | torch.meshgrid(torch.linspace(0, 1, spectrum_size), 25 | torch.linspace(0, 1, spectrum_size + 1), # add one to store the mean value 26 | ) 27 | ).norm(dim=0) 28 | 29 | mean = freq 30 | std = 1 / (spectrum_size + 1) # sigma = 1pxl 31 | magnitude = gaussian(r, mean, std) 32 | 33 | # We need to decide the normalization of the Gaussian 34 | # We choose it, s.t. 35 | # image.norm() = fft(image, norm='ortho').norm() ~ resolution 36 | # fft(image, norm='ortho').norm() = sqrt(magnitude.sum()) 37 | # -> magnitude.sum() ~ resolution**2 38 | # Since we consider only a single channel but image will have three channels, we get 39 | # magnitude.sum() ~ resolution**2 / 3 40 | 41 | # Compute sum of a Gaussian with mean=1/sqrt(2) (middle of spectrum) and sigma~1pixel 42 | # to get the normalization constant 43 | norm = gaussian(r, 1 / sqrt(2), std).sum() 44 | norm *= 4 # we only considered one quadrant 45 | 46 | magnitude /= norm # normalize sum approx to 1 47 | magnitude *= resolution**2 / 3 # scale sum to approx resolution**2 / 3 48 | 49 | # Stack to get the full magnitude image of shape (res x res//2+1) 50 | magnitude = torch.cat([magnitude, magnitude.flipud()]) 51 | return magnitude 52 | 53 | 54 | def img_from_magnitude(magnitude): 55 | """Combine a given spectrum (magnitude) with a random phase and use inverse Fourier transform to obtain an image.""" 56 | magnitude = magnitude.repeat(3, 1, 1) # create 3 channels 57 | 58 | phase = torch.rand_like(magnitude) * 2 * pi # uniform distributed phase 59 | # set phase of mean to zero 60 | phase[0, 0] = 0 61 | phase[0, -1] = 0 62 | phase[magnitude.shape[0]//2, 0] = 0 63 | phase[magnitude.shape[0]//2, -1] = 0 64 | 65 | tanphi = torch.tan(phase) 66 | 67 | # compute real and imaginary part 68 | # tan(phase) = Im/Re -> Im = Re*tan(phase) 69 | # mag^2 = Im^2 + Re^2 -> mag^2 = (1+tan(phase)^2)*Re^2 -> Re = mag/sqrt(1+tan(phase)^2) 70 | real = magnitude / (1 + tanphi ** 2).sqrt() 71 | imag = real * tanphi 72 | 73 | freq_img = torch.complex(real, imag) 74 | img = irfftn(freq_img, dim=(1, 2), norm='ortho') 75 | 76 | return img 77 | 78 | 79 | def generate_toysamples(resolution, freqrange=[(0.05, 0.15), (0.75, 0.85)]): 80 | assert isinstance(freqrange, list) and isinstance(freqrange[0], tuple) 81 | resolution_gen = 2*resolution # generate images at higher resolution to avoid discretization artifacts 82 | crop = CenterCrop((resolution, resolution)) # center crop generated images <-> downsample spectrum 83 | to_pil = ToPILImage() 84 | 85 | fnyq = sqrt(2) # diagonal of image in range [0,1] 86 | while True: 87 | img = torch.zeros(3, resolution_gen, resolution_gen) 88 | for fmin, fmax in freqrange: 89 | f = torch.rand(1) * (fmax - fmin) + fmin 90 | f = f * fnyq 91 | 92 | img += img_from_magnitude(make_circ_magnitude(resolution_gen, f)) 93 | 94 | img /= len(freqrange) # compute mean 95 | # convert to uint8, assume range is [-1, 1] 96 | img = ((img + 1) * 127.5).clamp_(0, 255).to(torch.uint8) 97 | img = to_pil(crop(img)) 98 | yield img 99 | 100 | 101 | if __name__ == '__main__': 102 | import os 103 | import sys 104 | sys.path.append('..') 105 | # Arguments 106 | parser = argparse.ArgumentParser( 107 | description='Generate toyimages with multiple Gaussian peaks as spectrum.' 108 | ) 109 | parser.add_argument('res', type=int, help='Image resolution.') 110 | parser.add_argument('nsamples', type=int, help='Number of samples to generate.') 111 | parser.add_argument('--outdir', type=str, help='Directory for saving the images.') 112 | parser.add_argument('--seed', type=int, default=0, help='Random seed.') 113 | parser.add_argument('--plot_stats', action='store_true', help='Plot spectrum statistics of generated toyset.') 114 | parser.add_argument('--freqs', type=str, nargs='+', 115 | default=['0.05,0.15', '0.75, 0.85'], 116 | help='List of frequency ranges (fmin,fmax) wrt. the nyquist frequency.' 117 | 'The number of freqranges equals the number of peaks in the spectrum of each image. ' 118 | 'The mean value of each peak i is drawn uniformly from (fmin,fmax)_i.') 119 | 120 | args = parser.parse_args() 121 | try: 122 | args.freqs = [tuple(map(float, s.split(','))) for s in args.freqs] 123 | for f in args.freqs: 124 | if len(f) != 2: 125 | raise TypeError 126 | except: 127 | raise TypeError("freqranges must be fmin,fmax") 128 | 129 | if args.outdir is None: 130 | args.outdir = f'toyset{args.res}_{args.nsamples}' 131 | os.makedirs(args.outdir, exist_ok=True) 132 | 133 | torch.manual_seed(args.seed) 134 | i = 0 135 | for img in tqdm(generate_toysamples(resolution=args.res), total=args.nsamples, desc='Creating toyset'): 136 | img.save(os.path.join(args.outdir, '%08d.png' % i)) 137 | i += 1 138 | if i == args.nsamples: 139 | break 140 | 141 | if args.plot_stats: 142 | import matplotlib.pyplot as plt; plt.switch_backend('TkAgg'); plt.ion() 143 | from glob import glob 144 | from PIL import Image 145 | from torchvision.transforms import ToTensor 146 | from utils.spectrum import get_spectrum 147 | from utils.plot import plot_std, HAS_LATEX 148 | 149 | imgs = torch.stack([ToTensor()(Image.open(f)) for f in glob(os.path.join(args.outdir, '*.png'))]) 150 | imgs = imgs * 2 - 1 151 | spectra = get_spectrum(imgs.flatten(0, 1)).unflatten(0, (args.nsamples, 3)).mean(dim=1) 152 | 153 | fig, ax = plt.subplots(1) 154 | 155 | # Settings for x-axis 156 | N = sqrt(2) * args.res 157 | fnyq = (N - 1) / 2 158 | x_ticks = [0, fnyq / 2, fnyq] 159 | x_ticklabels = ['%.1f' % (l / fnyq) for l in x_ticks] 160 | 161 | ax.set_xlim(0, fnyq) 162 | xlabel = r'$f/f_{nyq}$' if HAS_LATEX else 'f/fnyq' 163 | ax.set_xlabel(xlabel) 164 | ax.set_xticks(x_ticks) 165 | ax.set_xticklabels(x_ticklabels) 166 | 167 | # Settings for y-axis 168 | ax.set_ylabel(r'Spectral density') 169 | ax.set_yscale('log') 170 | 171 | plot_std(spectra.mean(dim=0), spectra.std(dim=0), ax=ax) 172 | for freqs in args.freqs: 173 | ax.axvline(sum(freqs) / 2 * fnyq, c='k', ls='--') 174 | 175 | plt.show() 176 | plt.waitforbuttonpress() -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from glob import glob 4 | from PIL import Image 5 | from torchvision.datasets import VisionDataset 6 | from torchvision import transforms 7 | 8 | 9 | IMG_EXTENSIONS = ['.png', '.jpg'] 10 | 11 | 12 | class ToSquare: 13 | """Crop a ``PIL Image`` at the center to make it square. This transform does not support torchscript. 14 | 15 | Crops a PIL Image (H x W x C) to (R x R x C) where R=min(H,W). 16 | """ 17 | 18 | def __call__(self, pic): 19 | """ 20 | Args: 21 | pic (PIL Image): Image to be cropped. 22 | 23 | Returns: 24 | PIL Image: Square PIL Image. 25 | """ 26 | R = min(pic.size) 27 | return transforms.functional.center_crop(pic, (R, R)) 28 | 29 | def __repr__(self): 30 | return self.__class__.__name__ + '()' 31 | 32 | 33 | class HighPass(torch.nn.Module): 34 | """Apply high pass filter to Image. 35 | This transform does not support PIL Image. 36 | """ 37 | 38 | def __init__(self): 39 | super().__init__() 40 | self.filter = torch.tensor([[1, -1], [-1, 1]]) 41 | self.padding = (1, 1, 1, 1) 42 | 43 | 44 | def forward(self, x): 45 | """ 46 | Args: 47 | tensor (Tensor): Tensor image to be high pass filtered. 48 | 49 | Returns: 50 | Tensor: Filtered Tensor image. 51 | """ 52 | # Pad input with reflection padding 53 | C, H, W = x.shape 54 | x = x.unsqueeze(0) 55 | x = torch.nn.functional.pad(x, self.padding, mode='reflect') 56 | 57 | 58 | # Convolve with the filter to filter high frequencies. 59 | f = self.filter.view(1, 1, 2, 2).repeat(C, 1, 1, 1).to(x.device, x.dtype) 60 | x = torch.nn.functional.conv2d(input=x, weight=f, groups=C).squeeze(0)[:, :H, :W] 61 | return x 62 | 63 | 64 | def __repr__(self): 65 | return self.__class__.__name__ + '(filter={0}, padding={1})'.format(self.filter.tolist(), self.padding) 66 | 67 | 68 | def get_dataset(cfg): 69 | tf = transforms.Compose([ 70 | ToSquare(), 71 | transforms.ToTensor(), 72 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # [0, 1] -> [-1, 1] 73 | ]) 74 | if cfg.data.get('highpass', False): 75 | tf.transforms.append(HighPass()) 76 | z_dim = cfg.model.get('z_dim', 1) 77 | dset = ImageLatentDataset(cfg.data.root, z_dim, transform=tf) 78 | H, W = dset[0][0].shape[1:] 79 | if H != W: 80 | raise RuntimeError(f'Images need to be square but have H={H} and W={W}.') 81 | resolution = cfg.data.get('resolution', H) 82 | if H != cfg.data.resolution: 83 | print(f'Resize images from {H} to {cfg.data.resolution} using Lanczos filter.') 84 | tf.transforms.insert(1, transforms.Resize(cfg.data.resolution, interpolation=transforms.InterpolationMode.LANCZOS)) 85 | assert dset[0][0].shape[1] == resolution 86 | dset.resolution = resolution 87 | dset.highpass = cfg.data.get('highpass', False) 88 | 89 | if z_dim == 1: # replace z by index, used for discriminator testbed 90 | dset.z = torch.arange(len(dset)) 91 | 92 | if cfg.data.subset is not None: 93 | subset_idcs = range(cfg.data.subset) 94 | subset = torch.utils.data.Subset(dset, subset_idcs) 95 | 96 | # inherit attributes 97 | subset.root = dset.root 98 | subset.z = dset.z[subset.indices] 99 | subset.resolution = dset.resolution 100 | subset.highpass = dset.highpass 101 | 102 | dset = subset 103 | 104 | return dset 105 | 106 | 107 | class ImageFolder(VisionDataset): 108 | """Loads all images in given root 109 | """ 110 | def __init__(self, root, transform=None): 111 | super(ImageFolder, self).__init__(root, transform=transform) 112 | self.img_paths = [p for p in glob(os.path.join(root, '*')) if os.path.splitext(p)[-1] in IMG_EXTENSIONS] 113 | 114 | def __len__(self): 115 | return len(self.img_paths) 116 | 117 | def __getitem__(self, idx): 118 | img = Image.open(self.img_paths[idx]) 119 | if self.transform is not None: 120 | img = self.transform(img) 121 | return img 122 | 123 | 124 | class ImageLatentDataset(ImageFolder): 125 | """Wrapper class which pairs each image with a fixed latent code.""" 126 | def __init__(self, root, z_dim, transform=None): 127 | super(ImageLatentDataset, self).__init__(root, transform=transform) 128 | self.z = torch.randn(len(self), z_dim) 129 | 130 | def __getitem__(self, idx): 131 | img = super(ImageLatentDataset, self).__getitem__(idx) 132 | return img, self.z[idx] 133 | -------------------------------------------------------------------------------- /discriminator_testbed.py: -------------------------------------------------------------------------------- 1 | """Train a Discriminator architecture with a GAN loss.""" 2 | 3 | import argparse 4 | import os 5 | import torch 6 | import pickle 7 | import copy 8 | from utils import CheckpointIO, Logger 9 | import utils.misc as misc 10 | import utils.plot as plot 11 | import utils.spectrum as spectrum 12 | import utils.gan_training as gan_training 13 | from dataset import get_dataset 14 | from loss import get_criterion 15 | from torchvision.utils import save_image 16 | 17 | 18 | if __name__ == '__main__': 19 | import matplotlib.pyplot as plt; plt.switch_backend('Agg'); plt.ioff() 20 | # Arguments 21 | parser = argparse.ArgumentParser( 22 | description='Image regression with a Discriminator.' 23 | ) 24 | parser.add_argument('expname', type=str, help='Name of experiment.') 25 | parser.add_argument('config', type=str, help='Path to config file.') 26 | 27 | args = parser.parse_args() 28 | cfg = misc.load_config(args.config, 'configs/discriminator_testbed/default.yaml') 29 | 30 | # fix random seed 31 | torch.manual_seed(cfg['training']['seed']) 32 | torch.cuda.manual_seed_all(cfg['training']['seed']) 33 | 34 | device = torch.device("cuda:0") 35 | 36 | # Short hands 37 | batch_size = cfg['training']['batch_size'] 38 | nworkers = cfg['training']['nworkers'] 39 | nepochs = cfg['training']['nepochs'] 40 | eval_every = cfg['training']['eval_every'] 41 | save_every = cfg['training']['save_every'] 42 | out_dir = os.path.join('output/discriminator_testbed', args.expname) 43 | log_dir = os.path.join(out_dir, 'logs') 44 | img_dir = os.path.join(out_dir, 'imgs') 45 | plot_dir = os.path.join(out_dir, 'plots') 46 | 47 | # Create missing directories 48 | for d in [log_dir, img_dir, plot_dir]: 49 | os.makedirs(d, exist_ok=True) 50 | 51 | # Logger 52 | checkpoint_io = CheckpointIO(checkpoint_dir=out_dir) 53 | 54 | # Save config 55 | misc.save_config(os.path.join(out_dir, 'config.yaml'), cfg) 56 | 57 | # Dataset 58 | dataset = get_dataset(cfg) 59 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=nworkers, 60 | pin_memory=True, drop_last=False) 61 | 62 | # Visualize training data 63 | grid_size = (8, 4) 64 | images = [dataset[i][0] for i in range(min(len(dataset), grid_size[0]*grid_size[1]))] 65 | eval_z = torch.stack([dataset[i][1] for i in range(min(len(dataset), grid_size[0]*grid_size[1]))]).to(device) 66 | save_image(images, os.path.join(out_dir, 'training_data.png'), nrow=grid_size[1], normalize=True, value_range=(-1, 1)) 67 | 68 | # Create models 69 | use_spec_disc = cfg.model.pop('spectrum_disc', False) 70 | common_kwargs = misc.EasyDict(resolution=cfg.data.resolution) 71 | generator = misc.construct_class_by_name(class_name="models.DirectGenerator", z=dataset.z, **common_kwargs) 72 | generator_test = copy.deepcopy(generator) 73 | model = misc.construct_class_by_name(**cfg.model, label_size=len(dataset)-1, # class conditional 74 | **common_kwargs).train().requires_grad_(True).to(device) 75 | if use_spec_disc: # additional discriminator on log of reduced spectrum 76 | from models import MLP 77 | from functools import partial 78 | spec_len = spectrum.resolution_to_spectrum_length(cfg.data.resolution) 79 | model.spec_disc = MLP(input_size=spec_len, 80 | output_size=len(dataset), 81 | nhidden=0, dhidden=spec_len, activation=partial(torch.nn.LeakyReLU, negative_slope=0.2)) 82 | print(generator) 83 | try: 84 | print(model) 85 | except TypeError: # print model does not work with SG2 Discriminator 86 | pass 87 | print(f'Generator has {misc.count_trainable_parameters(generator)} trainable parameters.') 88 | print(f'Discriminator has {misc.count_trainable_parameters(model)} trainable parameters.') 89 | 90 | # Put model on gpu if needed 91 | generator = generator.to(device) 92 | generator_test = generator_test.to(device) 93 | model = model.to(device) 94 | 95 | g_optimizer = torch.optim.Adam(generator.parameters(), lr=cfg.training.lr_g, betas=(0., 0.99), eps=1e-8) 96 | d_optimizer = torch.optim.Adam(model.parameters(), lr=cfg.training.lr_d, betas=(0., 0.99), eps=1e-8) 97 | criterion = get_criterion(**cfg.training.criterion) 98 | 99 | # Register modules to checkpoint 100 | checkpoint_io.register_modules( 101 | generator=generator, 102 | generator_test=generator_test, 103 | model=model, 104 | g_optimizer=g_optimizer, 105 | d_optimizer=d_optimizer, 106 | ) 107 | 108 | # Logger 109 | logger = Logger( 110 | log_dir=log_dir, 111 | img_dir=img_dir, 112 | monitoring=cfg['training']['monitoring'], 113 | monitoring_dir=os.path.join(out_dir, 'monitoring') 114 | ) 115 | 116 | # Load checkpoint if it exists 117 | try: 118 | load_dict = checkpoint_io.load('model.pt') 119 | except FileNotFoundError: 120 | epoch_idx = -1 121 | it = -1 122 | else: 123 | epoch_idx = load_dict.get('epoch_idx') 124 | it = load_dict.get('it') 125 | logger.load_stats('stats.p') 126 | 127 | def spec_disc_step(x, c=None): 128 | specs = spectrum.get_spectrum(x.flatten(0, 1), normalize=True).unflatten(0, x.shape[:2]).to(torch.float32) 129 | specs = specs.mean(dim=1) # average over channels 130 | specs = (1 + specs).log() # apply to logarithm of spectrum to avoid very large values 131 | return model.spec_disc(specs, c=c) 132 | 133 | def discriminator_trainstep(x_real, z, reg_param=10): 134 | gan_training.toggle_grad(generator, False) 135 | gan_training.toggle_grad(model, True) 136 | generator.train() 137 | model.train() 138 | d_optimizer.zero_grad() 139 | 140 | # On real data 141 | x_real.requires_grad_() 142 | 143 | d_real = model(x_real, c=z) 144 | targets = d_real.new_full(size=d_real.size(), fill_value=1) 145 | if use_spec_disc: 146 | d_real_spec = spec_disc_step(x_real, c=z) 147 | d_real = torch.cat([d_real, d_real_spec]) 148 | targets = targets.repeat(2, 1) 149 | 150 | dloss_real = criterion(d_real, targets) 151 | 152 | # Regularization on real 153 | dloss_real.backward(retain_graph=True) 154 | reg = reg_param * gan_training.compute_grad2(d_real, x_real).mean() 155 | reg.backward() 156 | 157 | # On fake data 158 | with torch.no_grad(): 159 | x_fake = generator(z) 160 | 161 | x_fake.requires_grad_() 162 | 163 | d_fake = model(x_fake, c=z) 164 | targets = d_fake.new_full(size=d_fake.size(), fill_value=0) 165 | if use_spec_disc: 166 | d_fake_spec = spec_disc_step(x_fake, c=z) 167 | d_fake = torch.cat([d_fake, d_fake_spec]) 168 | targets = targets.repeat(2, 1) 169 | 170 | dloss_fake = criterion(d_fake, targets) 171 | dloss_fake.backward() 172 | 173 | d_optimizer.step() 174 | 175 | gan_training.toggle_grad(model, False) 176 | 177 | # Output 178 | dloss = (dloss_real + dloss_fake) 179 | return dloss.item(), reg.item() 180 | 181 | def generator_trainstep(z): 182 | gan_training.toggle_grad(generator, True) 183 | gan_training.toggle_grad(model, False) 184 | generator.train() 185 | model.train() 186 | g_optimizer.zero_grad() 187 | 188 | x_fake = generator(z) 189 | d_fake = model(x_fake, c=z) 190 | targets = d_fake.new_full(size=d_fake.size(), fill_value=1) 191 | if use_spec_disc: 192 | d_fake_spec = spec_disc_step(x_fake, c=z) 193 | d_fake = torch.cat([d_fake, d_fake_spec]) 194 | targets = targets.repeat(2, 1) 195 | 196 | gloss = criterion(d_fake, targets) 197 | gloss.backward() 198 | 199 | g_optimizer.step() 200 | return gloss.item() 201 | 202 | # Training loop 203 | print('Start training...') 204 | while epoch_idx < nepochs: 205 | epoch_idx += 1 206 | 207 | for img, z in dataloader: 208 | it += 1 209 | img, z = img.to(device), z.to(device) 210 | 211 | if it > 0: # only evaluate at initialization 212 | model = model.train() 213 | 214 | # Model updates 215 | dloss, reg = discriminator_trainstep(img, z, reg_param=cfg.training.reg_param) 216 | logger.add('losses', 'discriminator', dloss, it=it) 217 | logger.add('losses', 'regularizer', reg, it=it) 218 | 219 | # Image updates 220 | gloss = generator_trainstep(z) 221 | logger.add('losses', 'generator', gloss, it=it) 222 | 223 | # Update ema 224 | gan_training.update_average(generator_test, generator, beta=cfg.training.model_average_beta) 225 | 226 | # Print stats 227 | if (it % cfg['training']['print_every']) == 0: 228 | dloss_last = logger.get_last('losses', 'discriminator') 229 | gloss_last = logger.get_last('losses', 'generator') 230 | print('[epoch %0d, it %4d] dloss = %.4f, gloss = %.4f' % (epoch_idx, it, dloss_last, gloss_last)) 231 | 232 | # Evaluate if necessary 233 | if (it % eval_every) == 0: 234 | generator_test.eval() 235 | # Evaluate spectrum 236 | spec_real, spec_gen = spectrum.evaluate_spectrum(dataset, generator_test, batch_size=batch_size) 237 | spec_gen.update({'it': it}) 238 | filename = os.path.join(log_dir, f'spectrum_%08d.pkl' % it) 239 | with open(filename, 'wb') as f: 240 | pickle.dump(spec_gen, f) 241 | 242 | # Save plot of spectrum 243 | filename = os.path.join(plot_dir, f'spectrum_%08d.png' % it) 244 | plot.plot_spectrum(spec_real, spec_gen, cfg.data.resolution, filename) 245 | 246 | # Save some generated images 247 | pred = generator_test(eval_z) 248 | filename = os.path.join(img_dir, 'samples_%08d.png' % it) 249 | save_image(pred, filename, normalize=True, value_range=(-1, 1)) 250 | 251 | # (iii) Checkpoint if necessary 252 | if (epoch_idx % save_every) == 0 and (it > 0): 253 | print('Saving checkpoint...') 254 | checkpoint_io.save('model.pt', epoch_idx=epoch_idx, it=it) 255 | logger.save_stats('stats.p') 256 | 257 | # Save model 258 | print('Saving last model...') 259 | checkpoint_io.save('model.pt', epoch_idx=epoch_idx, it=it) 260 | logger.save_stats('stats.p') 261 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: fbias 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python >= 3.8 7 | - pip 8 | - numpy>=1.20 9 | - scipy=1.7.1 10 | - pytorch=1.9.1 11 | - torchvision 12 | - cudatoolkit=11.1 13 | - tqdm=4.62.2 14 | - matplotlib=3.4.2 15 | - pyyaml=5.3.1 16 | - imageio=2.9.0 17 | - tensorboardX 18 | - pip: 19 | - imageio-ffmpeg==0.4.3 20 | -------------------------------------------------------------------------------- /environment_sg3.yml: -------------------------------------------------------------------------------- 1 | name: fbias 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python >= 3.8 7 | - pip 8 | - numpy>=1.20 9 | - click>=8.0 10 | - pillow=8.3.1 11 | - scipy=1.7.1 12 | - pytorch=1.9.1 13 | - cudatoolkit=11.1 14 | - requests=2.26.0 15 | - tqdm=4.62.2 16 | - ninja=1.10.2 17 | - matplotlib=3.4.2 18 | - imageio=2.9.0 19 | - pip: 20 | - imgui==1.3.0 21 | - glfw==2.2.0 22 | - pyopengl==3.1.5 23 | - imageio-ffmpeg==0.4.3 24 | - pyspng 25 | -------------------------------------------------------------------------------- /eval_discriminator.py: -------------------------------------------------------------------------------- 1 | """Train a Generator architecture with an MSE loss.""" 2 | 3 | import argparse 4 | import os 5 | import torch 6 | from glob import glob 7 | import pickle 8 | from tqdm import tqdm 9 | from PIL import Image 10 | from utils import CheckpointIO 11 | import utils.misc as misc 12 | import utils.plot as plot 13 | import utils.metrics as metrics 14 | from dataset import get_dataset 15 | 16 | 17 | if __name__ == '__main__': 18 | import matplotlib.pyplot as plt; plt.switch_backend('Agg'); plt.ioff() 19 | # Arguments 20 | parser = argparse.ArgumentParser( 21 | description='Evaluate image regression with a trained Generator.' 22 | ) 23 | parser.add_argument('expname', type=str, help='Name of experiment.') 24 | parser.add_argument('--psnr', action='store_true', help='Evaluate PSNR of regressed images.') 25 | parser.add_argument('--image-evolution', action='store_true', help='Create video of image evolution.') 26 | parser.add_argument('--spectrum-evolution', action='store_true', help='Create video of spectrum evolution.') 27 | parser.add_argument('--spectrum-error-evolution', action='store_true', help='Create image of spectrum error evolution.') 28 | 29 | args = parser.parse_args() 30 | run_dir = os.path.join('output/discriminator_testbed', args.expname) 31 | cfg = misc.load_config(os.path.join(run_dir, 'config.yaml')) 32 | 33 | # fix random seed (ensures to sample same latent codes as in training) 34 | torch.manual_seed(cfg['training']['seed']) 35 | torch.cuda.manual_seed_all(cfg['training']['seed']) 36 | 37 | device = torch.device("cuda:0") 38 | 39 | # Short hands 40 | batch_size = cfg['training']['batch_size'] 41 | nworkers = cfg['training']['nworkers'] 42 | out_dir = os.path.join(run_dir, 'eval') 43 | log_dir = os.path.join(run_dir, 'logs') 44 | img_dir = os.path.join(run_dir, 'imgs') 45 | plot_dir = os.path.join(run_dir, 'plots') 46 | 47 | # Create missing directories 48 | os.makedirs(out_dir, exist_ok=True) 49 | 50 | if args.psnr or args.spectrum_error_evolution: # Load dataset 51 | # Dataset 52 | dataset = get_dataset(cfg) 53 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=nworkers, 54 | pin_memory=True, drop_last=False) 55 | 56 | if args.psnr: # Load trained model to evaluate psnr of all images 57 | print('Evaluate PSNR...') 58 | # Logger 59 | checkpoint_io = CheckpointIO(checkpoint_dir=run_dir) 60 | 61 | # Create models 62 | common_kwargs = misc.EasyDict(resolution=cfg.data.resolution) 63 | generator = misc.construct_class_by_name(class_name="models.DirectGenerator", z=dataset.z, **common_kwargs) 64 | 65 | # Put generator on gpu if needed 66 | generator = generator.to(device) 67 | 68 | # Register modules to checkpoint 69 | checkpoint_io.register_modules( 70 | generator=generator, 71 | ) 72 | 73 | # Load checkpoint 74 | load_dict = checkpoint_io.load('model.pt') 75 | print(f'Using checkpoint from iteration {load_dict["it"]}.') 76 | 77 | psnr = [] 78 | for img, z in tqdm(dataloader): 79 | img, z = img.to(device), z.to(device) 80 | 81 | generator = generator.eval() 82 | pred = generator(z) 83 | psnr.append(metrics.psnr(pred, img)) 84 | psnr = torch.cat(psnr).mean().item() 85 | 86 | print(f'Average PSNR: {psnr:.1f}.') 87 | 88 | if args.image_evolution: 89 | print('Plot image evolution...') 90 | images = [Image.open(f) for f in sorted(glob(os.path.join(img_dir, 'samples_*.png')))] 91 | misc.make_video(images, os.path.join(out_dir, 'image_evolution.mp4'), fps=20, quality=8) 92 | print('Done.') 93 | 94 | if args.spectrum_evolution: 95 | print('Plot spectrum evolution...') 96 | images = [Image.open(f) for f in sorted(glob(os.path.join(plot_dir, 'spectrum_*.png')))] 97 | misc.make_video(images, os.path.join(out_dir, 'spectrum_evolution.mp4'), fps=20, quality=8, macro_block_size=None) 98 | print('Done.') 99 | 100 | if args.spectrum_error_evolution: 101 | print('Plot spectrum error evolution...') 102 | spec_file_real = os.path.join(dataset.root, f'spectrum{dataset.resolution}_N{len(dataset)}.pkl') # Loads cache file from training 103 | with open(spec_file_real, 'rb') as f: 104 | spec_real = pickle.load(f) 105 | 106 | spec_gen_all = [] 107 | for spec_file in sorted(glob(os.path.join(log_dir, 'spectrum_*.pkl'))): 108 | with open(spec_file, 'rb') as f: 109 | spec = pickle.load(f) 110 | spec_gen_all.append(spec) 111 | 112 | filename = os.path.join(out_dir, 'spectrum_error_evolution.png') 113 | plot.plot_spectrum_error_evolution(spec_real, spec_gen_all, dataset.resolution, filename) 114 | print('Done.') 115 | -------------------------------------------------------------------------------- /eval_generator.py: -------------------------------------------------------------------------------- 1 | """Train a Generator architecture with an MSE loss.""" 2 | 3 | import argparse 4 | import os 5 | import torch 6 | from glob import glob 7 | import pickle 8 | from tqdm import tqdm 9 | from PIL import Image 10 | from utils import CheckpointIO 11 | import utils.misc as misc 12 | import utils.plot as plot 13 | import utils.metrics as metrics 14 | from dataset import get_dataset 15 | 16 | 17 | if __name__ == '__main__': 18 | import matplotlib.pyplot as plt; plt.switch_backend('Agg'); plt.ioff() 19 | # Arguments 20 | parser = argparse.ArgumentParser( 21 | description='Evaluate image regression with a trained Generator.' 22 | ) 23 | parser.add_argument('expname', type=str, help='Name of experiment.') 24 | parser.add_argument('--psnr', action='store_true', help='Evaluate PSNR of regressed images.') 25 | parser.add_argument('--image-evolution', action='store_true', help='Create video of image evolution.') 26 | parser.add_argument('--spectrum-evolution', action='store_true', help='Create video of spectrum evolution.') 27 | parser.add_argument('--spectrum-error-evolution', action='store_true', help='Create image of spectrum error evolution.') 28 | 29 | args = parser.parse_args() 30 | run_dir = os.path.join('output/generator_testbed', args.expname) 31 | cfg = misc.load_config(os.path.join(run_dir, 'config.yaml')) 32 | 33 | # fix random seed (ensures to sample same latent codes as in training) 34 | torch.manual_seed(cfg['training']['seed']) 35 | torch.cuda.manual_seed_all(cfg['training']['seed']) 36 | 37 | device = torch.device("cuda:0") 38 | 39 | # Short hands 40 | batch_size = cfg['training']['batch_size'] 41 | nworkers = cfg['training']['nworkers'] 42 | out_dir = os.path.join(run_dir, 'eval') 43 | log_dir = os.path.join(run_dir, 'logs') 44 | img_dir = os.path.join(run_dir, 'imgs') 45 | plot_dir = os.path.join(run_dir, 'plots') 46 | 47 | # Create missing directories 48 | os.makedirs(out_dir, exist_ok=True) 49 | 50 | if args.psnr or args.spectrum_error_evolution: # Load dataset 51 | # Dataset 52 | dataset = get_dataset(cfg) 53 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=nworkers, 54 | pin_memory=True, drop_last=False) 55 | 56 | if args.psnr: # Load trained model to evaluate psnr of all images 57 | print('Evaluate PSNR...') 58 | # Logger 59 | checkpoint_io = CheckpointIO(checkpoint_dir=run_dir) 60 | 61 | # Create models 62 | common_kwargs = misc.EasyDict(resolution=cfg.data.resolution) 63 | model = misc.construct_class_by_name(**cfg.model, **common_kwargs).train().requires_grad_(True).to(device) 64 | 65 | # Put model on gpu if needed 66 | model = model.to(device) 67 | 68 | # Register modules to checkpoint 69 | checkpoint_io.register_modules( 70 | model=model, 71 | ) 72 | 73 | # Load checkpoint 74 | load_dict = checkpoint_io.load('model.pt') 75 | print(f'Using checkpoint from iteration {load_dict["it"]}.') 76 | 77 | psnr = [] 78 | for img, z in tqdm(dataloader): 79 | img, z = img.to(device), z.to(device) 80 | 81 | model = model.eval() 82 | pred = model(z) 83 | psnr.append(metrics.psnr(pred, img)) 84 | psnr = torch.cat(psnr).mean().item() 85 | 86 | print(f'Average PSNR: {psnr:.1f}.') 87 | 88 | if args.image_evolution: 89 | print('Plot image evolution...') 90 | images = [Image.open(f) for f in sorted(glob(os.path.join(img_dir, 'samples_*.png')))] 91 | misc.make_video(images, os.path.join(out_dir, 'image_evolution.mp4'), fps=20, quality=8) 92 | print('Done.') 93 | 94 | if args.spectrum_evolution: 95 | print('Plot spectrum evolution...') 96 | images = [Image.open(f) for f in sorted(glob(os.path.join(plot_dir, 'spectrum_*.png')))] 97 | misc.make_video(images, os.path.join(out_dir, 'spectrum_evolution.mp4'), fps=20, quality=8, macro_block_size=None) 98 | print('Done.') 99 | 100 | if args.spectrum_error_evolution: 101 | print('Plot spectrum error evolution...') 102 | spec_file_real = os.path.join(dataset.root, f'spectrum{dataset.resolution}_N{len(dataset)}.pkl') # Loads cache file from training 103 | with open(spec_file_real, 'rb') as f: 104 | spec_real = pickle.load(f) 105 | 106 | spec_gen_all = [] 107 | for spec_file in sorted(glob(os.path.join(log_dir, 'spectrum_*.pkl'))): 108 | with open(spec_file, 'rb') as f: 109 | spec = pickle.load(f) 110 | spec_gen_all.append(spec) 111 | 112 | filename = os.path.join(out_dir, 'spectrum_error_evolution.png') 113 | plot.plot_spectrum_error_evolution(spec_real, spec_gen_all, dataset.resolution, filename) 114 | print('Done.') 115 | -------------------------------------------------------------------------------- /generator_testbed.py: -------------------------------------------------------------------------------- 1 | """Train a Generator architecture with an MSE loss.""" 2 | 3 | import argparse 4 | import os 5 | import torch 6 | import pickle 7 | from utils import CheckpointIO, Logger 8 | import utils.misc as misc 9 | import utils.plot as plot 10 | import utils.spectrum as spectrum 11 | from dataset import get_dataset 12 | from loss import get_criterion 13 | from torchvision.utils import save_image 14 | 15 | 16 | if __name__ == '__main__': 17 | import matplotlib.pyplot as plt; plt.switch_backend('Agg'); plt.ioff() 18 | # Arguments 19 | parser = argparse.ArgumentParser( 20 | description='Image regression with a Generator.' 21 | ) 22 | parser.add_argument('expname', type=str, help='Name of experiment.') 23 | parser.add_argument('config', type=str, help='Path to config file.') 24 | 25 | args = parser.parse_args() 26 | cfg = misc.load_config(args.config, 'configs/generator_testbed/default.yaml') 27 | 28 | # fix random seed 29 | torch.manual_seed(cfg['training']['seed']) 30 | torch.cuda.manual_seed_all(cfg['training']['seed']) 31 | 32 | device = torch.device("cuda:0") 33 | 34 | # Short hands 35 | batch_size = cfg['training']['batch_size'] 36 | nworkers = cfg['training']['nworkers'] 37 | nepochs = cfg['training']['nepochs'] 38 | eval_every = cfg['training']['eval_every'] 39 | save_every = cfg['training']['save_every'] 40 | out_dir = os.path.join('output/generator_testbed', args.expname) 41 | log_dir = os.path.join(out_dir, 'logs') 42 | img_dir = os.path.join(out_dir, 'imgs') 43 | plot_dir = os.path.join(out_dir, 'plots') 44 | 45 | # Create missing directories 46 | for d in [log_dir, img_dir, plot_dir]: 47 | os.makedirs(d, exist_ok=True) 48 | 49 | # Logger 50 | checkpoint_io = CheckpointIO(checkpoint_dir=out_dir) 51 | 52 | # Save config 53 | misc.save_config(os.path.join(out_dir, 'config.yaml'), cfg) 54 | 55 | # Dataset 56 | dataset = get_dataset(cfg) 57 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=nworkers, 58 | pin_memory=True, drop_last=False) 59 | 60 | # Visualize training data 61 | grid_size = (8, 4) 62 | images = [dataset[i][0] for i in range(min(len(dataset), grid_size[0]*grid_size[1]))] 63 | eval_z = torch.stack([dataset[i][1] for i in range(min(len(dataset), grid_size[0]*grid_size[1]))]).to(device) 64 | save_image(images, os.path.join(out_dir, 'training_data.png'), nrow=grid_size[1], normalize=True, value_range=(-1, 1)) 65 | 66 | # Create models 67 | common_kwargs = misc.EasyDict(resolution=cfg.data.resolution) 68 | model = misc.construct_class_by_name(**cfg.model, **common_kwargs).train().requires_grad_(True).to(device) 69 | print(model) 70 | print(f'Model has {misc.count_trainable_parameters(model)} trainable parameters.') 71 | 72 | # Put model on gpu if needed 73 | model = model.to(device) 74 | 75 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg.training.lr, betas=(0., 0.99), eps=1e-8) 76 | criterion = get_criterion(**cfg.training.criterion) 77 | 78 | # Register modules to checkpoint 79 | checkpoint_io.register_modules( 80 | model=model, 81 | optimizer=optimizer, 82 | ) 83 | 84 | # Logger 85 | logger = Logger( 86 | log_dir=log_dir, 87 | img_dir=img_dir, 88 | monitoring=cfg['training']['monitoring'], 89 | monitoring_dir=os.path.join(out_dir, 'monitoring') 90 | ) 91 | 92 | # Load checkpoint if it exists 93 | try: 94 | load_dict = checkpoint_io.load('model.pt') 95 | except FileNotFoundError: 96 | epoch_idx = -1 97 | it = -1 98 | else: 99 | epoch_idx = load_dict.get('epoch_idx') 100 | it = load_dict.get('it') 101 | logger.load_stats('stats.p') 102 | 103 | # Training loop 104 | print('Start training...') 105 | while epoch_idx < nepochs: 106 | epoch_idx += 1 107 | 108 | for img, z in dataloader: 109 | it += 1 110 | img, z = img.to(device), z.to(device) 111 | 112 | if it > 0: # only evaluate at initialization 113 | model = model.train() 114 | 115 | # Model updates 116 | optimizer.zero_grad() 117 | pred = model(z) 118 | loss = criterion(pred, img) 119 | loss.backward() 120 | 121 | optimizer.step() 122 | 123 | logger.add('losses', 'train', loss, it=it) 124 | 125 | # Print stats 126 | if (it % cfg['training']['print_every']) == 0: 127 | loss_last = logger.get_last('losses', 'train') 128 | print('[epoch %0d, it %4d] loss = %.4f' % (epoch_idx, it, loss_last)) 129 | 130 | # Evaluate if necessary 131 | if (it % eval_every) == 0: 132 | model = model.eval() 133 | # Evaluate spectrum 134 | spec_real, spec_gen = spectrum.evaluate_spectrum(dataset, model, batch_size=batch_size) 135 | spec_gen.update({'it': it}) 136 | filename = os.path.join(log_dir, f'spectrum_%08d.pkl' % it) 137 | with open(filename, 'wb') as f: 138 | pickle.dump(spec_gen, f) 139 | 140 | # Save plot of spectrum 141 | filename = os.path.join(plot_dir, f'spectrum_%08d.png' % it) 142 | plot.plot_spectrum(spec_real, spec_gen, cfg.data.resolution, filename) 143 | 144 | # Save some generated images 145 | pred = model(eval_z) 146 | filename = os.path.join(img_dir, 'samples_%08d.png' % it) 147 | save_image(pred, filename, normalize=True, value_range=(-1, 1)) 148 | 149 | # (iii) Checkpoint if necessary 150 | if (epoch_idx % save_every) == 0 and (it > 0): 151 | print('Saving checkpoint...') 152 | checkpoint_io.save('model.pt', epoch_idx=epoch_idx, it=it) 153 | logger.save_stats('stats.p') 154 | 155 | # Save model 156 | print('Saving last model...') 157 | checkpoint_io.save('model.pt', epoch_idx=epoch_idx, it=it) 158 | logger.save_stats('stats.p') 159 | -------------------------------------------------------------------------------- /gfx/teaser_disc.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/frequency_bias/1bd267e897bd2d122524cbe3f01f241b0dea75c5/gfx/teaser_disc.gif -------------------------------------------------------------------------------- /gfx/teaser_gen.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/frequency_bias/1bd267e897bd2d122524cbe3f01f241b0dea75c5/gfx/teaser_gen.gif -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.misc import construct_class_by_name 3 | from utils.spectrum import get_spectrum 4 | from torch import Tensor 5 | 6 | 7 | class MSESpectrumLoss(torch.nn.MSELoss): 8 | def __init__(self, *args, **kwargs): 9 | super(MSESpectrumLoss, self).__init__(*args, **kwargs) 10 | 11 | @staticmethod 12 | def get_log_spectrum(input): 13 | spectra = get_spectrum(input.flatten(0, 1)).unflatten(0, input.shape[:2]) 14 | spectra = spectra.mean(dim=1) # average over channels 15 | return (1 + spectra).log() 16 | 17 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 18 | input_spectrum = self.get_log_spectrum(input) 19 | target_spectrum = self.get_log_spectrum(target) 20 | return super(MSESpectrumLoss, self).forward(input_spectrum, target_spectrum) 21 | 22 | 23 | class MultiLoss(object): 24 | def __init__(self, losses, weights): 25 | self.losses = losses 26 | self.weights = weights 27 | 28 | def __call__(self, *args, **kwargs): 29 | loss = 0 30 | for loss_fn, w in zip(self.losses, self.weights): 31 | loss = loss + w * loss_fn(*args, **kwargs) 32 | return loss 33 | 34 | 35 | def get_criterion(class_name, weight=None): 36 | if not isinstance(class_name, list): 37 | return construct_class_by_name(class_name=class_name) 38 | 39 | if (weight is None) or (len(class_name) != len(weight)): 40 | raise AttributeError('Number of losses and number of weights have to match.') 41 | 42 | losses = [construct_class_by_name(class_name=n) for n in class_name] 43 | return MultiLoss(losses, weight) 44 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.pggan_generator import PGGANGenerator 2 | from models.pggan_discriminator import PGGANDiscriminator 3 | 4 | from models.direct_generator import DirectGenerator 5 | from models.mlp import MLP 6 | 7 | import os 8 | if os.path.isfile('models/stylegan3/stylegan3/train.py'): 9 | from models.stylegan2_generator import SG2Generator 10 | from models.stylegan2_discriminator import SG2Discriminator 11 | from models.stylegan3_generator import SG3Generator 12 | else: 13 | print('StyleGAN3 submodule not initialized. If you want to add StyleGAN3 support run: \n' 14 | '\tcd models/stylegan3 \n' 15 | '\tgit submodule update --init --recursive --remote') -------------------------------------------------------------------------------- /models/direct_generator.py: -------------------------------------------------------------------------------- 1 | """Optimize image directly.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | __all__ = ['DirectGenerator'] 7 | 8 | 9 | class DirectGenerator(nn.Module): 10 | def __init__(self, resolution, z, image_channels=3): 11 | super().__init__() 12 | N = len(z) 13 | imgs = torch.empty((N, image_channels, resolution, resolution)) 14 | nn.init.kaiming_normal_(imgs) 15 | self.imgs = nn.Parameter(imgs) 16 | 17 | def forward(self, idx): 18 | assert idx.ndim == 1 19 | return self.imgs[idx] 20 | 21 | def extra_repr(self): 22 | N, image_channels, resolution = self.imgs.shape[:3] 23 | s = f'{N}, {image_channels}, {resolution}, {resolution}' 24 | return s -------------------------------------------------------------------------------- /models/mlp.py: -------------------------------------------------------------------------------- 1 | """Fully-connected architecture.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | __all__ = ['MLP'] 7 | 8 | 9 | class MLP(nn.Module): 10 | def __init__(self, input_size, output_size, nhidden=3, dhidden=16, activation=nn.ReLU, bias=True): 11 | super(MLP, self).__init__() 12 | self.nhidden = nhidden 13 | if isinstance(dhidden, int): 14 | dhidden = [dhidden] * (self.nhidden + 1) # one for input layer 15 | 16 | input_layer = nn.Linear(input_size, dhidden[0], bias=bias) 17 | hidden_layers = [nn.Linear(dhidden[i], dhidden[i+1], bias=bias) for i in range(nhidden)] 18 | output_layer = nn.Linear(dhidden[nhidden], output_size, bias=bias) 19 | 20 | layers = [input_layer] + hidden_layers + [output_layer] 21 | 22 | main = [] 23 | for l in layers: 24 | main.extend([l, activation()]) 25 | main = main[:-1] # no activation after last layer 26 | 27 | self.main = nn.Sequential(*main) 28 | 29 | def forward(self, x, c=None): 30 | assert x.ndim == 2 31 | out = self.main(x) 32 | if c is not None: 33 | out = out[range(len(c)), c].unsqueeze(1) 34 | return out -------------------------------------------------------------------------------- /models/pggan_discriminator.py: -------------------------------------------------------------------------------- 1 | """Adapted from https://github.com/genforce/genforce/blob/master/models/pggan_discriminator.py 2 | Added different downsampling modes and padding modes""" 3 | # python3.7 4 | """Contains the implementation of discriminator described in PGGAN. 5 | 6 | Paper: https://arxiv.org/pdf/1710.10196.pdf 7 | 8 | Official TensorFlow implementation: 9 | https://github.com/tkarras/progressive_growing_of_gans 10 | """ 11 | 12 | import numpy as np 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from .utils import DOWNSAMPLE_FNS 18 | 19 | __all__ = ['PGGANDiscriminator'] 20 | 21 | # Resolutions allowed. 22 | _RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] 23 | 24 | # Initial resolution. 25 | _INIT_RES = 4 26 | 27 | # Default gain factor for weight scaling. 28 | _WSCALE_GAIN = np.sqrt(2.0) 29 | 30 | 31 | class PGGANDiscriminator(nn.Module): 32 | """Defines the discriminator network in PGGAN. 33 | 34 | NOTE: The discriminator takes images with `RGB` channel order and pixel 35 | range [-1, 1] as inputs. 36 | 37 | Settings for the network: 38 | 39 | (1) resolution: The resolution of the input image. 40 | (2) image_channels: Number of channels of the input image. (default: 3) 41 | (3) label_size: Size of the additional label for conditional generation. 42 | (default: 0) 43 | (4) fused_scale: Whether to fused `conv2d` and `downsample` together, 44 | resulting in `conv2d` with strides. (default: False) 45 | (5) use_wscale: Whether to use weight scaling. (default: True) 46 | (6) minibatch_std_group_size: Group size for the minibatch standard 47 | deviation layer. 0 means disable. (default: 16) 48 | (7) fmaps_base: Factor to control number of feature maps for each layer. 49 | (default: 16 << 10) 50 | (8) fmaps_max: Maximum number of feature maps in each layer. (default: 512) 51 | """ 52 | 53 | def __init__(self, 54 | resolution, 55 | image_channels=3, 56 | label_size=0, 57 | fused_scale=False, 58 | use_wscale=True, 59 | minibatch_std_group_size=16, 60 | fmaps_base=16 << 10, 61 | fmaps_max=512, 62 | downsampling_mode='nearest', 63 | padding_mode='zeros', 64 | divide_channels_by=1): 65 | """Initializes with basic settings. 66 | 67 | Raises: 68 | ValueError: If the `resolution` is not supported. 69 | """ 70 | super().__init__() 71 | 72 | if resolution not in _RESOLUTIONS_ALLOWED: 73 | raise ValueError(f'Invalid resolution: `{resolution}`!\n' 74 | f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') 75 | 76 | self.init_res = _INIT_RES 77 | self.init_res_log2 = int(np.log2(self.init_res)) 78 | self.resolution = resolution 79 | self.final_res_log2 = int(np.log2(self.resolution)) 80 | self.image_channels = image_channels 81 | self.label_size = label_size 82 | self.fused_scale = fused_scale 83 | self.use_wscale = use_wscale 84 | self.minibatch_std_group_size = minibatch_std_group_size 85 | self.fmaps_base = fmaps_base 86 | self.fmaps_max = fmaps_max 87 | self.divide_channels_by = divide_channels_by 88 | 89 | # Level of detail (used for progressive training). 90 | self.register_buffer('lod', torch.zeros(())) 91 | self.pth_to_tf_var_mapping = {'lod': 'lod'} 92 | 93 | for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1): 94 | res = 2 ** res_log2 95 | block_idx = self.final_res_log2 - res_log2 96 | 97 | # Input convolution layer for each resolution. 98 | self.add_module( 99 | f'input{block_idx}', 100 | ConvBlock(in_channels=self.image_channels, 101 | out_channels=self.get_nf(res), 102 | kernel_size=1, 103 | padding=0, 104 | padding_mode=padding_mode, 105 | use_wscale=self.use_wscale)) 106 | self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = ( 107 | f'FromRGB_lod{block_idx}/weight') 108 | self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = ( 109 | f'FromRGB_lod{block_idx}/bias') 110 | 111 | # Convolution block for each resolution (except the last one). 112 | if res != self.init_res: 113 | self.add_module( 114 | f'layer{2 * block_idx}', 115 | ConvBlock(in_channels=self.get_nf(res), 116 | out_channels=self.get_nf(res), 117 | padding_mode=padding_mode, 118 | use_wscale=self.use_wscale)) 119 | tf_layer0_name = 'Conv0' 120 | self.add_module( 121 | f'layer{2 * block_idx + 1}', 122 | ConvBlock(in_channels=self.get_nf(res), 123 | out_channels=self.get_nf(res // 2), 124 | downsample=True, 125 | downsampling_mode=downsampling_mode, 126 | padding_mode=padding_mode, 127 | fused_scale=self.fused_scale, 128 | use_wscale=self.use_wscale)) 129 | tf_layer1_name = 'Conv1_down' if self.fused_scale else 'Conv1' 130 | 131 | # Convolution block for last resolution. 132 | else: 133 | self.add_module( 134 | f'layer{2 * block_idx}', 135 | ConvBlock( 136 | in_channels=self.get_nf(res), 137 | out_channels=self.get_nf(res), 138 | padding_mode=padding_mode, 139 | use_wscale=self.use_wscale, 140 | minibatch_std_group_size=self.minibatch_std_group_size)) 141 | tf_layer0_name = 'Conv' 142 | self.add_module( 143 | f'layer{2 * block_idx + 1}', 144 | DenseBlock(in_channels=self.get_nf(res) * res * res, 145 | out_channels=self.get_nf(res // 2), 146 | use_wscale=self.use_wscale)) 147 | tf_layer1_name = 'Dense0' 148 | 149 | self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = ( 150 | f'{res}x{res}/{tf_layer0_name}/weight') 151 | self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = ( 152 | f'{res}x{res}/{tf_layer0_name}/bias') 153 | self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = ( 154 | f'{res}x{res}/{tf_layer1_name}/weight') 155 | self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = ( 156 | f'{res}x{res}/{tf_layer1_name}/bias') 157 | 158 | # Final dense block. 159 | self.add_module( 160 | f'layer{2 * block_idx + 2}', 161 | DenseBlock(in_channels=self.get_nf(res // 2), 162 | out_channels=1 + self.label_size, 163 | use_wscale=self.use_wscale, 164 | wscale_gain=1.0, 165 | activation_type='linear')) 166 | self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.weight'] = ( 167 | f'{res}x{res}/Dense1/weight') 168 | self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.bias'] = ( 169 | f'{res}x{res}/Dense1/bias') 170 | 171 | self.downsample = DownsamplingLayer(mode='blurpool') # for fade-in always use blurpool downsampling 172 | 173 | def get_nf(self, res): 174 | """Gets number of feature maps according to current resolution.""" 175 | return min(self.fmaps_base // res, self.fmaps_max) // self.divide_channels_by 176 | 177 | def forward(self, image, c=None, lod=None, **_unused_kwargs): 178 | expected_shape = (self.image_channels, self.resolution, self.resolution) 179 | if image.ndim != 4 or image.shape[1:] != expected_shape: 180 | raise ValueError(f'The input tensor should be with shape ' 181 | f'[batch_size, channel, height, width], where ' 182 | f'`channel` equals to {self.image_channels}, ' 183 | f'`height`, `width` equal to {self.resolution}!\n' 184 | f'But `{image.shape}` is received!') 185 | 186 | lod = self.lod.cpu().tolist() if lod is None else lod 187 | if lod + self.init_res_log2 > self.final_res_log2: 188 | raise ValueError(f'Maximum level-of-detail (lod) is ' 189 | f'{self.final_res_log2 - self.init_res_log2}, ' 190 | f'but `{lod}` is received!') 191 | 192 | lod = self.lod.cpu().tolist() 193 | for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1): 194 | block_idx = current_lod = self.final_res_log2 - res_log2 195 | if current_lod <= lod < current_lod + 1: 196 | x = self.__getattr__(f'input{block_idx}')(image) 197 | elif current_lod - 1 < lod < current_lod: 198 | alpha = lod - np.floor(lod) 199 | x = (self.__getattr__(f'input{block_idx}')(image) * alpha + 200 | x * (1 - alpha)) 201 | if lod < current_lod + 1: 202 | x = self.__getattr__(f'layer{2 * block_idx}')(x) 203 | x = self.__getattr__(f'layer{2 * block_idx + 1}')(x) 204 | if lod > current_lod: 205 | image = self.downsample(image) 206 | x = self.__getattr__(f'layer{2 * block_idx + 2}')(x) 207 | 208 | if c is not None: 209 | x = x[range(len(c)), c].unsqueeze(1) 210 | return x 211 | 212 | 213 | class MiniBatchSTDLayer(nn.Module): 214 | """Implements the minibatch standard deviation layer.""" 215 | 216 | def __init__(self, group_size=16, epsilon=1e-8): 217 | super().__init__() 218 | self.group_size = group_size 219 | self.epsilon = epsilon 220 | 221 | def forward(self, x): 222 | if self.group_size <= 1: 223 | return x 224 | group_size = min(self.group_size, x.shape[0]) # [NCHW] 225 | y = x.view(group_size, -1, x.shape[1], x.shape[2], x.shape[3]) # [GMCHW] 226 | y = y - torch.mean(y, dim=0, keepdim=True) # [GMCHW] 227 | y = torch.mean(y ** 2, dim=0) # [MCHW] 228 | y = torch.sqrt(y + self.epsilon) # [MCHW] 229 | y = torch.mean(y, dim=[1, 2, 3], keepdim=True) # [M111] 230 | y = y.repeat(group_size, 1, x.shape[2], x.shape[3]) # [N1HW] 231 | return torch.cat([x, y], dim=1) 232 | 233 | 234 | class DownsamplingLayer(nn.Module): 235 | """Implements the downsampling layer. 236 | 237 | Basically, this layer can be used to downsample feature maps with average 238 | pooling. 239 | """ 240 | 241 | def __init__(self, scale_factor=2, mode='avg'): 242 | super().__init__() 243 | self.scale_factor = scale_factor 244 | # Added 245 | assert mode in DOWNSAMPLE_FNS 246 | self.mode = mode 247 | self.down = DOWNSAMPLE_FNS[mode] 248 | 249 | def forward(self, x): 250 | if self.scale_factor <= 1: 251 | return x 252 | return self.down(x, scale_factor=self.scale_factor) 253 | 254 | def extra_repr(self): 255 | s = ('downsampling_mode={mode}') 256 | return s.format(**self.__dict__) 257 | 258 | 259 | class ConvBlock(nn.Module): 260 | """Implements the convolutional block. 261 | 262 | Basically, this block executes minibatch standard deviation layer (if 263 | needed), convolutional layer, activation layer, and downsampling layer ( 264 | if needed) in sequence. 265 | """ 266 | 267 | def __init__(self, 268 | in_channels, 269 | out_channels, 270 | kernel_size=3, 271 | stride=1, 272 | padding=1, 273 | padding_mode='zeros', 274 | add_bias=True, 275 | downsample=False, 276 | downsampling_mode=False, 277 | fused_scale=False, 278 | use_wscale=True, 279 | wscale_gain=_WSCALE_GAIN, 280 | activation_type='lrelu', 281 | minibatch_std_group_size=0): 282 | """Initializes with block settings. 283 | 284 | Args: 285 | in_channels: Number of channels of the input tensor. 286 | out_channels: Number of channels of the output tensor. 287 | kernel_size: Size of the convolutional kernels. (default: 3) 288 | stride: Stride parameter for convolution operation. (default: 1) 289 | padding: Padding parameter for convolution operation. (default: 1) 290 | add_bias: Whether to add bias onto the convolutional result. 291 | (default: True) 292 | downsample: Whether to downsample the result after convolution. 293 | (default: False) 294 | fused_scale: Whether to fused `conv2d` and `downsample` together, 295 | resulting in `conv2d` with strides. (default: False) 296 | use_wscale: Whether to use weight scaling. (default: True) 297 | wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN) 298 | activation_type: Type of activation. Support `linear` and `lrelu`. 299 | (default: `lrelu`) 300 | minibatch_std_group_size: Group size for the minibatch standard 301 | deviation layer. 0 means disable. (default: 0) 302 | 303 | Raises: 304 | NotImplementedError: If the `activation_type` is not supported. 305 | """ 306 | super().__init__() 307 | 308 | if minibatch_std_group_size > 1: 309 | in_channels = in_channels + 1 310 | self.mbstd = MiniBatchSTDLayer(group_size=minibatch_std_group_size) 311 | else: 312 | self.mbstd = nn.Identity() 313 | 314 | if downsample and not fused_scale: 315 | self.downsample = DownsamplingLayer(mode=downsampling_mode) 316 | else: 317 | self.downsample = nn.Identity() 318 | 319 | if downsample and fused_scale: 320 | raise RuntimeWarning('Custom downsampling and padding operations not implemented with fused_scale.') 321 | self.use_stride = True 322 | self.stride = 2 323 | self.padding = 1 324 | else: 325 | self.use_stride = False 326 | self.stride = stride 327 | self.padding = padding 328 | self.padding_mode = padding_mode 329 | 330 | weight_shape = (out_channels, in_channels, kernel_size, kernel_size) 331 | fan_in = kernel_size * kernel_size * in_channels 332 | wscale = wscale_gain / np.sqrt(fan_in) 333 | if use_wscale: 334 | self.weight = nn.Parameter(torch.randn(*weight_shape)) 335 | self.wscale = wscale 336 | else: 337 | self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale) 338 | self.wscale = 1.0 339 | 340 | if add_bias: 341 | self.bias = nn.Parameter(torch.zeros(out_channels)) 342 | else: 343 | self.bias = None 344 | 345 | if activation_type == 'linear': 346 | self.activate = nn.Identity() 347 | elif activation_type == 'lrelu': 348 | self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) 349 | else: 350 | raise NotImplementedError(f'Not implemented activation function: ' 351 | f'`{activation_type}`!') 352 | 353 | def forward(self, x): 354 | x = self.mbstd(x) 355 | weight = self.weight * self.wscale 356 | if self.use_stride: 357 | weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0) 358 | weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] + 359 | weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) * 0.25 360 | if self.padding_mode != 'zeros': 361 | _reversed_padding_repeated_twice = torch.nn.modules.utils._reverse_repeat_tuple( 362 | torch.nn.modules.utils._pair(self.padding), 2) 363 | x = F.pad(x, _reversed_padding_repeated_twice, mode=self.padding_mode) 364 | padding = torch.nn.modules.utils._pair(0) 365 | else: 366 | padding = self.padding 367 | x = F.conv2d(x, 368 | weight=weight, 369 | bias=self.bias, 370 | stride=self.stride, 371 | padding=padding) 372 | x = self.activate(x) 373 | x = self.downsample(x) 374 | return x 375 | 376 | def extra_repr(self): 377 | out_channels, in_channels, kernel_size = self.weight.shape[:3] 378 | s = (f'Conv2D({in_channels}, {out_channels}, kernel_size={kernel_size}' 379 | ', stride={stride}, use_stride={use_stride}') 380 | if self.padding != 0: 381 | s += ', padding={padding}' 382 | if self.bias is None: 383 | s += ', bias=False' 384 | if self.padding_mode != 'zeros': 385 | s += ', padding_mode={padding_mode}' 386 | s += ')' 387 | return s.format(**self.__dict__) 388 | 389 | 390 | class DenseBlock(nn.Module): 391 | """Implements the dense block. 392 | 393 | Basically, this block executes fully-connected layer, and activation layer. 394 | """ 395 | 396 | def __init__(self, 397 | in_channels, 398 | out_channels, 399 | add_bias=True, 400 | use_wscale=True, 401 | wscale_gain=_WSCALE_GAIN, 402 | activation_type='lrelu'): 403 | """Initializes with block settings. 404 | 405 | Args: 406 | in_channels: Number of channels of the input tensor. 407 | out_channels: Number of channels of the output tensor. 408 | add_bias: Whether to add bias onto the fully-connected result. 409 | (default: True) 410 | use_wscale: Whether to use weight scaling. (default: True) 411 | wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN) 412 | activation_type: Type of activation. Support `linear` and `lrelu`. 413 | (default: `lrelu`) 414 | 415 | Raises: 416 | NotImplementedError: If the `activation_type` is not supported. 417 | """ 418 | super().__init__() 419 | weight_shape = (out_channels, in_channels) 420 | wscale = wscale_gain / np.sqrt(in_channels) 421 | if use_wscale: 422 | self.weight = nn.Parameter(torch.randn(*weight_shape)) 423 | self.wscale = wscale 424 | else: 425 | self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale) 426 | self.wscale = 1.0 427 | 428 | if add_bias: 429 | self.bias = nn.Parameter(torch.zeros(out_channels)) 430 | else: 431 | self.bias = None 432 | 433 | if activation_type == 'linear': 434 | self.activate = nn.Identity() 435 | elif activation_type == 'lrelu': 436 | self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) 437 | else: 438 | raise NotImplementedError(f'Not implemented activation function: ' 439 | f'`{activation_type}`!') 440 | 441 | def forward(self, x): 442 | if x.ndim != 2: 443 | x = x.view(x.shape[0], -1) 444 | x = F.linear(x, weight=self.weight * self.wscale, bias=self.bias) 445 | x = self.activate(x) 446 | return x -------------------------------------------------------------------------------- /models/pggan_generator.py: -------------------------------------------------------------------------------- 1 | """Adapted from https://github.com/genforce/genforce/blob/master/models/pggan_generator.py 2 | Added different upsampling modes and padding modes""" 3 | # python3.7 4 | """Contains the implementation of generator described in PGGAN. 5 | 6 | Paper: https://arxiv.org/pdf/1710.10196.pdf 7 | 8 | Official TensorFlow implementation: 9 | https://github.com/tkarras/progressive_growing_of_gans 10 | """ 11 | 12 | import numpy as np 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from .utils import UPSAMPLE_FNS 18 | 19 | __all__ = ['PGGANGenerator'] 20 | 21 | # Resolutions allowed. 22 | _RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] 23 | 24 | # Initial resolution. 25 | _INIT_RES = 4 26 | 27 | # Default gain factor for weight scaling. 28 | _WSCALE_GAIN = np.sqrt(2.0) 29 | 30 | 31 | class PGGANGenerator(nn.Module): 32 | """Defines the generator network in PGGAN. 33 | 34 | NOTE: The synthesized images are with `RGB` channel order and pixel range 35 | [-1, 1]. 36 | 37 | Settings for the network: 38 | 39 | (1) resolution: The resolution of the output image. 40 | (2) z_dim: The dimension of the latent space, Z. (default: 512) 41 | (3) image_channels: Number of channels of the output image. (default: 3) 42 | (4) final_tanh: Whether to use `tanh` to control the final pixel range. 43 | (default: False) 44 | (5) label_size: Size of the additional label for conditional generation. 45 | (default: 0) 46 | (6) fused_scale: Whether to fused `upsample` and `conv2d` together, 47 | resulting in `conv2d_transpose`. (default: False) 48 | (7) use_wscale: Whether to use weight scaling. (default: True) 49 | (8) fmaps_base: Factor to control number of feature maps for each layer. 50 | (default: 16 << 10) 51 | (9) fmaps_max: Maximum number of feature maps in each layer. (default: 512) 52 | """ 53 | 54 | def __init__(self, 55 | resolution, 56 | z_dim=512, 57 | image_channels=3, 58 | final_tanh=False, 59 | label_size=0, 60 | fused_scale=False, 61 | use_wscale=True, 62 | fmaps_base=16 << 10, 63 | fmaps_max=512, 64 | upsampling_mode='nearest', 65 | padding_mode='zeros', 66 | divide_channels_by=1): 67 | """Initializes with basic settings. 68 | 69 | Raises: 70 | ValueError: If the `resolution` is not supported. 71 | """ 72 | super().__init__() 73 | 74 | if resolution not in _RESOLUTIONS_ALLOWED: 75 | raise ValueError(f'Invalid resolution: `{resolution}`!\n' 76 | f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') 77 | 78 | self.init_res = _INIT_RES 79 | self.init_res_log2 = int(np.log2(self.init_res)) 80 | self.resolution = resolution 81 | self.final_res_log2 = int(np.log2(self.resolution)) 82 | self.z_space_dim = z_dim 83 | self.image_channels = image_channels 84 | self.final_tanh = final_tanh 85 | self.label_size = label_size 86 | self.fused_scale = fused_scale 87 | self.use_wscale = use_wscale 88 | self.fmaps_base = fmaps_base 89 | self.fmaps_max = fmaps_max 90 | self.divide_channels_by = divide_channels_by 91 | 92 | # Number of convolutional layers. 93 | self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2 94 | 95 | # Level of detail (used for progressive training). 96 | self.register_buffer('lod', torch.zeros(())) 97 | self.pth_to_tf_var_mapping = {'lod': 'lod'} 98 | 99 | for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): 100 | res = 2 ** res_log2 101 | block_idx = res_log2 - self.init_res_log2 102 | is_last = res_log2 == self.final_res_log2 103 | ch_multiplier = 4 if (upsampling_mode == 'shuffle' and not is_last) else 1 # pixel shuffle needs 4x channel dim 104 | 105 | # First convolution layer for each resolution. 106 | if res == self.init_res: 107 | self.add_module( 108 | f'layer{2 * block_idx}', 109 | ConvBlock(in_channels=self.z_space_dim + self.label_size, 110 | out_channels=self.get_nf(res), 111 | kernel_size=self.init_res, 112 | padding=self.init_res - 1, 113 | padding_mode='zeros', # always start with zero padding 114 | use_wscale=self.use_wscale)) 115 | tf_layer_name = 'Dense' 116 | else: 117 | self.add_module( 118 | f'layer{2 * block_idx}', 119 | ConvBlock(in_channels=self.get_nf(res // 2), 120 | out_channels=self.get_nf(res), 121 | upsample=True, 122 | upsampling_mode=upsampling_mode, 123 | padding_mode=padding_mode, 124 | fused_scale=self.fused_scale, 125 | use_wscale=self.use_wscale)) 126 | tf_layer_name = 'Conv0_up' if self.fused_scale else 'Conv0' 127 | self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = ( 128 | f'{res}x{res}/{tf_layer_name}/weight') 129 | self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = ( 130 | f'{res}x{res}/{tf_layer_name}/bias') 131 | 132 | # Second convolution layer for each resolution. 133 | self.add_module( 134 | f'layer{2 * block_idx + 1}', 135 | ConvBlock(in_channels=self.get_nf(res), 136 | out_channels=self.get_nf(res) * ch_multiplier, 137 | use_wscale=self.use_wscale, 138 | padding_mode=padding_mode)) 139 | tf_layer_name = 'Conv' if res == self.init_res else 'Conv1' 140 | self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = ( 141 | f'{res}x{res}/{tf_layer_name}/weight') 142 | self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = ( 143 | f'{res}x{res}/{tf_layer_name}/bias') 144 | 145 | # Output convolution layer for each resolution. 146 | self.add_module( 147 | f'output{block_idx}', 148 | ConvBlock(in_channels=self.get_nf(res), 149 | out_channels=self.image_channels, 150 | kernel_size=1, 151 | padding=0, 152 | use_wscale=self.use_wscale, 153 | wscale_gain=1.0, 154 | activation_type='linear')) 155 | self.pth_to_tf_var_mapping[f'output{block_idx}.weight'] = ( 156 | f'ToRGB_lod{self.final_res_log2 - res_log2}/weight') 157 | self.pth_to_tf_var_mapping[f'output{block_idx}.bias'] = ( 158 | f'ToRGB_lod{self.final_res_log2 - res_log2}/bias') 159 | 160 | self.upsample = UpsamplingLayer(mode='nearest') # for fade-in always use nearest neighbor interpolation 161 | self.final_activate = nn.Tanh() if self.final_tanh else nn.Identity() 162 | 163 | def get_nf(self, res): 164 | """Gets number of feature maps according to current resolution.""" 165 | return min(self.fmaps_base // res, self.fmaps_max) // self.divide_channels_by 166 | 167 | def forward(self, z, label=None, lod=None, **_unused_kwargs): 168 | if z.ndim != 2 or z.shape[1] != self.z_space_dim: 169 | raise ValueError(f'Input latent code should be with shape ' 170 | f'[batch_size, latent_dim], where ' 171 | f'`latent_dim` equals to {self.z_space_dim}!\n' 172 | f'But `{z.shape}` is received!') 173 | z = self.layer0.pixel_norm(z) 174 | if self.label_size: 175 | if label is None: 176 | raise ValueError(f'Model requires an additional label ' 177 | f'(with size {self.label_size}) as input, ' 178 | f'but no label is received!') 179 | if label.ndim != 2 or label.shape != (z.shape[0], self.label_size): 180 | raise ValueError(f'Input label should be with shape ' 181 | f'[batch_size, label_size], where ' 182 | f'`batch_size` equals to that of ' 183 | f'latent codes ({z.shape[0]}) and ' 184 | f'`label_size` equals to {self.label_size}!\n' 185 | f'But `{label.shape}` is received!') 186 | z = torch.cat((z, label), dim=1) 187 | 188 | lod = self.lod.cpu().tolist() if lod is None else lod 189 | if lod + self.init_res_log2 > self.final_res_log2: 190 | raise ValueError(f'Maximum level-of-detail (lod) is ' 191 | f'{self.final_res_log2 - self.init_res_log2}, ' 192 | f'but `{lod}` is received!') 193 | 194 | x = z.view(z.shape[0], self.z_space_dim + self.label_size, 1, 1) 195 | for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): 196 | current_lod = self.final_res_log2 - res_log2 197 | if lod < current_lod + 1: 198 | block_idx = res_log2 - self.init_res_log2 199 | x = self.__getattr__(f'layer{2 * block_idx}')(x) 200 | x = self.__getattr__(f'layer{2 * block_idx + 1}')(x) 201 | if current_lod - 1 < lod <= current_lod: 202 | image = self.__getattr__(f'output{block_idx}')(x) 203 | elif current_lod < lod < current_lod + 1: 204 | alpha = np.ceil(lod) - lod 205 | image = (self.__getattr__(f'output{block_idx}')(x) * alpha + 206 | self.upsample(image) * (1 - alpha)) 207 | elif lod >= current_lod + 1: 208 | image = self.upsample(image) 209 | image = self.final_activate(image) 210 | 211 | return image 212 | 213 | 214 | class PixelNormLayer(nn.Module): 215 | """Implements pixel-wise feature vector normalization layer.""" 216 | 217 | def __init__(self, epsilon=1e-8): 218 | super().__init__() 219 | self.eps = epsilon 220 | 221 | def forward(self, x): 222 | norm = torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.eps) 223 | return x / norm 224 | 225 | 226 | class UpsamplingLayer(nn.Module): 227 | """Implements the upsampling layer. 228 | 229 | Basically, this layer can be used to upsample feature maps with bilinear or nearest 230 | neighbor interpolation or zero insertion or pixel shuffle. 231 | """ 232 | 233 | def __init__(self, scale_factor=2, mode='nearest'): 234 | super().__init__() 235 | self.scale_factor = scale_factor 236 | # Added 237 | assert mode in UPSAMPLE_FNS 238 | self.mode = mode 239 | self.up = UPSAMPLE_FNS[mode] 240 | 241 | def forward(self, x): 242 | if self.scale_factor <= 1: 243 | return x 244 | return self.up(x, scale_factor=self.scale_factor) 245 | 246 | def extra_repr(self): 247 | s = ('upsampling_mode={mode}') 248 | return s.format(**self.__dict__) 249 | 250 | 251 | class ConvBlock(nn.Module): 252 | """Implements the convolutional block. 253 | 254 | Basically, this block executes pixel-wise normalization layer, upsampling 255 | layer (if needed), convolutional layer, and activation layer in sequence. 256 | """ 257 | 258 | def __init__(self, 259 | in_channels, 260 | out_channels, 261 | kernel_size=3, 262 | stride=1, 263 | padding=1, 264 | padding_mode='zeros', 265 | add_bias=True, 266 | upsample=False, 267 | upsampling_mode='nearest', 268 | fused_scale=False, 269 | use_wscale=True, 270 | wscale_gain=_WSCALE_GAIN, 271 | activation_type='lrelu'): 272 | """Initializes with block settings. 273 | 274 | Args: 275 | in_channels: Number of channels of the input tensor. 276 | out_channels: Number of channels of the output tensor. 277 | kernel_size: Size of the convolutional kernels. (default: 3) 278 | stride: Stride parameter for convolution operation. (default: 1) 279 | padding: Padding parameter for convolution operation. (default: 1) 280 | add_bias: Whether to add bias onto the convolutional result. 281 | (default: True) 282 | upsample: Whether to upsample the input tensor before convolution. 283 | (default: False) 284 | fused_scale: Whether to fused `upsample` and `conv2d` together, 285 | resulting in `conv2d_transpose`. (default: False) 286 | use_wscale: Whether to use weight scaling. (default: True) 287 | wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN) 288 | activation_type: Type of activation. Support `linear` and `lrelu`. 289 | (default: `lrelu`) 290 | 291 | Raises: 292 | NotImplementedError: If the `activation_type` is not supported. 293 | """ 294 | super().__init__() 295 | 296 | self.pixel_norm = PixelNormLayer() 297 | 298 | if upsample and not fused_scale: 299 | self.upsample = UpsamplingLayer(mode=upsampling_mode) 300 | else: 301 | self.upsample = nn.Identity() 302 | 303 | if upsample and fused_scale: 304 | raise RuntimeWarning('Custom upsampling and padding operations not implemented with fused_scale.') 305 | self.use_conv2d_transpose = True 306 | weight_shape = (in_channels, out_channels, kernel_size, kernel_size) 307 | self.stride = 2 308 | self.padding = 1 309 | else: 310 | self.use_conv2d_transpose = False 311 | weight_shape = (out_channels, in_channels, kernel_size, kernel_size) 312 | self.stride = stride 313 | self.padding = padding 314 | self.padding_mode = padding_mode 315 | 316 | fan_in = kernel_size * kernel_size * in_channels 317 | wscale = wscale_gain / np.sqrt(fan_in) 318 | if use_wscale: 319 | self.weight = nn.Parameter(torch.randn(*weight_shape)) 320 | self.wscale = wscale 321 | else: 322 | self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale) 323 | self.wscale = 1.0 324 | 325 | if add_bias: 326 | self.bias = nn.Parameter(torch.zeros(out_channels)) 327 | else: 328 | self.bias = None 329 | 330 | if activation_type == 'linear': 331 | self.activate = nn.Identity() 332 | elif activation_type == 'lrelu': 333 | self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) 334 | else: 335 | raise NotImplementedError(f'Not implemented activation function: ' 336 | f'`{activation_type}`!') 337 | 338 | def forward(self, x): 339 | x = self.pixel_norm(x) 340 | x = self.upsample(x) 341 | weight = self.weight * self.wscale 342 | if self.use_conv2d_transpose: 343 | weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0) 344 | weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] + 345 | weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) 346 | x = F.conv_transpose2d(x, 347 | weight=weight, 348 | bias=self.bias, 349 | stride=self.stride, 350 | padding=self.padding) 351 | else: 352 | if self.padding_mode != 'zeros': 353 | _reversed_padding_repeated_twice = torch.nn.modules.utils._reverse_repeat_tuple(torch.nn.modules.utils._pair(self.padding), 2) 354 | x = F.pad(x, _reversed_padding_repeated_twice, mode=self.padding_mode) 355 | padding = torch.nn.modules.utils._pair(0) 356 | else: 357 | padding = self.padding 358 | x = F.conv2d(x, 359 | weight=weight, 360 | bias=self.bias, 361 | stride=self.stride, 362 | padding=padding) 363 | x = self.activate(x) 364 | return x 365 | 366 | def extra_repr(self): 367 | out_channels, in_channels, kernel_size = self.weight.shape[:3] 368 | s = (f'Conv2D({in_channels}, {out_channels}, kernel_size={kernel_size}' 369 | ', stride={stride}, use_conv2d_transpose={use_conv2d_transpose}') 370 | if self.padding != 0: 371 | s += ', padding={padding}' 372 | if self.bias is None: 373 | s += ', bias=False' 374 | if self.padding_mode != 'zeros': 375 | s += ', padding_mode={padding_mode}' 376 | s += ')' 377 | return s.format(**self.__dict__) -------------------------------------------------------------------------------- /models/stylegan2_discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .stylegan3 import SG2Discriminator as DBase 3 | 4 | 5 | class SG2Discriminator(DBase): 6 | def __init__(self, resolution, label_size=0, divide_channels_by=1, **kwargs): 7 | default_kwargs = { 8 | 'channel_base': 32768 // divide_channels_by, 9 | 'channel_max': 512 // divide_channels_by, 10 | 'c_dim': label_size+1, 11 | 'img_channels': 3, 12 | } 13 | super(SG2Discriminator, self).__init__(img_resolution=resolution, **default_kwargs, **kwargs) 14 | 15 | def forward(self, img, c=None): 16 | if c is None: 17 | c = torch.empty(img.shape[0], 0).to(img.device) 18 | else: 19 | c = c.unsqueeze(1) 20 | return super(SG2Discriminator, self).forward(img, c=c, update_emas=False) -------------------------------------------------------------------------------- /models/stylegan2_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .stylegan3 import SG2Generator as GBase 3 | 4 | 5 | class SG2Generator(GBase): 6 | def __init__(self, resolution, z_dim, divide_channels_by=1, **kwargs): 7 | default_kwargs = { 8 | 'channel_base': 32768 // divide_channels_by, 9 | 'channel_max': 512 // divide_channels_by, 10 | 'c_dim': 0, 11 | 'w_dim': z_dim, 12 | 'img_channels': 3, 13 | } 14 | super(SG2Generator, self).__init__(img_resolution=resolution, z_dim=z_dim, **default_kwargs, **kwargs) 15 | 16 | def forward(self, z): 17 | c = torch.empty(z.shape[0], 0).to(z.device) 18 | return super(SG2Generator, self).forward(z, c=c, truncation_psi=1, truncation_cutoff=None, update_emas=False) -------------------------------------------------------------------------------- /models/stylegan3/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from pathlib import Path 4 | 5 | current_path = os.getcwd() 6 | 7 | module_path = Path(__file__).parent / 'stylegan3' 8 | sys.path.insert(0, str(module_path.resolve())) 9 | os.chdir(module_path) 10 | 11 | from training.networks_stylegan2 import Generator as SG2Generator, Discriminator as SG2Discriminator 12 | from training.networks_stylegan3 import Generator as SG3Generator 13 | 14 | os.chdir(current_path) 15 | sys.path.pop(0) -------------------------------------------------------------------------------- /models/stylegan3_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .stylegan3 import SG3Generator as GBase 3 | 4 | 5 | class SG3Generator(GBase): 6 | def __init__(self, resolution, z_dim, divide_channels_by=1, **kwargs): 7 | default_kwargs = { 8 | 'channel_base': 32768 // divide_channels_by, 9 | 'channel_max': 512 // divide_channels_by, 10 | 'c_dim': 0, 11 | 'w_dim': z_dim, 12 | 'img_channels': 3, 13 | } 14 | super(SG3Generator, self).__init__(img_resolution=resolution, z_dim=z_dim, **default_kwargs, **kwargs) 15 | 16 | def forward(self, z): 17 | c = torch.empty(z.shape[0], 0).to(z.device) 18 | return super(SG3Generator, self).forward(z, c=c, truncation_psi=1, truncation_cutoff=None, update_emas=False) -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def up_bilinear(x, scale_factor): 6 | return F.interpolate(x, scale_factor=scale_factor, mode='bilinear', align_corners=False) 7 | 8 | 9 | def up_nearest(x, scale_factor): 10 | return F.interpolate(x, scale_factor=scale_factor, mode='nearest') 11 | 12 | 13 | def up_zeros(x, scale_factor): 14 | if scale_factor != 2: 15 | raise NotImplementedError 16 | assert x.ndim == 4 17 | uph = torch.stack([x, torch.zeros_like(x)], dim=-1).flatten(-2, -1) 18 | uphw = torch.stack([uph, torch.zeros_like(uph)], dim=-2).flatten(-3, -2) 19 | return uphw 20 | 21 | 22 | def up_shuffle(x, scale_factor): 23 | return F.pixel_shuffle(x, upscale_factor=scale_factor) 24 | 25 | 26 | UPSAMPLE_FNS = { 27 | 'bilinear': up_bilinear, 28 | 'nearest': up_nearest, 29 | 'zeros': up_zeros, 30 | 'shuffle': up_shuffle, 31 | } 32 | 33 | 34 | def down_avg(x, scale_factor): 35 | return F.avg_pool2d(x, kernel_size=scale_factor, stride=scale_factor, padding=0) 36 | 37 | 38 | def down_stride(x, scale_factor): 39 | if scale_factor != 2: 40 | raise NotImplementedError 41 | assert x.ndim == 4 42 | return x[..., ::2, ::2] 43 | 44 | 45 | def down_blurpool(x, scale_factor): 46 | f = torch.tensor([[1, 3, 1], [3, 9, 3], [1, 3, 1]], dtype=x.dtype, device=x.device) / 25. 47 | f = f.flip(list(range(f.ndim))) 48 | 49 | # Pad input with reflection padding 50 | x = torch.nn.functional.pad(x, (1,1,1,1), mode='reflect') 51 | 52 | # Convolve with the filter to filter high frequencies. 53 | num_channels = x.shape[1] 54 | f = f.view(1, 1, *f.shape).repeat(num_channels, 1, 1, 1) 55 | x = F.conv2d(input=x, weight=f, groups=num_channels) 56 | 57 | return down_avg(x, scale_factor) 58 | 59 | 60 | DOWNSAMPLE_FNS = { 61 | 'avg': down_avg, 62 | 'stride': down_stride, 63 | 'blurpool': down_blurpool, 64 | } -------------------------------------------------------------------------------- /scripts/demo_discriminator_testbed.sh: -------------------------------------------------------------------------------- 1 | python discriminator_testbed.py baboon64/pggan configs/discriminator_testbed/pggan.yaml 2 | python eval_discriminator.py baboon64/pggan --psnr --image-evolution --spectrum-evolution --spectrum-error-evolution -------------------------------------------------------------------------------- /scripts/demo_generator_testbed.sh: -------------------------------------------------------------------------------- 1 | python generator_testbed.py baboon64/pggan configs/generator_testbed/pggan.yaml 2 | python eval_generator.py baboon64/pggan --psnr --image-evolution --spectrum-evolution --spectrum-error-evolution -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .checkpoints import CheckpointIO 2 | from .logger import Logger -------------------------------------------------------------------------------- /utils/checkpoints.py: -------------------------------------------------------------------------------- 1 | """From https://github.com/LMescheder/GAN_stability/blob/master/gan_training/checkpoints.py""" 2 | 3 | import os 4 | import urllib 5 | import torch 6 | from torch.utils import model_zoo 7 | 8 | 9 | class CheckpointIO(object): 10 | ''' CheckpointIO class. 11 | 12 | It handles saving and loading checkpoints. 13 | 14 | Args: 15 | checkpoint_dir (str): path where checkpoints are saved 16 | ''' 17 | 18 | def __init__(self, checkpoint_dir='./chkpts', **kwargs): 19 | self.module_dict = kwargs 20 | self.checkpoint_dir = checkpoint_dir 21 | if not os.path.exists(checkpoint_dir): 22 | os.makedirs(checkpoint_dir) 23 | 24 | def register_modules(self, **kwargs): 25 | ''' Registers modules in current module dictionary. 26 | ''' 27 | self.module_dict.update(kwargs) 28 | 29 | def save(self, filename, **kwargs): 30 | ''' Saves the current module dictionary. 31 | 32 | Args: 33 | filename (str): name of output file 34 | ''' 35 | if not os.path.isabs(filename): 36 | filename = os.path.join(self.checkpoint_dir, filename) 37 | 38 | outdict = kwargs 39 | for k, v in self.module_dict.items(): 40 | outdict[k] = v.state_dict() 41 | torch.save(outdict, filename) 42 | 43 | def load(self, filename): 44 | '''Loads a module dictionary from local file or url. 45 | 46 | Args: 47 | filename (str): name of saved module dictionary 48 | ''' 49 | if is_url(filename): 50 | return self.load_url(filename) 51 | else: 52 | return self.load_file(filename) 53 | 54 | def load_file(self, filename): 55 | '''Loads a module dictionary from file. 56 | 57 | Args: 58 | filename (str): name of saved module dictionary 59 | ''' 60 | 61 | if not os.path.isabs(filename): 62 | filename = os.path.join(self.checkpoint_dir, filename) 63 | 64 | if os.path.exists(filename): 65 | print(filename) 66 | print('=> Loading checkpoint from local file...') 67 | state_dict = torch.load(filename) 68 | scalars = self.parse_state_dict(state_dict) 69 | return scalars 70 | else: 71 | raise FileNotFoundError 72 | 73 | def load_url(self, url): 74 | '''Load a module dictionary from url. 75 | 76 | Args: 77 | url (str): url to saved model 78 | ''' 79 | print(url) 80 | print('=> Loading checkpoint from url...') 81 | state_dict = model_zoo.load_url(url, progress=True) 82 | scalars = self.parse_state_dict(state_dict) 83 | return scalars 84 | 85 | def parse_state_dict(self, state_dict): 86 | '''Parse state_dict of model and return scalars. 87 | 88 | Args: 89 | state_dict (dict): State dict of model 90 | ''' 91 | 92 | for k, v in self.module_dict.items(): 93 | if k in state_dict: 94 | v.load_state_dict(state_dict[k]) 95 | else: 96 | print('Warning: Could not find %s in checkpoint!' % k) 97 | scalars = {k: v for k, v in state_dict.items() 98 | if k not in self.module_dict} 99 | return scalars 100 | 101 | 102 | def is_url(url): 103 | scheme = urllib.parse.urlparse(url).scheme 104 | return scheme in ('http', 'https') -------------------------------------------------------------------------------- /utils/gan_training.py: -------------------------------------------------------------------------------- 1 | """From https://github.com/LMescheder/GAN_stability/blob/master/gan_training/train.py""" 2 | 3 | import torch 4 | 5 | 6 | def toggle_grad(model, requires_grad): 7 | for p in model.parameters(): 8 | p.requires_grad_(requires_grad) 9 | 10 | 11 | def compute_grad2(d_out, x_in): 12 | batch_size = x_in.size(0) 13 | grad_dout = torch.autograd.grad(outputs=d_out.sum(), inputs=x_in, 14 | create_graph=True, retain_graph=True, only_inputs=True)[0] 15 | grad_dout2 = grad_dout.pow(2) 16 | assert(grad_dout2.size() == x_in.size()) 17 | reg = grad_dout2.reshape(batch_size, -1).sum(1) 18 | return reg 19 | 20 | 21 | def update_average(model_tgt, model_src, beta): 22 | toggle_grad(model_src, False) 23 | toggle_grad(model_tgt, False) 24 | 25 | param_dict_src = dict(model_src.named_parameters()) 26 | 27 | for p_name, p_tgt in model_tgt.named_parameters(): 28 | p_src = param_dict_src[p_name] 29 | assert(p_src is not p_tgt) 30 | p_tgt.copy_(beta*p_tgt + (1. - beta)*p_src) -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | """From https://github.com/LMescheder/GAN_stability/blob/master/gan_training/logger.py""" 2 | 3 | import pickle 4 | import os 5 | import torchvision 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, log_dir='./logs', img_dir='./imgs', 10 | monitoring=None, monitoring_dir=None): 11 | self.stats = dict() 12 | self.log_dir = log_dir 13 | self.img_dir = img_dir 14 | 15 | if not os.path.exists(log_dir): 16 | os.makedirs(log_dir) 17 | 18 | if not os.path.exists(img_dir): 19 | os.makedirs(img_dir) 20 | 21 | if not (monitoring is None or monitoring == 'none'): 22 | self.setup_monitoring(monitoring, monitoring_dir) 23 | else: 24 | self.monitoring = None 25 | self.monitoring_dir = None 26 | 27 | def setup_monitoring(self, monitoring, monitoring_dir=None): 28 | self.monitoring = monitoring 29 | self.monitoring_dir = monitoring_dir 30 | 31 | if monitoring == 'telemetry': 32 | import telemetry 33 | self.tm = telemetry.ApplicationTelemetry() 34 | if self.tm.get_status() == 0: 35 | print('Telemetry successfully connected.') 36 | elif monitoring == 'tensorboard': 37 | import tensorboardX 38 | self.tb = tensorboardX.SummaryWriter(monitoring_dir) 39 | else: 40 | raise NotImplementedError('Monitoring tool "%s" not supported!' 41 | % monitoring) 42 | 43 | def add(self, category, k, v, it): 44 | if category not in self.stats: 45 | self.stats[category] = {} 46 | 47 | if k not in self.stats[category]: 48 | self.stats[category][k] = [] 49 | 50 | self.stats[category][k].append((it, v)) 51 | 52 | k_name = '%s/%s' % (category, k) 53 | if self.monitoring == 'telemetry': 54 | self.tm.metric_push_async({ 55 | 'metric': k_name, 'value': v, 'it': it 56 | }) 57 | elif self.monitoring == 'tensorboard': 58 | self.tb.add_scalar(k_name, v, it) 59 | 60 | def add_imgs(self, imgs, class_name, it): 61 | outdir = os.path.join(self.img_dir, class_name) 62 | if not os.path.exists(outdir): 63 | os.makedirs(outdir) 64 | outfile = os.path.join(outdir, '%08d.png' % it) 65 | 66 | imgs = imgs / 2 + 0.5 67 | imgs = torchvision.utils.make_grid(imgs) 68 | torchvision.utils.save_image(imgs, outfile, nrow=8) 69 | 70 | if self.monitoring == 'tensorboard': 71 | self.tb.add_image(class_name, imgs, it) 72 | 73 | def get_last(self, category, k, default=0.): 74 | if category not in self.stats: 75 | return default 76 | elif k not in self.stats[category]: 77 | return default 78 | else: 79 | return self.stats[category][k][-1][1] 80 | 81 | def save_stats(self, filename): 82 | filename = os.path.join(self.log_dir, filename) 83 | with open(filename, 'wb') as f: 84 | pickle.dump(self.stats, f) 85 | 86 | def load_stats(self, filename): 87 | filename = os.path.join(self.log_dir, filename) 88 | if not os.path.exists(filename): 89 | print('Warning: file "%s" does not exist!' % filename) 90 | return 91 | 92 | try: 93 | with open(filename, 'rb') as f: 94 | self.stats = pickle.load(f) 95 | except EOFError: 96 | print('Warning: log file corrupted!') -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .misc import to_image 3 | 4 | 5 | def psnr(pred, target): 6 | pred = to_image(pred) 7 | target = to_image(target) 8 | 9 | mse = torch.nn.functional.mse_loss(pred.to(torch.float), target.to(torch.float), reduction='none') 10 | mse = mse.mean([1,2,3]) 11 | max_i2 = 255**2 12 | 13 | return 10 * torch.log10(max_i2/mse) -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import re 3 | import types 4 | import importlib 5 | import yaml 6 | import torch 7 | import imageio 8 | from torchvision.transforms import Pad 9 | 10 | from distutils.util import strtobool 11 | from typing import Any, Tuple 12 | 13 | 14 | # ------------------------------------------------------------------------------------------ 15 | 16 | # Taken from https://github.com/NVlabs/stylegan3/blob/main/dnnlib/util.py 17 | class EasyDict(dict): 18 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 19 | 20 | def __getattr__(self, name: str) -> Any: 21 | try: 22 | return self[name] 23 | except KeyError: 24 | raise AttributeError(name) 25 | 26 | def __setattr__(self, name: str, value: Any) -> None: 27 | self[name] = value 28 | 29 | def __delattr__(self, name: str) -> None: 30 | del self[name] 31 | 32 | 33 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: 34 | """Searches for the underlying module behind the name to some python object. 35 | Returns the module and the object name (original name with module part removed).""" 36 | 37 | # allow convenience shorthands, substitute them by full names 38 | obj_name = re.sub("^np.", "numpy.", obj_name) 39 | obj_name = re.sub("^tf.", "tensorflow.", obj_name) 40 | 41 | # list alternatives for (module_name, local_obj_name) 42 | parts = obj_name.split(".") 43 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] 44 | 45 | # try each alternative in turn 46 | for module_name, local_obj_name in name_pairs: 47 | try: 48 | module = importlib.import_module(module_name) # may raise ImportError 49 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 50 | return module, local_obj_name 51 | except: 52 | pass 53 | 54 | # maybe some of the modules themselves contain errors? 55 | for module_name, _local_obj_name in name_pairs: 56 | try: 57 | importlib.import_module(module_name) # may raise ImportError 58 | except ImportError: 59 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): 60 | raise 61 | 62 | # maybe the requested attribute is missing? 63 | for module_name, local_obj_name in name_pairs: 64 | try: 65 | module = importlib.import_module(module_name) # may raise ImportError 66 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 67 | except ImportError: 68 | pass 69 | 70 | # we are out of luck, but we have no idea why 71 | raise ImportError(obj_name) 72 | 73 | 74 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: 75 | """Traverses the object name and returns the last (rightmost) python object.""" 76 | if obj_name == '': 77 | return module 78 | obj = module 79 | for part in obj_name.split("."): 80 | obj = getattr(obj, part) 81 | return obj 82 | 83 | 84 | def get_obj_by_name(name: str) -> Any: 85 | """Finds the python object with the given name.""" 86 | module, obj_name = get_module_from_obj_name(name) 87 | return get_obj_from_module(module, obj_name) 88 | 89 | 90 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: 91 | """Finds the python object with the given name and calls it as a function.""" 92 | assert func_name is not None 93 | func_obj = get_obj_by_name(func_name) 94 | assert callable(func_obj) 95 | return func_obj(*args, **kwargs) 96 | 97 | 98 | def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: 99 | """Finds the python class with the given name and constructs it with the given arguments.""" 100 | return call_func_by_name(*args, func_name=class_name, **kwargs) 101 | 102 | 103 | # ------------------------------------------------------------------------------------------ 104 | 105 | # Custom functions 106 | def to_easydict(d): 107 | """ Convert python dict to EasyDict recursively. 108 | Args: 109 | d (dict): dictionary to convert 110 | 111 | Returns: 112 | EasyDict 113 | 114 | """ 115 | d = d.copy() 116 | for k, v in d.items(): 117 | if isinstance(d[k], dict): 118 | d[k] = to_easydict(d[k]) 119 | return EasyDict(d) 120 | 121 | 122 | def to_dict(d): 123 | """ Convert EasyDict to python dict recursively. 124 | Args: 125 | d (EasyDict): dictionary to convert 126 | 127 | Returns: 128 | dict 129 | 130 | """ 131 | d = d.copy() 132 | for k, v in d.items(): 133 | if isinstance(d[k], dict): 134 | d[k] = to_dict(d[k]) 135 | return dict(d) 136 | 137 | 138 | def save_config(outpath, config): 139 | config = to_dict(config) 140 | with open(outpath, 'w') as f: 141 | yaml.safe_dump(config, f) 142 | 143 | 144 | def load_config(path, default_path=None): 145 | ''' Loads config file. 146 | 147 | Args: 148 | path (str): path to config file 149 | default_path (str): path to default config 150 | ''' 151 | # Load configuration from file itself 152 | if path is not None: 153 | with open(path, 'r') as f: 154 | cfg_special = yaml.safe_load(f) 155 | else: 156 | cfg_special = dict() 157 | 158 | if default_path is not None: 159 | with open(default_path, 'r') as f: 160 | cfg = yaml.safe_load(f) 161 | else: 162 | cfg = dict() 163 | 164 | # Include main configuration 165 | update_recursive(cfg, cfg_special) 166 | 167 | return to_easydict(cfg) 168 | 169 | 170 | def update_recursive(dict1, dict2, allow_new=True): 171 | ''' Update two config dictionaries recursively. 172 | 173 | Args: 174 | dict1 (dict): first dictionary to be updated 175 | dict2 (dict): second dictionary which entries should be used 176 | allow_new(bool): allow adding new keys 177 | 178 | ''' 179 | for k, v in dict2.items(): 180 | # Add item if not yet in dict1 181 | if k not in dict1: 182 | if not allow_new: 183 | raise RuntimeError(f'New key {k} in dict2 but allow_new=False') 184 | dict1[k] = {} if isinstance(v, dict) else None 185 | # Update 186 | if isinstance(dict1[k], dict): 187 | update_recursive(dict1[k], v) 188 | else: 189 | if isinstance(v, str) and v.lower() in ['true', 'false']: 190 | v = strtobool(v.lower()) 191 | if not isinstance(v, list) and not isinstance(dict1[k], list): 192 | argtype = type(dict1[k]) 193 | if argtype is not type(None) and v is not None: 194 | v = argtype(v) 195 | 196 | if dict1[k] is not None: 197 | print(f'Changing {k} ---- {dict1[k]} to {v}') 198 | 199 | dict1[k] = v 200 | 201 | 202 | def args_to_dict(args): 203 | out = {} 204 | for k, v in zip(args[::2], args[1::2]): 205 | assert k.startswith('--'), f'Can only process kwargs starting with "--" but key is {k}' 206 | k = k.replace('--', '', 1) 207 | keys = k.split(':') 208 | Nk = len(keys) 209 | curr_dct = out 210 | for i, k_i in enumerate(keys): 211 | if i == (Nk-1): 212 | curr_dct[k_i] = v 213 | else: 214 | if k_i not in curr_dct: 215 | curr_dct[k_i] = {} 216 | curr_dct = curr_dct[k_i] 217 | 218 | return to_easydict(out) 219 | 220 | 221 | def count_trainable_parameters(model): 222 | return sum([p.numel() for p in model.parameters() if p.requires_grad]) 223 | 224 | 225 | def to_image(timg): 226 | """Convert tensor image in range [-1, 1] to range [0, 255].""" 227 | assert timg.dtype == torch.float 228 | return ((timg / 2 + 0.5).clamp(0, 1) * 255).to(torch.uint8) 229 | 230 | 231 | def make_video(images, filenpath, **kwargs): 232 | # ensure output sizes are even 233 | H, W = images[0].size 234 | pad = Pad((0, 0, H%2, W%2), fill=1) 235 | images = [pad(img) for img in images] 236 | imageio.mimwrite(filenpath, images, **kwargs) 237 | -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from math import sqrt 4 | import matplotlib.pyplot as plt 5 | from distutils.spawn import find_executable 6 | 7 | 8 | HAS_LATEX = find_executable('latex') 9 | if HAS_LATEX: # use LaTeX fonts in plots 10 | plt.rc('text', usetex=True) 11 | plt.rc('font', family='serif', size=26) 12 | else: 13 | plt.rc('font', family='serif', size=20) 14 | 15 | 16 | def plot_std(mean, std, ax, x=None, **kwargs): 17 | if x is None: 18 | x = range(len(mean)) 19 | 20 | l = ax.plot(x, mean, **kwargs) 21 | ax.fill_between(x, mean - std, mean + std, color=l[0]._color, alpha=0.3) 22 | 23 | 24 | def plot_spectrum(spec_real, spec_gen, resolution, filename): 25 | fig, ax = plt.subplots(1) 26 | mean_real, std_real = spec_real['mean'][1:], spec_real['std'][1:] 27 | mean_gen, std_gen = spec_gen['mean'][1:], spec_gen['std'][1:] 28 | 29 | plot_std(mean_real, std_real, ax, c='C0', ls='--', label='ground truth') 30 | plot_std(mean_gen, std_gen, ax, c='orange', ls='-', label='prediction') 31 | 32 | # Settings for x-axis 33 | N = sqrt(2) * resolution 34 | fnyq = (N - 1) // 2 35 | x_ticks = [0, fnyq / 2, fnyq] 36 | x_ticklabels = ['%.1f' % (l / fnyq) for l in x_ticks] 37 | 38 | ax.set_xlim(0, fnyq) 39 | xlabel = r'$f/f_{nyq}$' if HAS_LATEX else 'f/fnyq' 40 | ax.set_xlabel(xlabel) 41 | ax.set_xticks(x_ticks) 42 | ax.set_xticklabels(x_ticklabels) 43 | 44 | # Settings for y-axis 45 | ax.set_ylabel(r'Spectral density') 46 | if std_gen.isfinite().all(): 47 | ymin = (mean_real-std_real).min() 48 | if ymin < 0: 49 | ymin = mean_real.min() 50 | y_lim = ymin * 0.1, (mean_real+std_real).max() * 1.1 51 | else: 52 | y_lim = mean_real.min() * 0.1, mean_real.max() * 1.1 53 | ax.set_ylim(y_lim) 54 | ax.set_yscale('log') 55 | 56 | # Legend 57 | fs = plt.rcParams['font.size'] 58 | plt.rcParams.update({'font.size': fs*0.75}) 59 | ax.legend(loc='upper right', ncol=2, columnspacing=1) 60 | plt.rc('font', size=fs) 61 | 62 | os.makedirs(os.path.dirname(filename), exist_ok=True) 63 | plt.savefig(filename, bbox_inches='tight') 64 | plt.close(fig) 65 | 66 | 67 | def plot_spectrum_error_evolution(spec_real, spec_gen_all, resolution, filename): 68 | fig, ax = plt.subplots(1) 69 | 70 | # ensure spec_gen_all are in correct order 71 | spec_gen_all = sorted(spec_gen_all, key=lambda x: x['it']) 72 | 73 | # compute error image 74 | niter = spec_gen_all[-1]['it'] 75 | nspec = len(spec_gen_all) 76 | iters = torch.linspace(0, niter, nspec).to(torch.long) 77 | mean_real = spec_real['mean'][1:] 78 | error_img = torch.empty(nspec, len(mean_real)) 79 | for i, spec_gen in enumerate(spec_gen_all): 80 | mean_gen = spec_gen['mean'][1:] 81 | assert spec_gen['it'] == iters[i] 82 | error_img[i] = mean_gen / mean_real - 1 83 | 84 | # clamp at 100% relative error 85 | error_img.clamp_(-1, 1) 86 | 87 | # plot 88 | cmap = plt.cm.get_cmap('bwr') 89 | aspect = len(mean_real) / nspec # aspect=H/W, make image square 90 | h = ax.imshow(error_img, cmap=cmap, vmin=-1, vmax=1, origin='lower', aspect=aspect) 91 | 92 | # Settings for x-axis 93 | N = sqrt(2) * resolution 94 | fnyq = (N - 1) // 2 95 | x_ticks = [0, fnyq / 2, fnyq] 96 | x_ticklabels = ['%.1f' % (l / fnyq) for l in x_ticks] 97 | 98 | ax.set_xlim(0, fnyq) 99 | xlabel = r'$f/f_{nyq}$' if HAS_LATEX else 'f/fnyq' 100 | ax.set_xlabel(xlabel) 101 | ax.set_xticks(x_ticks) 102 | ax.set_xticklabels(x_ticklabels) 103 | 104 | # Settings for y-axis 105 | y_ticks = [0, nspec // 2, nspec] 106 | y_ticklabels = [t // 1000 for t in [0, niter // 2, niter]] 107 | 108 | ax.set_ylabel(r'Training Iteration [it/1000]') 109 | ax.set_yticks(y_ticks) 110 | ax.set_yticklabels(y_ticklabels) 111 | 112 | # Colorbar 113 | fig.colorbar(h, ticks=[-1, 0, 1], fraction=0.046, pad=0.04) 114 | 115 | os.makedirs(os.path.dirname(filename), exist_ok=True) 116 | plt.savefig(filename, bbox_inches='tight') 117 | plt.close(fig) 118 | 119 | 120 | if __name__ == '__main__': 121 | from glob import glob 122 | import pickle 123 | spec_file_real = '../data/baboon/spectrum64_N1.pkl' 124 | traindir = '../output/generator_testbed/pggan' 125 | evaldir = os.path.join(traindir, 'eval') 126 | resolution = 64 127 | 128 | with open(spec_file_real, 'rb') as f: 129 | spec_real = pickle.load(f) 130 | 131 | spec_files_gen_all = glob(os.path.join(traindir, 'logs', 'spectrum_*.pkl')) 132 | spec_gen_all = [] 133 | for path in spec_files_gen_all: 134 | with open(path, 'rb') as f: 135 | stats = pickle.load(f) 136 | spec_gen_all.append(stats) 137 | 138 | filename = os.path.join(evaldir, 'spectrum_error_evolution.png') 139 | plot_spectrum_error_evolution(spec_real, spec_gen_all, resolution, filename) -------------------------------------------------------------------------------- /utils/spectrum.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import torch 4 | from torch.fft import fftn 5 | from math import sqrt, ceil 6 | 7 | 8 | def resolution_to_spectrum_length(res): 9 | res_spec = (res+1) // 2 10 | res_azim_avg = ceil(sqrt(2) * res_spec) 11 | return res_azim_avg 12 | 13 | 14 | def roll_quadrants(data): 15 | """ 16 | Shift low frequencies to the center of fourier transform, i.e. [-N/2, ..., +N/2] -> [0, ..., N-1] 17 | Args: 18 | data: fourier transform, (NxHxW) 19 | 20 | Returns: 21 | Shifted fourier transform. 22 | """ 23 | dim = data.ndim - 1 24 | 25 | if dim != 2: 26 | raise AttributeError(f'Data must be 2d but it is {dim}d.') 27 | if any(s % 2 == 0 for s in data.shape[1:]): 28 | raise RuntimeWarning('Roll quadrants for 2d input should only be used with uneven spatial sizes.') 29 | 30 | # for each dimension swap left and right half 31 | dims = tuple(range(1, dim + 1)) # add one for batch dimension 32 | shifts = torch.tensor(data.shape[1:]).div(2, rounding_mode='floor') # N/2 if N even, (N-1)/2 if N odd 33 | return data.roll(shifts.tolist(), dims=dims) 34 | 35 | 36 | def batch_fft(data, normalize=False): 37 | """ 38 | Compute fourier transform of batch. 39 | Args: 40 | data: input tensor, (NxHxW) 41 | 42 | Returns: 43 | Batch fourier transform of input data. 44 | """ 45 | 46 | dim = data.ndim - 1 # subtract one for batch dimension 47 | if dim != 2: 48 | raise AttributeError(f'Data must be 2d but it is {dim}d.') 49 | 50 | dims = tuple(range(1, dim + 1)) # add one for batch dimension 51 | 52 | if not torch.is_complex(data): 53 | data = torch.complex(data, torch.zeros_like(data)) 54 | freq = fftn(data, dim=dims, norm='ortho' if normalize else 'backward') 55 | 56 | return freq 57 | 58 | 59 | def azimuthal_average(image, center=None): 60 | # modified to tensor inputs from https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/ 61 | """ 62 | Calculate the azimuthally averaged radial profile. 63 | Requires low frequencies to be at the center of the image. 64 | Args: 65 | image: Batch of 2D images, NxHxW 66 | center: The [x,y] pixel coordinates used as the center. The default is 67 | None, which then uses the center of the image (including 68 | fracitonal pixels). 69 | 70 | Returns: 71 | Azimuthal average over the image around the center 72 | """ 73 | # Check input shapes 74 | assert center is None or (len(center) == 2), f'Center has to be None or len(center)=2 ' \ 75 | f'(but it is len(center)={len(center)}.' 76 | # Calculate the indices from the image 77 | H, W = image.shape[-2:] 78 | h, w = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 79 | 80 | if center is None: 81 | center = torch.tensor([(w.max() - w.min()) / 2.0, (h.max() - h.min()) / 2.0]) 82 | 83 | # Compute radius for each pixel wrt center 84 | r = torch.stack([w - center[0], h - center[1]]).norm(2, 0) 85 | 86 | # Get sorted radii 87 | r_sorted, ind = r.flatten().sort() 88 | i_sorted = image.flatten(-2, -1)[..., ind] 89 | 90 | # Get the integer part of the radii (bin size = 1) 91 | r_int = r_sorted.long() # attribute to the smaller integer 92 | 93 | # Find all pixels that fall within each radial bin. 94 | deltar = r_int[1:] - r_int[:-1] # Assumes all radii represented, computes bin change between subsequent radii 95 | rind = torch.where(deltar)[0] # location of changed radius 96 | 97 | # compute number of elements in each bin 98 | nind = rind + 1 # number of elements = idx + 1 99 | nind = torch.cat([torch.tensor([0]), nind, torch.tensor([H * W])]) # add borders 100 | nr = nind[1:] - nind[:-1] # number of radius bin, i.e. counter for bins belonging to each radius 101 | 102 | # Cumulative sum to figure out sums for each radius bin 103 | if H % 2 == 0: 104 | raise NotImplementedError('Not sure if implementation correct, please check') 105 | rind = torch.cat([torch.tensor([0]), rind, torch.tensor([H * W - 1])]) # add borders 106 | else: 107 | rind = torch.cat([rind, torch.tensor([H * W - 1])]) # add borders 108 | csim = i_sorted.cumsum(-1, dtype=torch.float64) # integrate over all values with smaller radius 109 | tbin = csim[..., rind[1:]] - csim[..., rind[:-1]] 110 | # add mean 111 | tbin = torch.cat([csim[:, 0:1], tbin], 1) 112 | 113 | radial_prof = tbin / nr.to(tbin.device) # normalize by counted bins 114 | 115 | return radial_prof 116 | 117 | 118 | def get_spectrum(data, normalize=False): 119 | if (data.ndim - 1) != 2: 120 | raise AttributeError(f'Data must be 2d.') 121 | 122 | freq = batch_fft(data, normalize=normalize) 123 | power_spec = freq.real ** 2 + freq.imag ** 2 124 | N = data.shape[1] 125 | if N % 2 == 0: # duplicate value for N/2 so it is put at the end of the spectrum and is not averaged with the mean value 126 | N_2 = N//2 127 | power_spec = torch.cat([power_spec[:, :N_2+1], power_spec[:, N_2:N_2+1], power_spec[:, N_2+1:]], dim=1) 128 | power_spec = torch.cat([power_spec[:, :, :N_2+1], power_spec[:, :, N_2:N_2+1], power_spec[:, :, N_2+1:]], dim=2) 129 | 130 | power_spec = roll_quadrants(power_spec) 131 | power_spec = azimuthal_average(power_spec) 132 | return power_spec 133 | 134 | 135 | def compute_spectrum_stats_for_dataset(dataset, batch_size=32): 136 | # Try to lookup from cache. 137 | resolution = dataset[0][0].shape[1] 138 | cache_file = os.path.join(dataset.root, f'spectrum{resolution}_N{len(dataset)}.pkl') 139 | if dataset.highpass: 140 | cache_file = cache_file.replace('.pkl', '_highpass.pkl') 141 | if os.path.isfile(cache_file): 142 | with open(cache_file, 'rb') as f: 143 | return pickle.load(f) 144 | 145 | # Main loop. 146 | spectra = [] 147 | for data in torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=False): 148 | imgs = data[0] 149 | if imgs.shape[1] == 1: 150 | imgs = imgs.repeat([1, 3, 1, 1]) 151 | imgs = imgs.to('cuda:0', torch.float32) 152 | spec = get_spectrum(imgs.flatten(0, 1)).unflatten(0, (imgs.shape[0], imgs.shape[1])) 153 | spec = spec.mean(dim=1) # average over channels 154 | spectra.append(spec.cpu()) 155 | 156 | spectra = torch.cat(spectra) 157 | stats = {'mean': spectra.mean(dim=0), 'std': spectra.std(dim=0)} 158 | 159 | # Save to cache. 160 | with open(cache_file, 'wb') as f: 161 | pickle.dump(stats, f) 162 | return stats 163 | 164 | 165 | def compute_spectrum_stats_for_generator(dataset, model, batch_size=32): 166 | device = 'cuda:0' 167 | model = model.eval().to(device) 168 | 169 | # Main loop. 170 | spectra = [] 171 | with torch.no_grad(): 172 | for data in torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=False): 173 | z = data[1].to(device) 174 | imgs = model(z) 175 | if imgs.shape[1] == 1: 176 | imgs = imgs.repeat([1, 3, 1, 1]) 177 | imgs = imgs.to(torch.float32) 178 | spec = get_spectrum(imgs.flatten(0, 1)).unflatten(0, (imgs.shape[0], imgs.shape[1])) 179 | spec = spec.mean(dim=1) # average over channels 180 | spectra.append(spec.cpu()) 181 | 182 | spectra = torch.cat(spectra) 183 | stats = {'mean': spectra.mean(dim=0), 'std': spectra.std(dim=0)} 184 | return stats 185 | 186 | 187 | def evaluate_spectrum(dataset, model, batch_size=32): 188 | spec_real = compute_spectrum_stats_for_dataset(dataset, batch_size=batch_size) 189 | spec_gen = compute_spectrum_stats_for_generator(dataset, model, batch_size=batch_size) 190 | 191 | return spec_real, spec_gen --------------------------------------------------------------------------------