├── .circleci └── config.yml ├── .coveragerc ├── .readthedocs.yml ├── .style.yapf ├── LICENSE ├── README.md ├── docs ├── Makefile ├── gallery │ └── README.md ├── images │ ├── celeba_128.png │ ├── cifar10.png │ ├── cifar100.png │ ├── imagenet_32.png │ ├── lsun_bedroom_128.png │ ├── mimicry_logo.png │ └── stl10_48.png ├── requirements.txt └── source │ ├── _static │ ├── css │ │ └── custom.css │ └── img │ │ └── mimicry_logo.png │ ├── conf.py │ ├── guides │ ├── baselines.rst │ ├── images │ │ ├── fake_vis.png │ │ ├── fixed_fake_vis.png │ │ ├── output.gif │ │ ├── resnet_backbone.png │ │ ├── ssgan.png │ │ └── tensorboard.png │ ├── introduction.rst │ └── tutorial.rst │ ├── index.rst │ └── modules │ ├── datasets.rst │ ├── metrics.rst │ ├── modules.rst │ ├── nets.rst │ ├── training.rst │ └── utils.rst ├── examples ├── eval_pretrained.py ├── sngan_example.py └── ssgan_tutorial.py ├── pytest.ini ├── requirements.txt ├── setup.cfg ├── setup.py ├── tests ├── datasets │ ├── imagenet │ │ └── test.bin │ ├── test_data_utils.py │ └── test_image_loader.py ├── metrics │ ├── fid │ │ └── test_fid.py │ ├── inception_model │ │ └── test_inception_utils.py │ ├── inception_score │ │ └── test_inception_score.py │ ├── kid │ │ └── test_kid.py │ ├── test_compute_fid.py │ ├── test_compute_is.py │ ├── test_compute_kid.py │ └── test_compute_metrics.py ├── modules │ ├── test_layers.py │ ├── test_losses.py │ ├── test_resblocks.py │ └── test_spectral_norm.py ├── nets │ ├── basemodel │ │ └── test_basemodel.py │ ├── cgan_pd │ │ ├── test_cgan_pd_128.py │ │ └── test_cgan_pd_32.py │ ├── dcgan │ │ ├── test_dcgan_128.py │ │ ├── test_dcgan_32.py │ │ ├── test_dcgan_48.py │ │ ├── test_dcgan_64.py │ │ └── test_dcgan_cifar.py │ ├── gan │ │ ├── test_cgan.py │ │ └── test_gan.py │ ├── infomax_gan │ │ ├── test_infomax_gan_128.py │ │ ├── test_infomax_gan_32.py │ │ ├── test_infomax_gan_48.py │ │ ├── test_infomax_gan_64.py │ │ └── test_infomax_gan_base.py │ ├── sagan │ │ ├── test_sagan_128.py │ │ └── test_sagan_32.py │ ├── sngan │ │ ├── test_sngan_128.py │ │ ├── test_sngan_32.py │ │ ├── test_sngan_48.py │ │ └── test_sngan_64.py │ ├── ssgan │ │ ├── test_ssgan_128.py │ │ ├── test_ssgan_32.py │ │ ├── test_ssgan_48.py │ │ ├── test_ssgan_64.py │ │ └── test_ssgan_base.py │ └── wgan_gp │ │ ├── test_wgan_gp_128.py │ │ ├── test_wgan_gp_32.py │ │ ├── test_wgan_gp_48.py │ │ ├── test_wgan_gp_64.py │ │ └── test_wgan_gp_resblocks.py ├── training │ ├── test_logger.py │ ├── test_metric_log.py │ ├── test_scheduler.py │ └── test_trainer.py └── utils │ └── test_common.py └── torch_mimicry ├── __init__.py ├── datasets ├── __init__.py ├── data_utils.py ├── image_loader.py └── imagenet │ ├── __init__.py │ └── imagenet.py ├── metrics ├── __init__.py ├── compute_fid.py ├── compute_is.py ├── compute_kid.py ├── compute_metrics.py ├── fid │ ├── __init__.py │ └── fid_utils.py ├── inception_model │ ├── __init__.py │ └── inception_utils.py ├── inception_score │ ├── __init__.py │ └── inception_score_utils.py └── kid │ ├── __init__.py │ └── kid_utils.py ├── modules ├── __init__.py ├── layers.py ├── losses.py ├── resblocks.py └── spectral_norm.py ├── nets ├── __init__.py ├── basemodel │ ├── __init__.py │ └── basemodel.py ├── cgan_pd │ ├── __init__.py │ ├── cgan_pd_128.py │ ├── cgan_pd_32.py │ └── cgan_pd_base.py ├── dcgan │ ├── __init__.py │ ├── dcgan_128.py │ ├── dcgan_32.py │ ├── dcgan_48.py │ ├── dcgan_64.py │ ├── dcgan_base.py │ └── dcgan_cifar.py ├── gan │ ├── __init__.py │ ├── cgan.py │ └── gan.py ├── infomax_gan │ ├── __init__.py │ ├── infomax_gan_128.py │ ├── infomax_gan_32.py │ ├── infomax_gan_48.py │ ├── infomax_gan_64.py │ └── infomax_gan_base.py ├── sagan │ ├── __init__.py │ ├── sagan_128.py │ ├── sagan_32.py │ └── sagan_base.py ├── sngan │ ├── __init__.py │ ├── sngan_128.py │ ├── sngan_32.py │ ├── sngan_48.py │ ├── sngan_64.py │ └── sngan_base.py ├── ssgan │ ├── __init__.py │ ├── ssgan_128.py │ ├── ssgan_32.py │ ├── ssgan_48.py │ ├── ssgan_64.py │ └── ssgan_base.py └── wgan_gp │ ├── __init__.py │ ├── wgan_gp_128.py │ ├── wgan_gp_32.py │ ├── wgan_gp_48.py │ ├── wgan_gp_64.py │ ├── wgan_gp_base.py │ └── wgan_gp_resblocks.py ├── training ├── __init__.py ├── logger.py ├── metric_log.py ├── scheduler.py └── trainer.py └── utils ├── __init__.py └── common.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | # Python CircleCI 2.1 configuration file 2 | version: 2.1 3 | orbs: 4 | codecov: codecov/codecov@1.0.5 5 | jobs: 6 | build: 7 | # machine: 8 | # image: ubuntu-1604:201903-01 # recommended linux image - includes Ubuntu 16.04, docker 18.09.3, docker-compose 1.23.1 9 | 10 | docker: 11 | # specify the version you desire here 12 | # use `-browsers` prefix for selenium tests, e.g. `3.6.1-browsers` 13 | - image: cimg/python:3.8.13 14 | 15 | # Specify service dependencies here if necessary 16 | # CircleCI maintains a library of pre-built images 17 | # documented at https://circleci.com/docs/2.0/circleci-images/ 18 | # - image: circleci/postgres:9.4 19 | 20 | working_directory: ~/repo 21 | parallelism: 4 22 | resource_class: large 23 | 24 | steps: 25 | - checkout 26 | 27 | # Download and cache dependencies 28 | - restore_cache: 29 | keys: 30 | - v1-dependencies-{{ checksum "requirements.txt" }} 31 | # fallback to using the latest cache if no exact match is found 32 | - v1-dependencies- 33 | 34 | - run: 35 | name: install dependencies 36 | command: | 37 | python3 -m venv venv 38 | . venv/bin/activate 39 | pip install pip==22.1.2 && pip install -r requirements.txt 40 | 41 | - save_cache: 42 | paths: 43 | - ./venv 44 | key: v1-dependencies-{{ checksum "requirements.txt" }} 45 | 46 | # run tests! 47 | # this example uses Django's built-in test-runner 48 | # other common Python testing frameworks include pytest and nose 49 | # https://pytest.org 50 | # https://nose.readthedocs.io 51 | - run: 52 | name: run tests 53 | command: | 54 | . venv/bin/activate 55 | python -m pytest --cov=./ --cov-report=xml:coverage.xml tests 56 | 57 | - codecov/upload: 58 | file: ./coverage.xml 59 | token: e963353d-df2c-4b11-bcc8-c8295b40bd7f 60 | 61 | - store_artifacts: 62 | path: test-reports 63 | destination: test-reports 64 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | exclude_lines = 3 | if __name__ == .__main__.: 4 | 5 | [run] 6 | omit = 7 | *test* 8 | torch_mimicry/datasets/imagenet/* 9 | setup.py 10 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | build: 9 | image: stable 10 | 11 | python: 12 | version: 3.6 13 | system_packages: true 14 | install: 15 | - requirements: docs/requirements.txt 16 | - method: setuptools 17 | path: . 18 | 19 | formats: [] 20 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = pep8 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kwot Sin Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | SPHINXBUILD = sphinx-build 2 | SPHINXPROJ = torch_mimicry 3 | SOURCEDIR = source 4 | BUILDDIR = build 5 | 6 | .PHONY: help Makefile 7 | 8 | %: Makefile 9 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" 10 | -------------------------------------------------------------------------------- /docs/gallery/README.md: -------------------------------------------------------------------------------- 1 | # Gallery 2 | We provide randomly sampled, non-cherry picked images for various datasets as produced by InfoMax-GAN. 3 | 4 | ### LSUN-Bedroom 5 | Resolution: 128 x 128 6 | 7 | ![alt text](https://github.com/kwotsin/mimicry/blob/master/docs/images/lsun_bedroom_128.png "LSUN-Bedroom (128 x 128)") 8 | 9 | ### CelebA 10 | Resolution: 128 x 128 11 | 12 | ![alt text](https://github.com/kwotsin/mimicry/blob/master/docs/images/celeba_128.png "CelebA (128 x 128)") 13 | 14 | ### STL-10 15 | Resolution: 48 x 48 16 | 17 | ![alt text](https://github.com/kwotsin/mimicry/blob/master/docs/images/stl10_48.png "STL-10 (48 x 48)") 18 | 19 | ### ImageNet 20 | Resolution: 32 x 32 21 | 22 | ![alt text](https://github.com/kwotsin/mimicry/blob/master/docs/images/imagenet_32.png "ImageNet (32 x 32)") 23 | 24 | ### CIFAR-10 25 | Resolution: 32 x 32 26 | 27 | ![alt text](https://github.com/kwotsin/mimicry/blob/master/docs/images/cifar10.png "CIFAR-10 (32 x 32)") 28 | 29 | ### CIFAR-100 30 | Resolution: 32 x 32 31 | 32 | ![alt text](https://github.com/kwotsin/mimicry/blob/master/docs/images/cifar100.png "CIFAR-100 (32 x 32)") 33 | -------------------------------------------------------------------------------- /docs/images/celeba_128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/images/celeba_128.png -------------------------------------------------------------------------------- /docs/images/cifar10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/images/cifar10.png -------------------------------------------------------------------------------- /docs/images/cifar100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/images/cifar100.png -------------------------------------------------------------------------------- /docs/images/imagenet_32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/images/imagenet_32.png -------------------------------------------------------------------------------- /docs/images/lsun_bedroom_128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/images/lsun_bedroom_128.png -------------------------------------------------------------------------------- /docs/images/mimicry_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/images/mimicry_logo.png -------------------------------------------------------------------------------- /docs/images/stl10_48.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/images/stl10_48.png -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | python_version >= '3.6' 2 | scipy==1.4.1 3 | sphinx 4 | sphinx_rtd_theme 5 | torch 6 | torchvision 7 | numpy 8 | -e git+https://github.com/kwotsin/mimicry.git#egg=torch_mimicry 9 | -------------------------------------------------------------------------------- /docs/source/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /* White logo background. */ 2 | .wy-side-nav-search { 3 | background-color: #fff; 4 | } 5 | 6 | .wy-side-nav-search > div.version { 7 | color: #000; 8 | } 9 | -------------------------------------------------------------------------------- /docs/source/_static/img/mimicry_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/source/_static/img/mimicry_logo.png -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import sphinx_rtd_theme 3 | import doctest 4 | import torch_mimicry 5 | 6 | extensions = [ 7 | 'sphinx.ext.autodoc', 8 | 'sphinx.ext.doctest', 9 | 'sphinx.ext.intersphinx', 10 | 'sphinx.ext.mathjax', 11 | 'sphinx.ext.napoleon', 12 | 'sphinx.ext.viewcode', 13 | 'sphinx.ext.githubpages', 14 | ] 15 | 16 | source_suffix = '.rst' 17 | master_doc = 'index' 18 | 19 | author = 'Kwot Sin Lee' 20 | project = 'torch_mimicry' 21 | copyright = '{}, {}'.format(datetime.datetime.now().year, author) 22 | 23 | version = 'master' 24 | release = torch_mimicry.__version__ 25 | 26 | html_theme = 'sphinx_rtd_theme' 27 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 28 | 29 | doctest_default_flags = doctest.NORMALIZE_WHITESPACE 30 | intersphinx_mapping = {'python': ('https://docs.python.org/', None)} 31 | 32 | html_theme_options = { 33 | 'collapse_navigation': False, 34 | 'display_version': True, 35 | 'logo_only': True, 36 | } 37 | 38 | html_logo = '_static/img/mimicry_logo.png' 39 | html_static_path = ['_static'] 40 | html_context = {'css_files': ['_static/css/custom.css']} 41 | 42 | add_module_names = False 43 | 44 | 45 | def setup(app): 46 | def skip(app, what, name, obj, skip, options): 47 | members = [ 48 | '__init__', 49 | '__repr__', 50 | '__weakref__', 51 | '__dict__', 52 | '__module__', 53 | ] 54 | return True if name in members else skip 55 | 56 | app.connect('autodoc-skip-member', skip) 57 | -------------------------------------------------------------------------------- /docs/source/guides/images/fake_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/source/guides/images/fake_vis.png -------------------------------------------------------------------------------- /docs/source/guides/images/fixed_fake_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/source/guides/images/fixed_fake_vis.png -------------------------------------------------------------------------------- /docs/source/guides/images/output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/source/guides/images/output.gif -------------------------------------------------------------------------------- /docs/source/guides/images/resnet_backbone.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/source/guides/images/resnet_backbone.png -------------------------------------------------------------------------------- /docs/source/guides/images/ssgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/source/guides/images/ssgan.png -------------------------------------------------------------------------------- /docs/source/guides/images/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/source/guides/images/tensorboard.png -------------------------------------------------------------------------------- /docs/source/guides/introduction.rst: -------------------------------------------------------------------------------- 1 | Introduction 2 | ============ 3 | 4 | Installation 5 | ------------ 6 | For best performance, we recommend you to install the GPU versions of both TensorFlow and PyTorch, which are used in this library. 7 | 8 | First create a new environment with :code:`conda` using Python 3.6 or 3.7 (Python 3.8 is currently not supported by TensorFlow): 9 | 10 | .. code-block:: none 11 | 12 | $ conda create -n mimicry python=3.6 13 | 14 | Install TensorFlow (GPU): 15 | 16 | .. code-block:: none 17 | 18 | $ conda install -c anaconda tensorflow-gpu 19 | 20 | Install PyTorch (GPU): 21 | 22 | .. code-block:: none 23 | 24 | $ conda install pytorch torchvision cudatoolkit=9.2 -c pytorch 25 | 26 | For installing the CUDA version matching your drivers, see the `official instructions `_. 27 | 28 | Mimicry can be installed using :code:`pip` directly: 29 | 30 | .. code-block:: none 31 | 32 | $ pip install torch-mimicry 33 | 34 | .. To install the GPU versions of TensorFlow and PyTorch see the official `PyTorch `_ and `TensorFlow `_ pages based on your system. The library is currently only compatible with :code:`TensorFlow 1.x` to accommodate the original implementations of the GAN metrics. 35 | 36 | Quick Start 37 | ----------- 38 | We provide a sample training script for training the `Spectral Normalization GAN `_ on the CIFAR-10 dataset, with the same training hyperparameters that reproduce results in the paper. 39 | 40 | 41 | .. code-block:: python 42 | 43 | import torch 44 | import torch.optim as optim 45 | import torch_mimicry as mmc 46 | from torch_mimicry.nets import sngan 47 | 48 | 49 | # Data handling objects 50 | device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") 51 | dataset = mmc.datasets.load_dataset(root='./datasets', name='cifar10') 52 | dataloader = torch.utils.data.DataLoader( 53 | dataset, batch_size=64, shuffle=True, num_workers=4) 54 | 55 | # Define models and optimizers 56 | netG = sngan.SNGANGenerator32().to(device) 57 | netD = sngan.SNGANDiscriminator32().to(device) 58 | optD = optim.Adam(netD.parameters(), 2e-4, betas=(0.0, 0.9)) 59 | optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9)) 60 | 61 | # Start training 62 | trainer = mmc.training.Trainer( 63 | netD=netD, 64 | netG=netG, 65 | optD=optD, 66 | optG=optG, 67 | n_dis=5, 68 | num_steps=100000, 69 | lr_decay='linear', 70 | dataloader=dataloader, 71 | log_dir='./log/example', 72 | device=device) 73 | trainer.train() 74 | 75 | To evaluate its FID, we can simply run the following: 76 | 77 | .. code-block:: python 78 | 79 | # Evaluate fid 80 | mmc.metrics.evaluate( 81 | metric='fid', 82 | log_dir='./log/example', 83 | netG=netG, 84 | dataset='cifar10', 85 | num_real_samples=50000, 86 | num_fake_samples=50000, 87 | evaluate_step=100000, 88 | device=device) 89 | 90 | Alternatively, one could evaluate FID progressively over an interval by swapping the `evaluate_step` argument for `evaluate_range`: 91 | 92 | .. code-block:: python 93 | 94 | # Evaluate fid 95 | mmc.metrics.evaluate( 96 | metric='fid', 97 | log_dir='./log/example', 98 | netG=netG, 99 | dataset='cifar10', 100 | num_real_samples=50000, 101 | num_fake_samples=50000, 102 | evaluate_range=(5000, 100000, 5000), 103 | device=device) 104 | 105 | We support other datasets and models See `datasets `_ and `nets `_ for more information. 106 | 107 | Visualizations 108 | -------------- 109 | Mimicry provides TensorBoard support for visualizing the following: 110 | 111 | - Loss and probability curves for monitoring GAN training 112 | - Randomly generated images for checking diversity. 113 | - Generated images from a fixed set of noise vectors. 114 | 115 | .. code-block:: none 116 | 117 | $ tensorboard --logdir=./log/example 118 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/kwotsin/mimicry 2 | 3 | Mimicry Documentation 4 | ===================== 5 | 6 | `Mimicry `_ is a lightweight PyTorch library aimed towards the reproducibility of GAN research. 7 | 8 | Comparing GANs is often difficult - mild differences in implementations and evaluation methodologies can result in huge performance differences. Mimicry aims to resolve this by providing: 9 | 10 | - Standardized implementations of popular GANs that closely reproduce reported scores 11 | 12 | - Baseline scores of GANs trained and evaluated under the same conditions 13 | 14 | - A framework for researchers to focus on implementation of GANs without rewriting most of GAN training boilerplate code, with support for multiple GAN evaluation metrics. 15 | 16 | .. toctree:: 17 | :glob: 18 | :maxdepth: 1 19 | :caption: Guides 20 | 21 | guides/introduction 22 | guides/tutorial 23 | guides/baselines 24 | 25 | .. toctree:: 26 | :caption: API Reference 27 | :maxdepth: 1 28 | 29 | modules/nets 30 | modules/modules 31 | modules/training 32 | modules/metrics 33 | modules/datasets 34 | modules/utils 35 | -------------------------------------------------------------------------------- /docs/source/modules/datasets.rst: -------------------------------------------------------------------------------- 1 | torch_mimicry.datasets 2 | ====================== 3 | 4 | .. contents:: Contents 5 | :local: 6 | 7 | 8 | Dataset Loaders 9 | --------------- 10 | .. automodule:: torch_mimicry.datasets.data_utils 11 | :members: 12 | 13 | 14 | Image Loaders 15 | ------------- 16 | .. automodule:: torch_mimicry.datasets.image_loader 17 | :members: -------------------------------------------------------------------------------- /docs/source/modules/metrics.rst: -------------------------------------------------------------------------------- 1 | torch_mimicry.metrics 2 | ===================== 3 | 4 | .. contents:: Contents 5 | :local: 6 | 7 | 8 | FID 9 | --- 10 | .. automodule:: torch_mimicry.metrics.compute_fid 11 | :members: 12 | 13 | KID 14 | --- 15 | .. automodule:: torch_mimicry.metrics.compute_kid 16 | :members: 17 | 18 | Inception Score 19 | --------------- 20 | .. automodule:: torch_mimicry.metrics.compute_is 21 | :members: 22 | 23 | Metrics 24 | --------------- 25 | .. automodule:: torch_mimicry.metrics.compute_metrics 26 | :members: 27 | -------------------------------------------------------------------------------- /docs/source/modules/modules.rst: -------------------------------------------------------------------------------- 1 | torch_mimicry.modules 2 | ===================== 3 | 4 | .. contents:: Contents 5 | :local: 6 | 7 | Layers 8 | ------ 9 | .. automodule:: torch_mimicry.modules.layers 10 | :members: 11 | 12 | Residual Blocks 13 | --------------- 14 | .. automodule:: torch_mimicry.modules.resblocks 15 | :members: 16 | 17 | Losses 18 | ------ 19 | .. automodule:: torch_mimicry.modules.losses 20 | :members: 21 | -------------------------------------------------------------------------------- /docs/source/modules/nets.rst: -------------------------------------------------------------------------------- 1 | torch_mimicry.nets 2 | ================== 3 | 4 | .. contents:: Contents 5 | :local: 6 | 7 | torch_mimicry.nets.basemodel 8 | ---------------------------- 9 | 10 | .. automodule:: torch_mimicry.nets.basemodel.basemodel 11 | :members: 12 | 13 | torch_mimicry.nets.gan 14 | ---------------------- 15 | 16 | Base Unconditional GAN 17 | ^^^^^^^^^^^^^^^^^^^^^^ 18 | 19 | .. automodule:: torch_mimicry.nets.gan.gan 20 | :members: 21 | 22 | Base Conditional GAN 23 | ^^^^^^^^^^^^^^^^^^^^ 24 | .. automodule:: torch_mimicry.nets.gan.cgan 25 | :members: 26 | 27 | 28 | torch_mimicry.nets.dcgan 29 | ------------------------ 30 | DCGAN CIFAR 31 | ^^^^^^^^^^^ 32 | .. automodule:: torch_mimicry.nets.dcgan.dcgan_cifar 33 | :members: 34 | 35 | DCGAN 32 36 | ^^^^^^^^ 37 | .. automodule:: torch_mimicry.nets.dcgan.dcgan_32 38 | :members: 39 | 40 | DCGAN 48 41 | ^^^^^^^^ 42 | .. automodule:: torch_mimicry.nets.dcgan.dcgan_48 43 | :members: 44 | 45 | DCGAN 64 46 | ^^^^^^^^ 47 | .. automodule:: torch_mimicry.nets.dcgan.dcgan_64 48 | :members: 49 | 50 | 51 | DCGAN 128 52 | ^^^^^^^^^ 53 | .. automodule:: torch_mimicry.nets.dcgan.dcgan_128 54 | :members: 55 | 56 | DCGAN Base 57 | ^^^^^^^^^^ 58 | .. automodule:: torch_mimicry.nets.dcgan.dcgan_base 59 | :members: 60 | 61 | torch_mimicry.nets.wgan_gp 62 | -------------------------- 63 | WGAN-GP 32 64 | ^^^^^^^^^^ 65 | .. automodule:: torch_mimicry.nets.wgan_gp.wgan_gp_32 66 | :members: 67 | 68 | WGAN-GP 48 69 | ^^^^^^^^^^ 70 | .. automodule:: torch_mimicry.nets.wgan_gp.wgan_gp_48 71 | :members: 72 | 73 | WGAN-GP 64 74 | ^^^^^^^^^^ 75 | .. automodule:: torch_mimicry.nets.wgan_gp.wgan_gp_64 76 | :members: 77 | 78 | WGAN-GP 128 79 | ^^^^^^^^^^^ 80 | .. automodule:: torch_mimicry.nets.wgan_gp.wgan_gp_128 81 | :members: 82 | 83 | WGAN-GP Base 84 | ^^^^^^^^^^^^ 85 | .. automodule:: torch_mimicry.nets.wgan_gp.wgan_gp_base 86 | :members: 87 | 88 | torch_mimicry.nets.sngan 89 | ------------------------ 90 | SNGAN 32 91 | ^^^^^^^^ 92 | .. automodule:: torch_mimicry.nets.sngan.sngan_32 93 | :members: 94 | 95 | SNGAN 48 96 | ^^^^^^^^ 97 | .. automodule:: torch_mimicry.nets.sngan.sngan_48 98 | :members: 99 | 100 | SNGAN 64 101 | ^^^^^^^^ 102 | .. automodule:: torch_mimicry.nets.sngan.sngan_64 103 | :members: 104 | 105 | SNGAN 128 106 | ^^^^^^^^^ 107 | .. automodule:: torch_mimicry.nets.sngan.sngan_128 108 | :members: 109 | 110 | SNGAN Base 111 | ^^^^^^^^^^ 112 | .. automodule:: torch_mimicry.nets.sngan.sngan_base 113 | :members: 114 | 115 | torch_mimicry.nets.cgan_pd 116 | -------------------------- 117 | CGAN-PD 32 118 | ^^^^^^^^^^ 119 | .. automodule:: torch_mimicry.nets.cgan_pd.cgan_pd_32 120 | :members: 121 | 122 | CGAN-PD 128 123 | ^^^^^^^^^^^ 124 | .. automodule:: torch_mimicry.nets.cgan_pd.cgan_pd_128 125 | :members: 126 | 127 | CGAN-PD Base 128 | ^^^^^^^^^^^^ 129 | .. automodule:: torch_mimicry.nets.cgan_pd.cgan_pd_base 130 | :members: 131 | 132 | torch_mimicry.nets.ssgan 133 | ------------------------ 134 | SSGAN 32 135 | ^^^^^^^^ 136 | .. automodule:: torch_mimicry.nets.ssgan.ssgan_32 137 | :members: 138 | 139 | SSGAN 48 140 | ^^^^^^^^ 141 | .. automodule:: torch_mimicry.nets.ssgan.ssgan_48 142 | :members: 143 | 144 | SSGAN 64 145 | ^^^^^^^^ 146 | .. automodule:: torch_mimicry.nets.ssgan.ssgan_64 147 | :members: 148 | 149 | SSGAN 128 150 | ^^^^^^^^^ 151 | .. automodule:: torch_mimicry.nets.ssgan.ssgan_128 152 | :members: 153 | 154 | SSGAN Base 155 | ^^^^^^^^^^ 156 | .. automodule:: torch_mimicry.nets.ssgan.ssgan_base 157 | :members: 158 | 159 | torch_mimicry.nets.infomax_gan 160 | ------------------------------ 161 | InfoMax-GAN 32 162 | ^^^^^^^^^^^^^^ 163 | .. automodule:: torch_mimicry.nets.infomax_gan.infomax_gan_32 164 | :members: 165 | 166 | InfoMax-GAN 48 167 | ^^^^^^^^^^^^^^ 168 | .. automodule:: torch_mimicry.nets.infomax_gan.infomax_gan_48 169 | :members: 170 | 171 | InfoMax-GAN 64 172 | ^^^^^^^^^^^^^^ 173 | .. automodule:: torch_mimicry.nets.infomax_gan.infomax_gan_64 174 | :members: 175 | 176 | InfoMax-GAN 128 177 | ^^^^^^^^^^^^^^^ 178 | .. automodule:: torch_mimicry.nets.infomax_gan.infomax_gan_128 179 | :members: 180 | 181 | InfoMax-GAN Base 182 | ^^^^^^^^^^^^^^^^ 183 | .. automodule:: torch_mimicry.nets.infomax_gan.infomax_gan_base 184 | :members: 185 | -------------------------------------------------------------------------------- /docs/source/modules/training.rst: -------------------------------------------------------------------------------- 1 | torch_mimicry.training 2 | ====================== 3 | 4 | .. contents:: Contents 5 | :local: 6 | 7 | 8 | Trainer 9 | ------- 10 | .. automodule:: torch_mimicry.training.trainer 11 | :members: 12 | 13 | 14 | Logger 15 | ------ 16 | .. automodule:: torch_mimicry.training.logger 17 | :members: 18 | 19 | MetricLog 20 | ---------- 21 | .. automodule:: torch_mimicry.training.metric_log 22 | :members: 23 | 24 | Scheduler 25 | --------- 26 | .. automodule:: torch_mimicry.training.scheduler 27 | :members: 28 | -------------------------------------------------------------------------------- /docs/source/modules/utils.rst: -------------------------------------------------------------------------------- 1 | torch_mimicry.utils 2 | =================== 3 | 4 | .. contents:: Contents 5 | :local: 6 | 7 | 8 | Common Utilities 9 | ---------------- 10 | .. automodule:: torch_mimicry.utils.common 11 | :members: -------------------------------------------------------------------------------- /examples/eval_pretrained.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example script of evaluating a pretrained generator. 3 | """ 4 | import torch 5 | import torch_mimicry as mmc 6 | from torch_mimicry.nets import sngan 7 | 8 | ###################################################### 9 | # Computing Metrics with Default Datasets 10 | ###################################################### 11 | 12 | # Download cifar10 checkpoint: https://drive.google.com/uc?id=1Gn4ouslRAHq3D7AP_V-T2x8Wi1S1hTXJ&export=download 13 | ckpt_file = "./log/sngan_example/checkpoints/netG/netG_100000_steps.pth" 14 | 15 | # Default variables 16 | log_dir = './examples/example_log' 17 | dataset = 'cifar10' 18 | seed = 0 19 | device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") 20 | 21 | # Restore model 22 | netG = sngan.SNGANGenerator32().to(device) 23 | netG.restore_checkpoint(ckpt_file) 24 | 25 | # Metrics with a known/popular dataset. 26 | mmc.metrics.fid_score(num_real_samples=50000, 27 | num_fake_samples=50000, 28 | netG=netG, 29 | seed=seed, 30 | dataset=dataset, 31 | log_dir=log_dir, 32 | device=device) 33 | 34 | mmc.metrics.kid_score(num_samples=50000, 35 | netG=netG, 36 | seed=seed, 37 | dataset=dataset, 38 | log_dir=log_dir, 39 | device=device) 40 | 41 | mmc.metrics.inception_score(num_samples=50000, 42 | netG=netG, 43 | seed=seed, 44 | log_dir=log_dir, 45 | device=device) 46 | 47 | ###################################################### 48 | # Computing Metrics with Custom Datasets 49 | ###################################################### 50 | """ 51 | Simply define a custom dataset as below to compute FID/KID, and define 52 | a stats_file/feat_file to save the cached statistics since we don't know what 53 | name to give your file. 54 | """ 55 | 56 | 57 | class CustomDataset(torch.utils.data.Dataset): 58 | def __init__(self): 59 | super().__init__() 60 | self.data = torch.ones(1000, 3, 32, 32) 61 | 62 | def __len__(self): 63 | return self.data.shape[0] 64 | 65 | def __getitem__(self, idx): 66 | return self.data[idx] 67 | 68 | 69 | custom_dataset = CustomDataset() 70 | 71 | # Metrics with a custom dataset. 72 | mmc.metrics.fid_score(num_real_samples=1000, 73 | num_fake_samples=1000, 74 | netG=netG, 75 | seed=seed, 76 | dataset=custom_dataset, 77 | log_dir=log_dir, 78 | device=device, 79 | stats_file='./examples/example_log/fid_stats.npz') 80 | 81 | mmc.metrics.kid_score(num_samples=1000, 82 | netG=netG, 83 | seed=seed, 84 | dataset=custom_dataset, 85 | log_dir=log_dir, 86 | device=device, 87 | feat_file='./examples/example_log/kid_stats.npz') 88 | 89 | # Using the evaluate API, which assumes a more fixed directory. 90 | netG = sngan.SNGANGenerator32().to(device) 91 | mmc.metrics.evaluate(metric='fid', 92 | log_dir='./log/sngan_example/', 93 | netG=netG, 94 | dataset=custom_dataset, 95 | num_real_samples=1000, 96 | num_fake_samples=1000, 97 | evaluate_step=100000, 98 | stats_file='./examples/example_log/fid_stats.npz', 99 | device=device) -------------------------------------------------------------------------------- /examples/sngan_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Typical usage example. 3 | """ 4 | 5 | import torch 6 | import torch.optim as optim 7 | import torch_mimicry as mmc 8 | from torch_mimicry.nets import sngan 9 | 10 | if __name__ == "__main__": 11 | # Data handling objects 12 | device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") 13 | dataset = mmc.datasets.load_dataset(root='./datasets', name='cifar10') 14 | dataloader = torch.utils.data.DataLoader(dataset, 15 | batch_size=64, 16 | shuffle=True, 17 | num_workers=4) 18 | 19 | # Define models and optimizers 20 | netG = sngan.SNGANGenerator32().to(device) 21 | netD = sngan.SNGANDiscriminator32().to(device) 22 | optD = optim.Adam(netD.parameters(), 2e-4, betas=(0.0, 0.9)) 23 | optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9)) 24 | 25 | # Start training 26 | trainer = mmc.training.Trainer(netD=netD, 27 | netG=netG, 28 | optD=optD, 29 | optG=optG, 30 | n_dis=5, 31 | num_steps=30, 32 | lr_decay='linear', 33 | dataloader=dataloader, 34 | log_dir='./log/example', 35 | device=device) 36 | trainer.train() 37 | 38 | # Evaluate fid 39 | mmc.metrics.evaluate(metric='fid', 40 | log_dir='./log/example', 41 | netG=netG, 42 | dataset='cifar10', 43 | num_real_samples=50000, 44 | num_fake_samples=50000, 45 | evaluate_step=30, 46 | device=device) 47 | 48 | # Evaluate kid 49 | mmc.metrics.evaluate(metric='kid', 50 | log_dir='./log/example', 51 | netG=netG, 52 | dataset='cifar10', 53 | num_samples=50000, 54 | evaluate_step=30, 55 | device=device) 56 | 57 | # Evaluate inception score 58 | mmc.metrics.evaluate(metric='inception_score', 59 | log_dir='./log/example', 60 | netG=netG, 61 | num_samples=50000, 62 | evaluate_step=30, 63 | device=device) 64 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | filterwarnings = 3 | ignore::DeprecationWarning 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.23.1 2 | scipy==1.9.0 3 | requests==2.28.1 4 | torch==1.12.1 5 | tensorflow==2.9.1 6 | torchvision==0.13.1 7 | six==1.16.0 8 | matplotlib==3.5.2 9 | Pillow==9.2.0 10 | scikit-image==0.19.3 11 | pytest==7.1.2 12 | scikit-learn==1.1.2 13 | future==0.18.2 14 | pytest-cov==3.0.0 15 | pandas==1.4.3 16 | psutil==5.9.1 17 | yapf==0.32.0 18 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | 4 | [aliases] 5 | test=pytest 6 | 7 | [tool:pytest] 8 | addopts = --capture=no --cov 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | __version__ = '0.1.16' 4 | url = 'https://github.com/kwotsin/mimicry' 5 | 6 | install_requires = [ 7 | 'numpy', 8 | 'scipy', 9 | 'requests', 10 | 'torch', 11 | 'tensorflow', 12 | 'torchvision', 13 | 'six', 14 | 'matplotlib', 15 | 'Pillow', 16 | 'scikit-image', 17 | 'pytest', 18 | 'scikit-learn', 19 | 'future', 20 | 'pytest-cov', 21 | 'pandas', 22 | 'psutil', 23 | 'yapf', 24 | 'lmdb', 25 | ] 26 | 27 | setup_requires = ['pytest-runner'] 28 | tests_require = ['pytest', 'pytest-cov', 'mock'] 29 | 30 | long_description = """ 31 | Mimicry is a lightweight PyTorch library aimed towards the reproducibility of GAN research. 32 | 33 | Comparing GANs is often difficult - mild differences in implementations and evaluation methodologies can result in huge performance differences. 34 | Mimicry aims to resolve this by providing: 35 | (a) Standardized implementations of popular GANs that closely reproduce reported scores; 36 | (b) Baseline scores of GANs trained and evaluated under the same conditions; 37 | (c) A framework for researchers to focus on implementation of GANs without rewriting most of GAN training boilerplate code, with support for multiple GAN evaluation metrics. 38 | 39 | We provide a model zoo and set of baselines to benchmark different GANs of the same model size trained under the same conditions, using multiple metrics. To ensure reproducibility, we verify scores of our implemented models against reported scores in literature. 40 | """ 41 | 42 | setup( 43 | name='torch_mimicry', 44 | version=__version__, 45 | long_description=long_description, 46 | long_description_content_type='text/markdown', 47 | description='Mimicry: Towards the Reproducibility of GAN Research', 48 | author='Kwot Sin Lee', 49 | author_email='ksl36@cam.ac.uk', 50 | url=url, 51 | download_url='{}/archive/{}.tar.gz'.format(url, __version__), 52 | keywords=[ 53 | 'pytorch', 54 | 'generative-adversarial-networks', 55 | 'gans', 56 | 'GAN', 57 | ], 58 | python_requires='>=3.6', 59 | install_requires=install_requires, 60 | setup_requires=setup_requires, 61 | tests_require=tests_require, 62 | packages=find_packages(), 63 | ) 64 | -------------------------------------------------------------------------------- /tests/datasets/imagenet/test.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/tests/datasets/imagenet/test.bin -------------------------------------------------------------------------------- /tests/metrics/fid/test_fid.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from torch_mimicry.metrics.fid import fid_utils 6 | from torch_mimicry.metrics.inception_model import inception_utils 7 | 8 | 9 | class TestFID: 10 | def setup(self): 11 | self.images = np.ones((4, 32, 32, 3)) 12 | self.sess = tf.compat.v1.Session() 13 | 14 | def test_calculate_activation_statistics(self): 15 | inception_path = './metrics/inception_model' 16 | inception_utils.create_inception_graph(inception_path) 17 | 18 | mu, sigma = fid_utils.calculate_activation_statistics( 19 | images=self.images, sess=self.sess) 20 | 21 | assert mu.shape == (2048, ) 22 | assert sigma.shape == (2048, 2048) 23 | 24 | def test_calculate_frechet_distance(self): 25 | mu1, sigma1 = np.ones((16, )), np.ones((16, 16)) 26 | mu2, sigma2 = mu1 * 2, sigma1 * 2 27 | 28 | score = fid_utils.calculate_frechet_distance(mu1=mu1, 29 | mu2=mu2, 30 | sigma1=sigma1, 31 | sigma2=sigma2) 32 | 33 | assert type(score) == np.float64 34 | 35 | # Inputs check 36 | bad_mu2, bad_sigma2 = np.ones((15, 15)), np.ones((15, 15)) 37 | with pytest.raises(ValueError): 38 | fid_utils.calculate_frechet_distance(mu1=mu1, 39 | mu2=bad_mu2, 40 | sigma1=sigma1, 41 | sigma2=bad_sigma2) 42 | 43 | def teardown(self): 44 | del self.images 45 | self.sess.close() 46 | 47 | 48 | if __name__ == "__main__": 49 | test = TestFID() 50 | test.setup() 51 | test.test_calculate_activation_statistics() 52 | test.test_calculate_frechet_distance() 53 | test.teardown() 54 | -------------------------------------------------------------------------------- /tests/metrics/inception_model/test_inception_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from torch_mimicry.metrics.inception_model import inception_utils 5 | 6 | 7 | class TestInceptionUtils: 8 | def test_get_activations(self): 9 | for inception_path in ['./metrics/inception_model', None]: 10 | inception_utils.create_inception_graph(inception_path) 11 | 12 | images = np.ones((4, 32, 32, 3)) 13 | with tf.compat.v1.Session() as sess: 14 | feat = inception_utils.get_activations(images=images, 15 | sess=sess) 16 | 17 | assert feat.shape == (4, 2048) 18 | 19 | 20 | if __name__ == "__main__": 21 | test = TestInceptionUtils() 22 | test.test_get_activations() 23 | -------------------------------------------------------------------------------- /tests/metrics/inception_score/test_inception_score.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import torch 4 | 5 | from torch_mimicry.metrics.inception_model import inception_utils 6 | from torch_mimicry.metrics.inception_score import inception_score_utils 7 | 8 | 9 | class TestInceptionScore: 10 | def test_get_predictions(self): 11 | inception_utils.create_inception_graph('./metrics/inception_model') 12 | 13 | images = np.ones((4, 32, 32, 3)) 14 | preds = inception_score_utils.get_predictions(images) 15 | assert preds.shape == (4, 1008) 16 | 17 | preds = inception_score_utils.get_predictions( 18 | images, device=torch.device('cpu')) 19 | assert preds.shape == (4, 1008) 20 | 21 | def test_get_inception_score(self): 22 | images = np.ones((4, 32, 32, 3)) 23 | mean, std = inception_score_utils.get_inception_score(images) 24 | 25 | assert type(mean) == float 26 | assert type(std) == float 27 | 28 | with pytest.raises(ValueError): 29 | images *= -1 30 | inception_score_utils.get_inception_score(images) 31 | 32 | 33 | if __name__ == "__main__": 34 | test = TestInceptionScore() 35 | test.test_get_predictions() 36 | test.test_get_inception_score() 37 | -------------------------------------------------------------------------------- /tests/metrics/kid/test_kid.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from sklearn.metrics.pairwise import polynomial_kernel 4 | 5 | from torch_mimicry.metrics.kid import kid_utils 6 | 7 | 8 | class TestKID: 9 | def setup(self): 10 | self.codes_g = np.ones((4, 16)) 11 | self.codes_r = np.ones((4, 16)) 12 | 13 | def test_polynomial_mmd(self): 14 | score = kid_utils.polynomial_mmd(codes_g=self.codes_g, 15 | codes_r=self.codes_r) 16 | 17 | assert type(score) == np.float64 18 | assert score < 1e-5 19 | 20 | def test_polynomial_mmd_averages(self): 21 | 22 | scores = kid_utils.polynomial_mmd_averages(codes_g=self.codes_g, 23 | codes_r=self.codes_r, 24 | n_subsets=4, 25 | subset_size=1) 26 | 27 | assert len(scores) == 4 28 | assert type(scores[0]) == np.float64 29 | 30 | def test_compute_mmd2(self): 31 | X = self.codes_g 32 | Y = self.codes_r 33 | K_XX = polynomial_kernel(X) 34 | K_YY = polynomial_kernel(Y) 35 | K_XY = polynomial_kernel(X, Y) 36 | 37 | mmd_est_args = ['u-statistic', 'unbiased'] 38 | 39 | for mmd_est in mmd_est_args: 40 | for unit_diagonal in [True, False]: 41 | mmd2_score = kid_utils._compute_mmd2( 42 | K_XX=K_XX, 43 | K_YY=K_YY, 44 | K_XY=K_XY, 45 | mmd_est=mmd_est, 46 | unit_diagonal=unit_diagonal) 47 | 48 | assert type(mmd2_score) == np.float64 49 | 50 | # Input checks 51 | with pytest.raises(ValueError): 52 | kid_utils._compute_mmd2(K_XX=K_XX, 53 | K_YY=K_YY, 54 | K_XY=K_XY, 55 | mmd_est='invalid_option', 56 | unit_diagonal=unit_diagonal) 57 | 58 | m = K_XX.shape[0] 59 | with pytest.raises(ValueError): 60 | bad_K_XX = np.ones((m + 1, m + 1)) 61 | kid_utils._compute_mmd2(K_XX=bad_K_XX, 62 | K_YY=K_YY, 63 | K_XY=K_XY, 64 | mmd_est='unbiased', 65 | unit_diagonal=unit_diagonal) 66 | 67 | with pytest.raises(ValueError): 68 | bad_K_YY = np.ones((m + 1, m + 1)) 69 | kid_utils._compute_mmd2(K_XX=K_XX, 70 | K_YY=bad_K_YY, 71 | K_XY=K_XY, 72 | mmd_est='unbiased', 73 | unit_diagonal=unit_diagonal) 74 | 75 | with pytest.raises(ValueError): 76 | bad_K_XY = np.ones((m + 1, m + 1)) 77 | kid_utils._compute_mmd2(K_XX=K_XX, 78 | K_YY=K_YY, 79 | K_XY=bad_K_XY, 80 | mmd_est='unbiased', 81 | unit_diagonal=unit_diagonal) 82 | 83 | def teardown(self): 84 | del self.codes_g 85 | del self.codes_r 86 | 87 | 88 | if __name__ == "__main__": 89 | test = TestKID() 90 | test.setup() 91 | test.test_polynomial_mmd() 92 | test.test_polynomial_mmd_averages() 93 | test.test_compute_mmd2() 94 | test.teardown() 95 | -------------------------------------------------------------------------------- /tests/metrics/test_compute_is.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_mimicry.metrics import compute_is 4 | from torch_mimicry.nets.gan import gan 5 | 6 | 7 | class ExampleGen(gan.BaseGenerator): 8 | def __init__(self, 9 | bottom_width=4, 10 | nz=4, 11 | ngf=256, 12 | loss_type='gan', 13 | *args, 14 | **kwargs): 15 | super().__init__(nz=nz, 16 | ngf=ngf, 17 | bottom_width=bottom_width, 18 | loss_type=loss_type, 19 | *args, 20 | **kwargs) 21 | 22 | def forward(self, x): 23 | output = torch.ones(x.shape[0], 3, 32, 32) 24 | 25 | return output 26 | 27 | 28 | class TestComputeIS: 29 | def setup(self): 30 | self.device = torch.device('cpu') 31 | self.netG = ExampleGen() 32 | 33 | def test_compute_inception_score(self): 34 | mean, std = compute_is.inception_score(netG=self.netG, 35 | device=self.device, 36 | num_samples=10, 37 | batch_size=10) 38 | 39 | assert type(mean) == float 40 | assert type(std) == float 41 | 42 | def teardown(self): 43 | del self.netG 44 | 45 | 46 | if __name__ == "__main__": 47 | test = TestComputeIS() 48 | test.setup() 49 | test.test_compute_inception_score() 50 | -------------------------------------------------------------------------------- /tests/modules/test_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch_mimicry.modules import layers 5 | 6 | 7 | class TestLayers: 8 | def setup(self): 9 | self.N, self.C, self.H, self.W = (4, 3, 32, 32) 10 | self.n_out = 8 11 | 12 | def test_ConditionalBatchNorm2d(self): 13 | num_classes = 10 14 | X = torch.ones(self.N, self.C, self.H, self.W) 15 | y = torch.randint(low=0, high=num_classes, size=(self.N, )) 16 | 17 | # Setup cond. BN --> Note: because we usually do 18 | # BN-ReLU-Conv, we take num feat as input channels 19 | conv = nn.Conv2d(self.C, self.n_out, 1, 1, 0) 20 | bn = layers.ConditionalBatchNorm2d(num_features=self.C, 21 | num_classes=num_classes) 22 | 23 | output = bn(X, y) 24 | output = conv(output) 25 | 26 | assert output.shape == (self.N, self.n_out, self.H, self.W) 27 | 28 | def test_SelfAttention(self): 29 | ngf = 16 30 | X = torch.randn(self.N, ngf, 16, 16) 31 | for spectral_norm in [True, False]: 32 | sa_layer = layers.SelfAttention(ngf, spectral_norm=spectral_norm) 33 | 34 | assert sa_layer(X).shape == (self.N, ngf, 16, 16) 35 | 36 | def test_SNConv2d(self): 37 | X = torch.ones(self.N, self.C, self.H, self.W) 38 | for default in [True, False]: 39 | layer = layers.SNConv2d(self.C, 40 | self.n_out, 41 | 1, 42 | 1, 43 | 0, 44 | default=default) 45 | 46 | assert layer(X).shape == (self.N, self.n_out, self.H, self.W) 47 | 48 | def test_SNLinear(self): 49 | X = torch.ones(self.N, self.C) 50 | for default in [True, False]: 51 | layer = layers.SNLinear(self.C, self.n_out, default=default) 52 | 53 | assert layer(X).shape == (self.N, self.n_out) 54 | 55 | def test_SNEmbedding(self): 56 | num_classes = 10 57 | X = torch.ones(self.N, dtype=torch.int64) 58 | for default in [True, False]: 59 | layer = layers.SNEmbedding(num_classes, 60 | self.n_out, 61 | default=default) 62 | 63 | assert layer(X).shape == (self.N, self.n_out) 64 | 65 | 66 | if __name__ == "__main__": 67 | test = TestLayers() 68 | test.setup() 69 | test.test_ConditionalBatchNorm2d() 70 | test.test_SelfAttention() 71 | test.test_SNConv2d() 72 | test.test_SNLinear() 73 | test.test_SNEmbedding() 74 | -------------------------------------------------------------------------------- /tests/modules/test_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_mimicry.modules import losses 4 | 5 | 6 | class TestLosses: 7 | def setup(self): 8 | self.output_real = torch.ones(4, 1) 9 | self.output_fake = torch.ones(4, 1) 10 | self.device = torch.device('cpu') 11 | 12 | def test_minimax_loss(self): 13 | loss_gen = losses.minimax_loss_gen(output_fake=self.output_fake) 14 | 15 | loss_dis = losses.minimax_loss_dis(output_fake=self.output_fake, 16 | output_real=self.output_real) 17 | 18 | assert loss_gen.dtype == torch.float32 19 | assert loss_dis.dtype == torch.float32 20 | assert loss_gen.item() - 0.3133 < 1e-2 21 | assert loss_dis.item() - 1.6265 < 1e-2 22 | 23 | def test_ns_loss(self): 24 | loss_gen = losses.ns_loss_gen(self.output_fake) 25 | 26 | assert loss_gen.dtype == torch.float32 27 | assert loss_gen.item() - 0.3133 < 1e-2 28 | 29 | def test_wasserstein_loss(self): 30 | loss_gen = losses.wasserstein_loss_gen(self.output_fake) 31 | loss_dis = losses.wasserstein_loss_dis(output_real=self.output_real, 32 | output_fake=self.output_fake) 33 | 34 | assert loss_gen.dtype == torch.float32 35 | assert loss_dis.dtype == torch.float32 36 | assert loss_gen.item() + 1.0 < 1e-2 37 | assert loss_dis.item() < 1e-2 38 | 39 | def test_hinge_loss(self): 40 | loss_gen = losses.hinge_loss_gen(output_fake=self.output_fake) 41 | loss_dis = losses.hinge_loss_dis(output_fake=self.output_fake, 42 | output_real=self.output_real) 43 | 44 | assert loss_gen.dtype == torch.float32 45 | assert loss_dis.dtype == torch.float32 46 | assert loss_gen.item() + 1.0 < 1e-2 47 | assert loss_dis.item() - 2.0 < 1e-2 48 | 49 | def teardown(self): 50 | del self.output_real 51 | del self.output_fake 52 | del self.device 53 | 54 | 55 | if __name__ == "__main__": 56 | test = TestLosses() 57 | test.setup() 58 | test.test_minimax_loss() 59 | test.test_ns_loss() 60 | test.test_wasserstein_loss() 61 | test.test_hinge_loss() 62 | test.teardown() 63 | -------------------------------------------------------------------------------- /tests/modules/test_resblocks.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import torch 4 | 5 | from torch_mimicry.modules import resblocks 6 | 7 | 8 | class TestResBlocks: 9 | def setup(self): 10 | self.images = torch.ones(4, 3, 16, 16) 11 | 12 | def test_GBlock(self): 13 | # Arguments 14 | num_classes_list = [0, 10] 15 | spectral_norm_list = [True, False] 16 | in_channels = 3 17 | out_channels = 8 18 | args_comb = product(num_classes_list, spectral_norm_list) 19 | 20 | for args in args_comb: 21 | num_classes = args[0] 22 | spectral_norm = args[1] 23 | 24 | if num_classes > 0: 25 | y = torch.ones((4, ), dtype=torch.int64) 26 | else: 27 | y = None 28 | 29 | gen_block_up = resblocks.GBlock(in_channels=in_channels, 30 | out_channels=out_channels, 31 | upsample=True, 32 | num_classes=num_classes, 33 | spectral_norm=spectral_norm) 34 | 35 | gen_block = resblocks.GBlock(in_channels=in_channels, 36 | out_channels=out_channels, 37 | upsample=False, 38 | num_classes=num_classes, 39 | spectral_norm=spectral_norm) 40 | 41 | gen_block_no_sc = resblocks.GBlock(in_channels=in_channels, 42 | out_channels=in_channels, 43 | upsample=False, 44 | num_classes=num_classes, 45 | spectral_norm=spectral_norm) 46 | 47 | assert gen_block_up(self.images, y).shape == (4, 8, 32, 32) 48 | assert gen_block(self.images, y).shape == (4, 8, 16, 16) 49 | assert gen_block_no_sc(self.images, y).shape == (4, 3, 16, 16) 50 | 51 | def test_DBlocks(self): 52 | in_channels = 3 53 | out_channels = 8 54 | 55 | for spectral_norm in [True, False]: 56 | dis_block_down = resblocks.DBlock(in_channels=in_channels, 57 | out_channels=out_channels, 58 | downsample=True, 59 | spectral_norm=spectral_norm) 60 | 61 | dis_block = resblocks.DBlock(in_channels=in_channels, 62 | out_channels=out_channels, 63 | downsample=False, 64 | spectral_norm=spectral_norm) 65 | 66 | dis_block_opt = resblocks.DBlockOptimized( 67 | in_channels=in_channels, 68 | out_channels=out_channels, 69 | spectral_norm=spectral_norm) 70 | 71 | assert dis_block(self.images).shape == (4, out_channels, 16, 16) 72 | assert dis_block_down(self.images).shape == (4, out_channels, 8, 8) 73 | assert dis_block_opt(self.images).shape == (4, out_channels, 8, 8) 74 | 75 | def teardown(self): 76 | del self.images 77 | 78 | 79 | if __name__ == "__main__": 80 | test = TestResBlocks() 81 | test.setup() 82 | test.test_GBlock() 83 | test.test_DBlocks() 84 | test.teardown() 85 | -------------------------------------------------------------------------------- /tests/modules/test_spectral_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch_mimicry.modules import spectral_norm 5 | 6 | 7 | class TestSpectralNorm: 8 | def setup(self): 9 | torch.manual_seed(0) 10 | self.N, self.C, self.H, self.W = (32, 16, 32, 32) 11 | self.n_in = self.C 12 | self.n_out = 32 13 | 14 | def test_SNConv2d(self): 15 | conv = spectral_norm.SNConv2d(self.n_in, self.n_out, 1, 1, 0) 16 | conv_def = nn.utils.spectral_norm( 17 | nn.Conv2d(self.n_in, self.n_out, 1, 1, 0)) 18 | 19 | # Init with ones to test implementation without randomness. 20 | nn.init.ones_(conv.weight.data) 21 | nn.init.ones_(conv_def.weight.data) 22 | 23 | # Get outputs 24 | X = torch.ones(self.N, self.C, self.H, self.W) 25 | output = conv(X) 26 | output_def = conv_def(X) 27 | 28 | # Test valid shape 29 | assert output.shape == output_def.shape == (32, 32, 32, 32) 30 | 31 | # Test per element it is very close to default implementation 32 | # to preserve correctness even when user toggles b/w implementations 33 | assert abs(torch.mean(output_def) - torch.mean(output)) < 1 34 | 35 | def test_SNLinear(self): 36 | linear = spectral_norm.SNLinear(self.n_in, self.n_out) 37 | linear_def = nn.utils.spectral_norm(nn.Linear(self.n_in, self.n_out)) 38 | 39 | nn.init.ones_(linear.weight.data) 40 | nn.init.ones_(linear_def.weight.data) 41 | 42 | X = torch.ones(self.N, self.n_in) 43 | output = linear(X) 44 | output_def = linear_def(X) 45 | 46 | assert output.shape == output_def.shape == (32, 32) 47 | assert abs(torch.mean(output_def) - torch.mean(output)) < 1 48 | 49 | def test_SNEmbedding(self): 50 | embedding = spectral_norm.SNEmbedding(self.N, self.n_out) 51 | embedding_def = nn.utils.spectral_norm(nn.Embedding( 52 | self.N, self.n_out)) 53 | 54 | nn.init.ones_(embedding.weight.data) 55 | nn.init.ones_(embedding_def.weight.data) 56 | 57 | X = torch.ones(self.N, dtype=torch.int64) 58 | output = embedding(X) 59 | output_def = embedding_def(X) 60 | 61 | assert output.shape == output_def.shape == (32, 32) 62 | assert abs(torch.mean(output_def) - torch.mean(output)) < 1 63 | 64 | 65 | if __name__ == "__main__": 66 | test = TestSpectralNorm() 67 | test.setup() 68 | test.test_SNConv2d() 69 | test.test_SNLinear() 70 | test.test_SNEmbedding() 71 | -------------------------------------------------------------------------------- /tests/nets/basemodel/test_basemodel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import pytest 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | 9 | from torch_mimicry.nets.basemodel.basemodel import BaseModel 10 | 11 | 12 | class ExampleModel(BaseModel): 13 | def __init__(self, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | self.linear = nn.Linear(1, 4) 16 | nn.init.xavier_uniform_(self.linear.weight.data) 17 | 18 | def forward(self, x): 19 | return 20 | 21 | 22 | class TestBaseModel: 23 | def setup(self): 24 | self.model = ExampleModel() 25 | self.opt = optim.Adam(self.model.parameters(), 2e-4, betas=(0.0, 0.9)) 26 | self.global_step = 0 27 | 28 | self.log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 29 | "test_log") 30 | 31 | def test_save_and_restore_checkpoint(self): 32 | ckpt_dir = os.path.join(self.log_dir, 'checkpoints/model') 33 | ckpt_file = os.path.join(ckpt_dir, 34 | "model_{}_steps.pth".format(self.global_step)) 35 | 36 | self.model.save_checkpoint(directory=ckpt_dir, 37 | optimizer=self.opt, 38 | global_step=self.global_step) 39 | 40 | restored_model = ExampleModel() 41 | restored_opt = optim.Adam(self.model.parameters(), 42 | 2e-4, 43 | betas=(0.0, 0.9)) 44 | 45 | restored_model.restore_checkpoint(ckpt_file=ckpt_file, 46 | optimizer=self.opt) 47 | 48 | # Check weights are preserved 49 | assert all( 50 | (restored_model.linear.weight == self.model.linear.weight) == 1) 51 | 52 | with pytest.raises(ValueError): 53 | restored_model.restore_checkpoint(ckpt_file=None, 54 | optimizer=self.opt) 55 | 56 | # Check optimizers have same state dict 57 | assert self.opt.state_dict() == restored_opt.state_dict() 58 | 59 | def test_count_params(self): 60 | num_total_params, num_trainable_params = self.model.count_params() 61 | 62 | assert num_trainable_params == num_total_params == 8 63 | 64 | def test_get_device(self): 65 | assert type(self.model.device) == torch.device 66 | 67 | def teardown(self): 68 | if os.path.exists(self.log_dir): 69 | shutil.rmtree(self.log_dir) 70 | 71 | del self.model 72 | del self.opt 73 | 74 | 75 | if __name__ == "__main__": 76 | test = TestBaseModel() 77 | test.setup() 78 | test.test_save_and_restore_checkpoint() 79 | test.test_count_params() 80 | test.test_get_device() 81 | test.teardown() 82 | -------------------------------------------------------------------------------- /tests/nets/cgan_pd/test_cgan_pd_128.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for cGAN-PD for image size 128. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.cgan_pd.cgan_pd_128 import CGANPDGenerator128, CGANPDDiscriminator128 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestCGANPD128: 13 | def setup(self): 14 | self.num_classes = 10 15 | self.nz = 128 16 | self.N, self.C, self.H, self.W = (4, 3, 128, 128) 17 | 18 | self.noise = torch.ones(self.N, self.nz) 19 | self.images = torch.ones(self.N, self.C, self.H, self.W) 20 | self.Y = torch.randint(low=0, high=self.num_classes, size=(self.N, )) 21 | 22 | self.ngf = 16 23 | self.ndf = 16 24 | 25 | self.netG = CGANPDGenerator128(num_classes=self.num_classes, 26 | ngf=self.ngf) 27 | self.netD = CGANPDDiscriminator128(num_classes=self.num_classes, 28 | ndf=self.ndf) 29 | 30 | def test_CGANPDGenerator128(self): 31 | images = self.netG(self.noise, self.Y) 32 | assert images.shape == (self.N, self.C, self.H, self.W) 33 | 34 | images = self.netG(self.noise, None) 35 | assert images.shape == (self.N, self.C, self.H, self.W) 36 | 37 | def test_CGANPDDiscriminator128(self): 38 | output = self.netD(self.images, self.Y) 39 | 40 | assert output.shape == (self.N, 1) 41 | 42 | def test_train_steps(self): 43 | real_batch = common.load_images(self.N) 44 | 45 | # Setup optimizers 46 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 47 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 48 | 49 | # Log statistics to check 50 | log_data = metric_log.MetricLog() 51 | 52 | # Test D train step 53 | log_data = self.netD.train_step(real_batch=real_batch, 54 | netG=self.netG, 55 | optD=optD, 56 | device='cpu', 57 | log_data=log_data) 58 | 59 | log_data = self.netG.train_step(real_batch=real_batch, 60 | netD=self.netD, 61 | optG=optG, 62 | log_data=log_data, 63 | device='cpu') 64 | 65 | for name, metric_dict in log_data.items(): 66 | assert type(name) == str 67 | assert type(metric_dict['value']) == float 68 | 69 | def teardown(self): 70 | del self.noise 71 | del self.images 72 | del self.Y 73 | del self.netG 74 | del self.netD 75 | 76 | 77 | if __name__ == "__main__": 78 | test = TestCGANPD128() 79 | test.setup() 80 | test.test_CGANPDGenerator128() 81 | test.test_CGANPDDiscriminator128() 82 | test.test_train_steps() 83 | test.teardown() 84 | -------------------------------------------------------------------------------- /tests/nets/cgan_pd/test_cgan_pd_32.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for cGAN-PD for image size 32. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.cgan_pd.cgan_pd_32 import CGANPDGenerator32, CGANPDDiscriminator32 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestCGANPD32: 13 | def setup(self): 14 | self.num_classes = 10 15 | self.nz = 128 16 | self.N, self.C, self.H, self.W = (4, 3, 32, 32) 17 | 18 | self.noise = torch.ones(self.N, self.nz) 19 | self.images = torch.ones(self.N, self.C, self.H, self.W) 20 | self.Y = torch.randint(low=0, high=self.num_classes, size=(self.N, )) 21 | 22 | self.ngf = 16 23 | self.ndf = 16 24 | 25 | self.netG = CGANPDGenerator32(num_classes=self.num_classes, 26 | ngf=self.ngf) 27 | self.netD = CGANPDDiscriminator32(num_classes=self.num_classes, 28 | ndf=self.ndf) 29 | 30 | def test_CGANPDGenerator32(self): 31 | images = self.netG(self.noise, self.Y) 32 | assert images.shape == (self.N, self.C, self.H, self.W) 33 | 34 | images = self.netG(self.noise, None) 35 | assert images.shape == (self.N, self.C, self.H, self.W) 36 | 37 | def test_CGANPDDiscriminator32(self): 38 | output = self.netD(self.images, self.Y) 39 | assert output.shape == (self.N, 1) 40 | 41 | def test_train_steps(self): 42 | real_batch = common.load_images(self.N) 43 | 44 | # Setup optimizers 45 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 46 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 47 | 48 | # Log statistics to check 49 | log_data = metric_log.MetricLog() 50 | 51 | # Test D train step 52 | log_data = self.netD.train_step(real_batch=real_batch, 53 | netG=self.netG, 54 | optD=optD, 55 | device='cpu', 56 | log_data=log_data) 57 | 58 | log_data = self.netG.train_step(real_batch=real_batch, 59 | netD=self.netD, 60 | optG=optG, 61 | log_data=log_data, 62 | device='cpu') 63 | 64 | for name, metric_dict in log_data.items(): 65 | assert type(name) == str 66 | assert type(metric_dict['value']) == float 67 | 68 | def teardown(self): 69 | del self.noise 70 | del self.images 71 | del self.Y 72 | del self.netG 73 | del self.netD 74 | 75 | 76 | if __name__ == "__main__": 77 | test = TestCGANPD32() 78 | test.setup() 79 | test.test_CGANPDGenerator32() 80 | test.test_CGANPDDiscriminator32() 81 | test.test_train_steps() 82 | test.teardown() 83 | -------------------------------------------------------------------------------- /tests/nets/dcgan/test_dcgan_128.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for DCGAN for image size 128. 3 | 4 | """ 5 | import torch 6 | import torch.optim as optim 7 | 8 | from torch_mimicry.nets.dcgan.dcgan_128 import DCGANGenerator128, DCGANDiscriminator128 9 | from torch_mimicry.training import metric_log 10 | from torch_mimicry.utils import common 11 | 12 | 13 | class TestDCGAN128: 14 | def setup(self): 15 | self.nz = 128 16 | self.N, self.C, self.H, self.W = (8, 3, 128, 128) 17 | self.ngf = 16 18 | self.ndf = 16 19 | 20 | self.netG = DCGANGenerator128(ngf=self.ngf) 21 | self.netD = DCGANDiscriminator128(ndf=self.ndf) 22 | 23 | def test_DCGANGenerator128(self): 24 | noise = torch.ones(self.N, self.nz) 25 | output = self.netG(noise) 26 | 27 | assert output.shape == (self.N, self.C, self.H, self.W) 28 | 29 | def test_DCGANDiscriminator128(self): 30 | images = torch.ones(self.N, self.C, self.H, self.W) 31 | output = self.netD(images) 32 | 33 | assert output.shape == (self.N, 1) 34 | 35 | def test_train_steps(self): 36 | real_batch = common.load_images(self.N, size=self.H) 37 | 38 | # Setup optimizers 39 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 40 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 41 | 42 | # Log statistics to check 43 | log_data = metric_log.MetricLog() 44 | 45 | # Test D train step 46 | log_data = self.netD.train_step(real_batch=real_batch, 47 | netG=self.netG, 48 | optD=optD, 49 | device='cpu', 50 | log_data=log_data) 51 | 52 | log_data = self.netG.train_step(real_batch=real_batch, 53 | netD=self.netD, 54 | optG=optG, 55 | log_data=log_data, 56 | device='cpu') 57 | 58 | for name, metric_dict in log_data.items(): 59 | assert type(name) == str 60 | assert type(metric_dict['value']) == float 61 | 62 | def teardown(self): 63 | del self.netG 64 | del self.netD 65 | 66 | 67 | if __name__ == "__main__": 68 | test = TestDCGAN128() 69 | test.setup() 70 | test.test_DCGANGenerator128() 71 | test.test_DCGANDiscriminator128() 72 | test.test_train_steps() 73 | test.teardown() 74 | -------------------------------------------------------------------------------- /tests/nets/dcgan/test_dcgan_32.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for DCGAN for image size 32. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.dcgan.dcgan_32 import DCGANGenerator32, DCGANDiscriminator32 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestDCGAN32: 13 | def setup(self): 14 | self.nz = 128 15 | self.N, self.C, self.H, self.W = (8, 3, 32, 32) 16 | self.ngf = 16 17 | self.ndf = 16 18 | 19 | self.netG = DCGANGenerator32(ngf=self.ngf) 20 | self.netD = DCGANDiscriminator32(ndf=self.ndf) 21 | 22 | def test_DCGANGenerator32(self): 23 | noise = torch.ones(self.N, self.nz) 24 | output = self.netG(noise) 25 | 26 | assert output.shape == (self.N, self.C, self.H, self.W) 27 | 28 | def test_DCGANDiscriminator32(self): 29 | images = torch.ones(self.N, self.C, self.H, self.W) 30 | output = self.netD(images) 31 | 32 | assert output.shape == (self.N, 1) 33 | 34 | def test_train_steps(self): 35 | real_batch = common.load_images(self.N, size=self.H) 36 | 37 | # Setup optimizers 38 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 39 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 40 | 41 | # Log statistics to check 42 | log_data = metric_log.MetricLog() 43 | 44 | # Test D train step 45 | log_data = self.netD.train_step(real_batch=real_batch, 46 | netG=self.netG, 47 | optD=optD, 48 | device='cpu', 49 | log_data=log_data) 50 | 51 | log_data = self.netG.train_step(real_batch=real_batch, 52 | netD=self.netD, 53 | optG=optG, 54 | log_data=log_data, 55 | device='cpu') 56 | 57 | for name, metric_dict in log_data.items(): 58 | assert type(name) == str 59 | assert type(metric_dict['value']) == float 60 | 61 | def teardown(self): 62 | del self.netG 63 | del self.netD 64 | 65 | 66 | if __name__ == "__main__": 67 | test = TestDCGAN32() 68 | test.setup() 69 | test.test_DCGANGenerator32() 70 | test.test_DCGANDiscriminator32() 71 | test.test_train_steps() 72 | test.teardown() 73 | -------------------------------------------------------------------------------- /tests/nets/dcgan/test_dcgan_48.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for DCGAN for image size 48. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.dcgan.dcgan_48 import DCGANGenerator48, DCGANDiscriminator48 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestDCGAN48: 13 | def setup(self): 14 | self.nz = 128 15 | self.N, self.C, self.H, self.W = (8, 3, 48, 48) 16 | self.ngf = 16 17 | self.ndf = 16 18 | 19 | self.netG = DCGANGenerator48(ngf=self.ngf) 20 | self.netD = DCGANDiscriminator48(ndf=self.ndf) 21 | 22 | def test_DCGANGenerator48(self): 23 | noise = torch.ones(self.N, self.nz) 24 | output = self.netG(noise) 25 | 26 | assert output.shape == (self.N, self.C, self.H, self.W) 27 | 28 | def test_DCGANDiscriminator48(self): 29 | images = torch.ones(self.N, self.C, self.H, self.W) 30 | output = self.netD(images) 31 | 32 | assert output.shape == (self.N, 1) 33 | 34 | def test_train_steps(self): 35 | real_batch = common.load_images(self.N, size=self.H) 36 | 37 | # Setup optimizers 38 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 39 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 40 | 41 | # Log statistics to check 42 | log_data = metric_log.MetricLog() 43 | 44 | # Test D train step 45 | log_data = self.netD.train_step(real_batch=real_batch, 46 | netG=self.netG, 47 | optD=optD, 48 | device='cpu', 49 | log_data=log_data) 50 | 51 | log_data = self.netG.train_step(real_batch=real_batch, 52 | netD=self.netD, 53 | optG=optG, 54 | log_data=log_data, 55 | device='cpu') 56 | 57 | for name, metric_dict in log_data.items(): 58 | assert type(name) == str 59 | assert type(metric_dict['value']) == float 60 | 61 | def teardown(self): 62 | del self.netG 63 | del self.netD 64 | 65 | 66 | if __name__ == "__main__": 67 | test = TestDCGAN48() 68 | test.setup() 69 | test.test_DCGANGenerator48() 70 | test.test_DCGANDiscriminator48() 71 | test.test_train_steps() 72 | test.teardown() 73 | -------------------------------------------------------------------------------- /tests/nets/dcgan/test_dcgan_64.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for DCGAN for image size 64. 3 | 4 | """ 5 | import torch 6 | import torch.optim as optim 7 | 8 | from torch_mimicry.nets.dcgan.dcgan_64 import DCGANGenerator64, DCGANDiscriminator64 9 | from torch_mimicry.training import metric_log 10 | from torch_mimicry.utils import common 11 | 12 | 13 | class TestDCGAN64: 14 | def setup(self): 15 | self.nz = 128 16 | self.N, self.C, self.H, self.W = (8, 3, 64, 64) 17 | self.ngf = 16 18 | self.ndf = 16 19 | 20 | self.netG = DCGANGenerator64(ngf=self.ngf) 21 | self.netD = DCGANDiscriminator64(ndf=self.ndf) 22 | 23 | def test_DCGANGenerator64(self): 24 | noise = torch.ones(self.N, self.nz) 25 | output = self.netG(noise) 26 | 27 | assert output.shape == (self.N, self.C, self.H, self.W) 28 | 29 | def test_DCGANDiscriminator64(self): 30 | images = torch.ones(self.N, self.C, self.H, self.W) 31 | output = self.netD(images) 32 | 33 | assert output.shape == (self.N, 1) 34 | 35 | def test_train_steps(self): 36 | real_batch = common.load_images(self.N, size=self.H) 37 | 38 | # Setup optimizers 39 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 40 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 41 | 42 | # Log statistics to check 43 | log_data = metric_log.MetricLog() 44 | 45 | # Test D train step 46 | log_data = self.netD.train_step(real_batch=real_batch, 47 | netG=self.netG, 48 | optD=optD, 49 | device='cpu', 50 | log_data=log_data) 51 | 52 | log_data = self.netG.train_step(real_batch=real_batch, 53 | netD=self.netD, 54 | optG=optG, 55 | log_data=log_data, 56 | device='cpu') 57 | 58 | for name, metric_dict in log_data.items(): 59 | assert type(name) == str 60 | assert type(metric_dict['value']) == float 61 | 62 | def teardown(self): 63 | del self.netG 64 | del self.netD 65 | 66 | 67 | if __name__ == "__main__": 68 | test = TestDCGAN64() 69 | test.setup() 70 | test.test_DCGANGenerator64() 71 | test.test_DCGANDiscriminator64() 72 | test.test_train_steps() 73 | test.teardown() 74 | -------------------------------------------------------------------------------- /tests/nets/dcgan/test_dcgan_cifar.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for DCGAN for image size 32. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.dcgan.dcgan_cifar import DCGANGeneratorCIFAR, DCGANDiscriminatorCIFAR 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestDCGAN32: 13 | def setup(self): 14 | self.nz = 128 15 | self.N, self.C, self.H, self.W = (8, 3, 32, 32) 16 | self.ngf = 16 17 | self.ndf = 16 18 | 19 | self.netG = DCGANGeneratorCIFAR(ngf=self.ngf) 20 | self.netD = DCGANDiscriminatorCIFAR(ndf=self.ndf) 21 | 22 | def test_DCGANGeneratorCIFAR(self): 23 | noise = torch.ones(self.N, self.nz) 24 | output = self.netG(noise) 25 | 26 | assert output.shape == (self.N, self.C, self.H, self.W) 27 | 28 | def test_DCGANDiscriminatorCIFAR(self): 29 | images = torch.ones(self.N, self.C, self.H, self.W) 30 | output = self.netD(images) 31 | 32 | assert output.shape == (self.N, 1) 33 | 34 | def test_train_steps(self): 35 | real_batch = common.load_images(self.N, size=self.H) 36 | 37 | # Setup optimizers 38 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 39 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 40 | 41 | # Log statistics to check 42 | log_data = metric_log.MetricLog() 43 | 44 | # Test D train step 45 | log_data = self.netD.train_step(real_batch=real_batch, 46 | netG=self.netG, 47 | optD=optD, 48 | device='cpu', 49 | log_data=log_data) 50 | 51 | log_data = self.netG.train_step(real_batch=real_batch, 52 | netD=self.netD, 53 | optG=optG, 54 | log_data=log_data, 55 | device='cpu') 56 | 57 | for name, metric_dict in log_data.items(): 58 | assert type(name) == str 59 | assert type(metric_dict['value']) == float 60 | 61 | def teardown(self): 62 | del self.netG 63 | del self.netD 64 | 65 | 66 | if __name__ == "__main__": 67 | test = TestDCGAN32() 68 | test.setup() 69 | test.test_DCGANGeneratorCIFAR() 70 | test.test_DCGANDiscriminatorCIFAR() 71 | test.test_train_steps() 72 | test.teardown() 73 | -------------------------------------------------------------------------------- /tests/nets/gan/test_cgan.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torch_mimicry.nets.gan.cgan import BaseConditionalGenerator, BaseConditionalDiscriminator 6 | 7 | 8 | class ExampleGenerator(BaseConditionalGenerator): 9 | def __init__(self, *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | self.l1 = nn.Linear(1, 1) 12 | 13 | def forward(self, x, y): 14 | return torch.ones(x.shape[0], 3, 32, 32) 15 | 16 | 17 | class ExampleDiscriminator(BaseConditionalDiscriminator): 18 | def __init__(self, *args, **kwargs): 19 | super().__init__(*args, **kwargs) 20 | self.l1 = nn.Linear(1, 1) 21 | 22 | def forward(self, x, y): 23 | return torch.ones(x.shape[0]) 24 | 25 | 26 | class TestBaseGAN: 27 | def setup(self): 28 | self.N = 1 29 | self.device = "cpu" 30 | 31 | self.nz = 16 32 | self.ngf = 16 33 | self.ndf = 16 34 | self.bottom_width = 4 35 | self.num_classes = 10 36 | self.loss_type = 'gan' 37 | 38 | self.netG = ExampleGenerator(num_classes=self.num_classes, 39 | ngf=self.ngf, 40 | bottom_width=self.bottom_width, 41 | nz=self.nz, 42 | loss_type=self.loss_type) 43 | self.netD = ExampleDiscriminator(num_classes=self.num_classes, 44 | ndf=self.ndf, 45 | loss_type=self.loss_type) 46 | self.output_fake = torch.ones(self.N, 1) 47 | self.output_real = torch.ones(self.N, 1) 48 | 49 | def test_generate_images(self): 50 | with pytest.raises(ValueError): 51 | images = self.netG.generate_images(10, c=self.num_classes + 1) 52 | 53 | images = self.netG.generate_images(10) 54 | assert images.shape == (10, 3, 32, 32) 55 | assert images.device == self.netG.device 56 | 57 | images = self.netG.generate_images(10, c=0) 58 | assert images.shape == (10, 3, 32, 32) 59 | assert images.device == self.netG.device 60 | 61 | def test_generate_images_with_labels(self): 62 | with pytest.raises(ValueError): 63 | images, labels = self.netG.generate_images_with_labels( 64 | 10, c=self.num_classes + 1) 65 | 66 | images, labels = self.netG.generate_images_with_labels(10) 67 | assert images.shape == (10, 3, 32, 32) 68 | assert images.device == self.netG.device 69 | assert labels.shape == (10, ) 70 | assert labels.device == self.netG.device 71 | 72 | images, labels = self.netG.generate_images_with_labels(10, c=0) 73 | assert images.shape == (10, 3, 32, 32) 74 | assert images.device == self.netG.device 75 | assert labels.shape == (10, ) 76 | assert labels.device == self.netG.device 77 | 78 | def test_compute_GAN_loss(self): 79 | losses = ['gan', 'ns', 'hinge', 'wasserstein'] 80 | 81 | for loss_type in losses: 82 | self.netG.loss_type = loss_type 83 | self.netD.loss_type = loss_type 84 | 85 | errG = self.netG.compute_gan_loss(output=self.output_fake) 86 | errD = self.netD.compute_gan_loss(output_real=self.output_real, 87 | output_fake=self.output_fake) 88 | 89 | assert type(errG.item()) == float 90 | assert type(errD.item()) == float 91 | 92 | def teardown(self): 93 | del self.netG 94 | del self.netD 95 | del self.output_real 96 | del self.output_fake 97 | 98 | 99 | if __name__ == "__main__": 100 | test = TestBaseGAN() 101 | test.setup() 102 | test.test_generate_images() 103 | test.test_generate_images_with_labels() 104 | test.test_compute_GAN_loss() 105 | test.teardown() 106 | -------------------------------------------------------------------------------- /tests/nets/gan/test_gan.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torch_mimicry.nets.gan.gan import BaseGenerator, BaseDiscriminator 6 | 7 | 8 | class ExampleGenerator(BaseGenerator): 9 | def __init__(self, *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | self.l1 = nn.Linear(1, 1) 12 | 13 | def forward(self, x): 14 | return torch.ones(x.shape[0], 3, 32, 32) 15 | 16 | 17 | class ExampleDiscriminator(BaseDiscriminator): 18 | def __init__(self, *args, **kwargs): 19 | super().__init__(*args, **kwargs) 20 | self.l1 = nn.Linear(1, 1) 21 | 22 | def forward(self, x): 23 | return torch.ones(x.shape[0]) 24 | 25 | 26 | class TestBaseGAN: 27 | def setup(self): 28 | self.N = 1 29 | self.device = "cpu" 30 | self.real_label_val = 1.0 31 | self.fake_label_val = 0.0 32 | 33 | self.nz = 16 34 | self.ngf = 16 35 | self.ndf = 16 36 | self.bottom_width = 4 37 | self.loss_type = 'gan' 38 | 39 | self.netG = ExampleGenerator(ngf=self.ngf, 40 | bottom_width=self.bottom_width, 41 | nz=self.nz, 42 | loss_type=self.loss_type) 43 | self.netD = ExampleDiscriminator(ndf=self.ndf, 44 | loss_type=self.loss_type) 45 | self.output_fake = torch.ones(self.N, 1) 46 | self.output_real = torch.ones(self.N, 1) 47 | 48 | def test_generate_images(self): 49 | images = self.netG.generate_images(10) 50 | 51 | assert images.shape == (10, 3, 32, 32) 52 | assert images.device == self.netG.device 53 | 54 | def test_compute_GAN_loss(self): 55 | losses = ['gan', 'ns', 'hinge', 'wasserstein'] 56 | 57 | for loss_type in losses: 58 | self.netG.loss_type = loss_type 59 | self.netD.loss_type = loss_type 60 | 61 | errG = self.netG.compute_gan_loss(output=self.output_fake) 62 | errD = self.netD.compute_gan_loss(output_real=self.output_real, 63 | output_fake=self.output_fake) 64 | 65 | assert type(errG.item()) == float 66 | assert type(errD.item()) == float 67 | 68 | with pytest.raises(ValueError): 69 | self.netG.loss_type = 'invalid' 70 | self.netG.compute_gan_loss(output=self.output_fake) 71 | 72 | with pytest.raises(ValueError): 73 | self.netD.loss_type = 'invalid' 74 | self.netD.compute_gan_loss(output_real=self.output_real, 75 | output_fake=self.output_fake) 76 | 77 | def teardown(self): 78 | del self.netG 79 | del self.netD 80 | del self.output_real 81 | del self.output_fake 82 | 83 | 84 | if __name__ == "__main__": 85 | test = TestBaseGAN() 86 | test.setup() 87 | test.test_generate_images() 88 | test.test_compute_GAN_loss() 89 | test.teardown() 90 | -------------------------------------------------------------------------------- /tests/nets/infomax_gan/test_infomax_gan_128.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for InfoMaxGAN for image size 128. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.infomax_gan.infomax_gan_128 import InfoMaxGANGenerator128, InfoMaxGANDiscriminator128 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestInfoMaxGAN128: 13 | def setup(self): 14 | self.nz = 128 15 | self.N, self.C, self.H, self.W = (8, 3, 128, 128) 16 | self.ngf = 16 17 | self.ndf = 16 18 | 19 | self.netG = InfoMaxGANGenerator128(ngf=self.ngf) 20 | self.netD = InfoMaxGANDiscriminator128(ndf=self.ndf) 21 | 22 | def test_InfoMaxGANGenerator128(self): 23 | noise = torch.ones(self.N, self.nz) 24 | output = self.netG(noise) 25 | 26 | assert output.shape == (self.N, self.C, self.H, self.W) 27 | 28 | def test_InfoMaxGANDiscriminator128(self): 29 | images = torch.ones(self.N, self.C, self.H, self.W) 30 | output, local_feat, global_feat = self.netD(images) 31 | 32 | assert output.shape == (self.N, 1) 33 | assert local_feat.shape == (self.N, self.netD.ndf, self.H >> 5, 34 | self.W >> 5) 35 | assert global_feat.shape == (self.N, self.netD.ndf) 36 | 37 | def test_train_steps(self): 38 | real_batch = common.load_images(self.N, size=self.H) 39 | 40 | # Setup optimizers 41 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 42 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 43 | 44 | # Log statistics to check 45 | log_data = metric_log.MetricLog() 46 | 47 | # Test D train step 48 | log_data = self.netD.train_step(real_batch=real_batch, 49 | netG=self.netG, 50 | optD=optD, 51 | device='cpu', 52 | log_data=log_data) 53 | 54 | log_data = self.netG.train_step(real_batch=real_batch, 55 | netD=self.netD, 56 | optG=optG, 57 | log_data=log_data, 58 | device='cpu') 59 | 60 | for name, metric_dict in log_data.items(): 61 | assert type(name) == str 62 | assert type(metric_dict['value']) == float 63 | 64 | def teardown(self): 65 | del self.netG 66 | del self.netD 67 | 68 | 69 | if __name__ == "__main__": 70 | test = TestInfoMaxGAN128() 71 | test.setup() 72 | test.test_InfoMaxGANGenerator128() 73 | test.test_InfoMaxGANDiscriminator128() 74 | test.test_train_steps() 75 | test.teardown() 76 | -------------------------------------------------------------------------------- /tests/nets/infomax_gan/test_infomax_gan_32.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for InfoMaxGAN for image size 32. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.infomax_gan.infomax_gan_32 import InfoMaxGANGenerator32, InfoMaxGANDiscriminator32 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestInfoMaxGAN32: 13 | def setup(self): 14 | self.nz = 128 15 | self.N, self.C, self.H, self.W = (8, 3, 32, 32) 16 | self.ngf = 16 17 | self.ndf = 16 18 | 19 | self.netG = InfoMaxGANGenerator32(ngf=self.ngf) 20 | self.netD = InfoMaxGANDiscriminator32(ndf=self.ndf) 21 | 22 | def test_InfoMaxGANGenerator32(self): 23 | noise = torch.ones(self.N, self.nz) 24 | output = self.netG(noise) 25 | 26 | assert output.shape == (self.N, self.C, self.H, self.W) 27 | 28 | def test_InfoMaxGANDiscriminator32(self): 29 | images = torch.ones(self.N, self.C, self.H, self.W) 30 | output, local_feat, global_feat = self.netD(images) 31 | 32 | assert output.shape == (self.N, 1) 33 | assert local_feat.shape == (self.N, self.netD.ndf, self.H >> 2, 34 | self.W >> 2) 35 | assert global_feat.shape == (self.N, self.netD.ndf) 36 | 37 | def test_train_steps(self): 38 | real_batch = common.load_images(self.N, size=self.H) 39 | 40 | # Setup optimizers 41 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 42 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 43 | 44 | # Log statistics to check 45 | log_data = metric_log.MetricLog() 46 | 47 | # Test D train step 48 | log_data = self.netD.train_step(real_batch=real_batch, 49 | netG=self.netG, 50 | optD=optD, 51 | device='cpu', 52 | log_data=log_data) 53 | 54 | log_data = self.netG.train_step(real_batch=real_batch, 55 | netD=self.netD, 56 | optG=optG, 57 | log_data=log_data, 58 | device='cpu') 59 | 60 | for name, metric_dict in log_data.items(): 61 | assert type(name) == str 62 | assert type(metric_dict['value']) == float 63 | 64 | def teardown(self): 65 | del self.netG 66 | del self.netD 67 | 68 | 69 | if __name__ == "__main__": 70 | test = TestInfoMaxGAN32() 71 | test.setup() 72 | test.test_InfoMaxGANGenerator32() 73 | test.test_InfoMaxGANDiscriminator32() 74 | test.test_train_steps() 75 | test.teardown() 76 | -------------------------------------------------------------------------------- /tests/nets/infomax_gan/test_infomax_gan_48.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for InfoMaxGAN for image size 48. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.infomax_gan.infomax_gan_48 import InfoMaxGANGenerator48, InfoMaxGANDiscriminator48 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestInfoMaxGAN48: 13 | def setup(self): 14 | self.nz = 128 15 | self.N, self.C, self.H, self.W = (8, 3, 48, 48) 16 | self.ngf = 16 17 | self.ndf = 16 18 | 19 | self.netG = InfoMaxGANGenerator48(ngf=self.ngf) 20 | self.netD = InfoMaxGANDiscriminator48(ndf=self.ndf) 21 | 22 | def test_InfoMaxGANGenerator48(self): 23 | noise = torch.ones(self.N, self.nz) 24 | output = self.netG(noise) 25 | 26 | assert output.shape == (self.N, self.C, self.H, self.W) 27 | 28 | def test_InfoMaxGANDiscriminator48(self): 29 | images = torch.ones(self.N, self.C, self.H, self.W) 30 | output, local_feat, global_feat = self.netD(images) 31 | 32 | assert output.shape == (self.N, 1) 33 | assert local_feat.shape == (self.N, self.netD.ndf >> 1, self.H >> 4, 34 | self.W >> 4) 35 | assert global_feat.shape == (self.N, self.netD.ndf) 36 | 37 | def test_train_steps(self): 38 | real_batch = common.load_images(self.N, size=self.H) 39 | 40 | # Setup optimizers 41 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 42 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 43 | 44 | # Log statistics to check 45 | log_data = metric_log.MetricLog() 46 | 47 | # Test D train step 48 | log_data = self.netD.train_step(real_batch=real_batch, 49 | netG=self.netG, 50 | optD=optD, 51 | device='cpu', 52 | log_data=log_data) 53 | 54 | log_data = self.netG.train_step(real_batch=real_batch, 55 | netD=self.netD, 56 | optG=optG, 57 | log_data=log_data, 58 | device='cpu') 59 | 60 | for name, metric_dict in log_data.items(): 61 | assert type(name) == str 62 | assert type(metric_dict['value']) == float 63 | 64 | def teardown(self): 65 | del self.netG 66 | del self.netD 67 | 68 | 69 | if __name__ == "__main__": 70 | test = TestInfoMaxGAN48() 71 | test.setup() 72 | test.test_InfoMaxGANGenerator48() 73 | test.test_InfoMaxGANDiscriminator48() 74 | test.test_train_steps() 75 | test.teardown() 76 | -------------------------------------------------------------------------------- /tests/nets/infomax_gan/test_infomax_gan_64.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for InfoMaxGAN for image size 64. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.infomax_gan.infomax_gan_64 import InfoMaxGANGenerator64, InfoMaxGANDiscriminator64 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestInfoMaxGAN64: 13 | def setup(self): 14 | self.nz = 128 15 | self.N, self.C, self.H, self.W = (8, 3, 64, 64) 16 | self.ngf = 16 17 | self.ndf = 16 18 | 19 | self.netG = InfoMaxGANGenerator64(ngf=self.ngf) 20 | self.netD = InfoMaxGANDiscriminator64(ndf=self.ndf) 21 | 22 | def test_InfoMaxGANGenerator64(self): 23 | noise = torch.ones(self.N, self.nz) 24 | output = self.netG(noise) 25 | 26 | assert output.shape == (self.N, self.C, self.H, self.W) 27 | 28 | def test_InfoMaxGANDiscriminator64(self): 29 | images = torch.ones(self.N, self.C, self.H, self.W) 30 | output, local_feat, global_feat = self.netD(images) 31 | 32 | assert output.shape == (self.N, 1) 33 | assert local_feat.shape == (self.N, self.netD.ndf >> 1, self.H >> 4, 34 | self.W >> 4) 35 | assert global_feat.shape == (self.N, self.netD.ndf) 36 | 37 | def test_train_steps(self): 38 | real_batch = common.load_images(self.N, size=self.H) 39 | 40 | # Setup optimizers 41 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 42 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 43 | 44 | # Log statistics to check 45 | log_data = metric_log.MetricLog() 46 | 47 | # Test D train step 48 | log_data = self.netD.train_step(real_batch=real_batch, 49 | netG=self.netG, 50 | optD=optD, 51 | device='cpu', 52 | log_data=log_data) 53 | 54 | log_data = self.netG.train_step(real_batch=real_batch, 55 | netD=self.netD, 56 | optG=optG, 57 | log_data=log_data, 58 | device='cpu') 59 | 60 | for name, metric_dict in log_data.items(): 61 | assert type(name) == str 62 | assert type(metric_dict['value']) == float 63 | 64 | def teardown(self): 65 | del self.netG 66 | del self.netD 67 | 68 | 69 | if __name__ == "__main__": 70 | test = TestInfoMaxGAN64() 71 | test.setup() 72 | test.test_InfoMaxGANGenerator64() 73 | test.test_InfoMaxGANDiscriminator64() 74 | test.test_train_steps() 75 | test.teardown() 76 | -------------------------------------------------------------------------------- /tests/nets/infomax_gan/test_infomax_gan_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test for SSGAN specific functions at the discriminator. 3 | """ 4 | import pytest 5 | import math 6 | import torch 7 | 8 | from torch_mimicry.nets.infomax_gan.infomax_gan_base import BaseDiscriminator 9 | 10 | 11 | class ExampleDiscriminator(BaseDiscriminator): 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | 15 | def forward(self, x): 16 | return 17 | 18 | 19 | class TestInfoMaxGANBase: 20 | def setup(self): 21 | self.ndf = 16 22 | self.nrkhs = 32 23 | self.N = 4 24 | self.netD = ExampleDiscriminator(ndf=self.ndf, nrkhs=self.nrkhs) 25 | 26 | def test_infonce_loss(self): 27 | l = torch.ones(self.N, self.nrkhs, 1) 28 | m = torch.ones(self.N, self.nrkhs, 1) 29 | 30 | loss = self.netD.infonce_loss(l=l, m=m) 31 | prob = math.exp(-1 * loss.item()) 32 | 33 | assert type(loss.item()) == float 34 | 35 | # 1/4 probability 36 | assert abs(prob - 0.25) < 1e-2 37 | 38 | def test_compute_infomax_loss(self): 39 | with pytest.raises(ValueError): 40 | local_feat = torch.ones(self.N, self.ndf, 4, 4) 41 | global_feat = torch.ones(self.N, self.nrkhs) 42 | scale = 0.2 43 | self.netD.compute_infomax_loss(local_feat, global_feat, scale) 44 | 45 | def teardown(self): 46 | del self.netD 47 | 48 | 49 | if __name__ == "__main__": 50 | test = TestInfoMaxGANBase() 51 | test.setup() 52 | test.test_infonce_loss() 53 | test.test_compute_infomax_loss() 54 | test.teardown() 55 | -------------------------------------------------------------------------------- /tests/nets/sagan/test_sagan_128.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for SAGAN for image size 128. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.sagan.sagan_128 import SAGANGenerator128, SAGANDiscriminator128 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestSAGAN128: 13 | def setup(self): 14 | self.num_classes = 10 15 | self.nz = 128 16 | self.N, self.C, self.H, self.W = (4, 3, 128, 128) 17 | 18 | self.noise = torch.ones(self.N, self.nz) 19 | self.images = torch.ones(self.N, self.C, self.H, self.W) 20 | self.Y = torch.randint(low=0, high=self.num_classes, size=(self.N, )) 21 | 22 | self.ngf = 32 23 | self.ndf = 64 24 | 25 | self.netG = SAGANGenerator128(num_classes=self.num_classes, 26 | ngf=self.ngf) 27 | self.netD = SAGANDiscriminator128(num_classes=self.num_classes, 28 | ndf=self.ndf) 29 | 30 | def test_SAGANGenerator128(self): 31 | images = self.netG(self.noise, self.Y) 32 | assert images.shape == (self.N, self.C, self.H, self.W) 33 | 34 | images = self.netG(self.noise, None) 35 | assert images.shape == (self.N, self.C, self.H, self.W) 36 | 37 | def test_SAGANDiscriminator128(self): 38 | output = self.netD(self.images, self.Y) 39 | 40 | assert output.shape == (self.N, 1) 41 | 42 | def test_train_steps(self): 43 | real_batch = common.load_images(self.N) 44 | 45 | # Setup optimizers 46 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 47 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 48 | 49 | # Log statistics to check 50 | log_data = metric_log.MetricLog() 51 | 52 | # Test D train step 53 | log_data = self.netD.train_step(real_batch=real_batch, 54 | netG=self.netG, 55 | optD=optD, 56 | device='cpu', 57 | log_data=log_data) 58 | 59 | log_data = self.netG.train_step(real_batch=real_batch, 60 | netD=self.netD, 61 | optG=optG, 62 | log_data=log_data, 63 | device='cpu') 64 | 65 | for name, metric_dict in log_data.items(): 66 | assert type(name) == str 67 | assert type(metric_dict['value']) == float 68 | 69 | def teardown(self): 70 | del self.noise 71 | del self.images 72 | del self.Y 73 | del self.netG 74 | del self.netD 75 | 76 | 77 | if __name__ == "__main__": 78 | test = TestSAGAN128() 79 | test.setup() 80 | test.test_SAGANGenerator128() 81 | test.test_SAGANDiscriminator128() 82 | test.test_train_steps() 83 | test.teardown() 84 | -------------------------------------------------------------------------------- /tests/nets/sagan/test_sagan_32.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for SAGAN for image size 32. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.sagan.sagan_32 import SAGANGenerator32, SAGANDiscriminator32 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestSAGAN32: 13 | def setup(self): 14 | self.num_classes = 10 15 | self.nz = 128 16 | self.N, self.C, self.H, self.W = (4, 3, 32, 32) 17 | 18 | self.noise = torch.ones(self.N, self.nz) 19 | self.images = torch.ones(self.N, self.C, self.H, self.W) 20 | self.Y = torch.randint(low=0, high=self.num_classes, size=(self.N, )) 21 | 22 | self.ngf = 32 23 | self.ndf = 64 24 | 25 | self.netG = SAGANGenerator32(num_classes=self.num_classes, 26 | ngf=self.ngf) 27 | self.netD = SAGANDiscriminator32(num_classes=self.num_classes, 28 | ndf=self.ndf) 29 | 30 | def test_SAGANGenerator32(self): 31 | images = self.netG(self.noise, self.Y) 32 | assert images.shape == (self.N, self.C, self.H, self.W) 33 | 34 | images = self.netG(self.noise, None) 35 | assert images.shape == (self.N, self.C, self.H, self.W) 36 | 37 | def test_SAGANDiscriminator32(self): 38 | output = self.netD(self.images, self.Y) 39 | assert output.shape == (self.N, 1) 40 | 41 | def test_train_steps(self): 42 | real_batch = common.load_images(self.N) 43 | 44 | # Setup optimizers 45 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 46 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 47 | 48 | # Log statistics to check 49 | log_data = metric_log.MetricLog() 50 | 51 | # Test D train step 52 | log_data = self.netD.train_step(real_batch=real_batch, 53 | netG=self.netG, 54 | optD=optD, 55 | device='cpu', 56 | log_data=log_data) 57 | 58 | log_data = self.netG.train_step(real_batch=real_batch, 59 | netD=self.netD, 60 | optG=optG, 61 | log_data=log_data, 62 | device='cpu') 63 | 64 | for name, metric_dict in log_data.items(): 65 | assert type(name) == str 66 | assert type(metric_dict['value']) == float 67 | 68 | def teardown(self): 69 | del self.noise 70 | del self.images 71 | del self.Y 72 | del self.netG 73 | del self.netD 74 | 75 | 76 | if __name__ == "__main__": 77 | test = TestSAGAN32() 78 | test.setup() 79 | test.test_SAGANGenerator32() 80 | test.test_SAGANDiscriminator32() 81 | test.test_train_steps() 82 | test.teardown() 83 | -------------------------------------------------------------------------------- /tests/nets/sngan/test_sngan_128.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for SNGAN for image size 128. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.sngan.sngan_128 import SNGANGenerator128, SNGANDiscriminator128 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestSNGAN128: 13 | def setup(self): 14 | self.nz = 128 15 | self.N, self.C, self.H, self.W = (8, 3, 128, 128) 16 | self.ngf = 16 17 | self.ndf = 16 18 | 19 | self.netG = SNGANGenerator128(ngf=self.ngf) 20 | self.netD = SNGANDiscriminator128(ndf=self.ndf) 21 | 22 | def test_SNGANGenerator128(self): 23 | noise = torch.ones(self.N, self.nz) 24 | output = self.netG(noise) 25 | 26 | assert output.shape == (self.N, self.C, self.H, self.W) 27 | 28 | def test_SNGANDiscriminator128(self): 29 | images = torch.ones(self.N, self.C, self.H, self.W) 30 | output = self.netD(images) 31 | 32 | assert output.shape == (self.N, 1) 33 | 34 | def test_train_steps(self): 35 | real_batch = common.load_images(self.N, size=self.H) 36 | 37 | # Setup optimizers 38 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 39 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 40 | 41 | # Log statistics to check 42 | log_data = metric_log.MetricLog() 43 | 44 | # Test D train step 45 | log_data = self.netD.train_step(real_batch=real_batch, 46 | netG=self.netG, 47 | optD=optD, 48 | device='cpu', 49 | log_data=log_data) 50 | 51 | log_data = self.netG.train_step(real_batch=real_batch, 52 | netD=self.netD, 53 | optG=optG, 54 | log_data=log_data, 55 | device='cpu') 56 | 57 | for name, metric_dict in log_data.items(): 58 | assert type(name) == str 59 | assert type(metric_dict['value']) == float 60 | 61 | def teardown(self): 62 | del self.netG 63 | del self.netD 64 | 65 | 66 | if __name__ == "__main__": 67 | test = TestSNGAN128() 68 | test.setup() 69 | test.test_SNGANGenerator128() 70 | test.test_SNGANDiscriminator128() 71 | test.test_train_steps() 72 | test.teardown() 73 | -------------------------------------------------------------------------------- /tests/nets/sngan/test_sngan_32.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for SNGAN for image size 32. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.sngan.sngan_32 import SNGANGenerator32, SNGANDiscriminator32 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestSNGAN32: 13 | def setup(self): 14 | self.nz = 128 15 | self.N, self.C, self.H, self.W = (8, 3, 32, 32) 16 | self.ngf = 16 17 | self.ndf = 16 18 | 19 | self.netG = SNGANGenerator32(ngf=self.ngf) 20 | self.netD = SNGANDiscriminator32(ndf=self.ndf) 21 | 22 | def test_SNGANGenerator32(self): 23 | noise = torch.ones(self.N, self.nz) 24 | output = self.netG(noise) 25 | 26 | assert output.shape == (self.N, self.C, self.H, self.W) 27 | 28 | def test_SNGANDiscriminator32(self): 29 | images = torch.ones(self.N, self.C, self.H, self.W) 30 | output = self.netD(images) 31 | 32 | assert output.shape == (self.N, 1) 33 | 34 | def test_train_steps(self): 35 | # Get real and fake images 36 | real_batch = common.load_images(self.N, size=self.H) 37 | 38 | # Setup optimizers 39 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 40 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 41 | 42 | # Log statistics to check 43 | log_data = metric_log.MetricLog() 44 | 45 | # Test D train step 46 | log_data = self.netD.train_step(real_batch=real_batch, 47 | netG=self.netG, 48 | optD=optD, 49 | device='cpu', 50 | log_data=log_data) 51 | 52 | log_data = self.netG.train_step(real_batch=real_batch, 53 | netD=self.netD, 54 | optG=optG, 55 | log_data=log_data, 56 | device='cpu') 57 | 58 | for name, metric_dict in log_data.items(): 59 | assert type(name) == str 60 | assert type(metric_dict['value']) == float 61 | 62 | def teardown(self): 63 | del self.netG 64 | del self.netD 65 | 66 | 67 | if __name__ == "__main__": 68 | test = TestSNGAN32() 69 | test.setup() 70 | test.test_SNGANGenerator32() 71 | test.test_SNGANDiscriminator32() 72 | test.test_train_steps() 73 | test.teardown() 74 | -------------------------------------------------------------------------------- /tests/nets/sngan/test_sngan_48.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for SNGAN for image size 48. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.sngan.sngan_48 import SNGANGenerator48, SNGANDiscriminator48 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestSNGAN48: 13 | def setup(self): 14 | self.nz = 128 15 | self.N, self.C, self.H, self.W = (8, 3, 48, 48) 16 | self.ngf = 16 17 | self.ndf = 16 18 | 19 | self.netG = SNGANGenerator48(ngf=self.ngf) 20 | self.netD = SNGANDiscriminator48(ndf=self.ndf) 21 | 22 | def test_SNGANGenerator48(self): 23 | noise = torch.ones(self.N, self.nz) 24 | output = self.netG(noise) 25 | 26 | assert output.shape == (self.N, self.C, self.H, self.W) 27 | 28 | def test_SNGANDiscriminator48(self): 29 | images = torch.ones(self.N, self.C, self.H, self.W) 30 | output = self.netD(images) 31 | 32 | assert output.shape == (self.N, 1) 33 | 34 | def test_train_steps(self): 35 | real_batch = common.load_images(self.N, size=self.H) 36 | 37 | # Setup optimizers 38 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 39 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 40 | 41 | # Log statistics to check 42 | log_data = metric_log.MetricLog() 43 | 44 | # Test D train step 45 | log_data = self.netD.train_step(real_batch=real_batch, 46 | netG=self.netG, 47 | optD=optD, 48 | device='cpu', 49 | log_data=log_data) 50 | 51 | log_data = self.netG.train_step(real_batch=real_batch, 52 | netD=self.netD, 53 | optG=optG, 54 | log_data=log_data, 55 | device='cpu') 56 | 57 | for name, metric_dict in log_data.items(): 58 | assert type(name) == str 59 | assert type(metric_dict['value']) == float 60 | 61 | def teardown(self): 62 | del self.netG 63 | del self.netD 64 | 65 | 66 | if __name__ == "__main__": 67 | test = TestSNGAN48() 68 | test.setup() 69 | test.test_SNGANGenerator48() 70 | test.test_SNGANDiscriminator48() 71 | test.test_train_steps() 72 | test.teardown() 73 | -------------------------------------------------------------------------------- /tests/nets/sngan/test_sngan_64.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for SNGAN for image size 64. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.sngan.sngan_64 import SNGANGenerator64, SNGANDiscriminator64 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestSNGAN64: 13 | def setup(self): 14 | self.nz = 128 15 | self.N, self.C, self.H, self.W = (8, 3, 64, 64) 16 | self.ngf = 16 17 | self.ndf = 16 18 | 19 | self.netG = SNGANGenerator64(ngf=self.ngf) 20 | self.netD = SNGANDiscriminator64(ndf=self.ndf) 21 | 22 | def test_SNGANGenerator64(self): 23 | noise = torch.ones(self.N, self.nz) 24 | output = self.netG(noise) 25 | 26 | assert output.shape == (self.N, self.C, self.H, self.W) 27 | 28 | def test_SNGANDiscriminator64(self): 29 | images = torch.ones(self.N, self.C, self.H, self.W) 30 | output = self.netD(images) 31 | 32 | assert output.shape == (self.N, 1) 33 | 34 | def test_train_steps(self): 35 | real_batch = common.load_images(self.N, size=self.H) 36 | 37 | # Setup optimizers 38 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 39 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 40 | 41 | # Log statistics to check 42 | log_data = metric_log.MetricLog() 43 | 44 | # Test D train step 45 | log_data = self.netD.train_step(real_batch=real_batch, 46 | netG=self.netG, 47 | optD=optD, 48 | device='cpu', 49 | log_data=log_data) 50 | 51 | log_data = self.netG.train_step(real_batch=real_batch, 52 | netD=self.netD, 53 | optG=optG, 54 | log_data=log_data, 55 | device='cpu') 56 | 57 | for name, metric_dict in log_data.items(): 58 | assert type(name) == str 59 | assert type(metric_dict['value']) == float 60 | 61 | def teardown(self): 62 | del self.netG 63 | del self.netD 64 | 65 | 66 | if __name__ == "__main__": 67 | test = TestSNGAN64() 68 | test.setup() 69 | test.test_SNGANGenerator64() 70 | test.test_SNGANDiscriminator64() 71 | test.test_train_steps() 72 | test.teardown() 73 | -------------------------------------------------------------------------------- /tests/nets/ssgan/test_ssgan_128.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for SSGAN for image size 128. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.ssgan.ssgan_128 import SSGANGenerator128, SSGANDiscriminator128 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestSSGAN128: 13 | def setup(self): 14 | self.nz = 128 15 | self.N, self.C, self.H, self.W = (8, 3, 128, 128) 16 | self.ngf = 16 17 | self.ndf = 16 18 | self.device = 'cpu' 19 | 20 | self.netG = SSGANGenerator128(ngf=self.ngf) 21 | self.netD = SSGANDiscriminator128(ndf=self.ndf) 22 | 23 | def test_SSGANGenerator128(self): 24 | noise = torch.ones(self.N, self.nz) 25 | output = self.netG(noise) 26 | 27 | assert output.shape == (self.N, self.C, self.H, self.W) 28 | 29 | def test_SSGANDiscriminator128(self): 30 | images = torch.ones(self.N, self.C, self.H, self.W) 31 | output, labels = self.netD(images) 32 | 33 | assert output.shape == (self.N, 1) 34 | assert labels.shape == (self.N, 4) 35 | 36 | def test_train_steps(self): 37 | real_batch = common.load_images(self.N, size=self.H) 38 | 39 | # Setup optimizers 40 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 41 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 42 | 43 | # Log statistics to check 44 | log_data = metric_log.MetricLog() 45 | 46 | # Test D train step 47 | log_data = self.netD.train_step(real_batch=real_batch, 48 | netG=self.netG, 49 | optD=optD, 50 | device=self.device, 51 | log_data=log_data) 52 | 53 | log_data = self.netG.train_step(real_batch=real_batch, 54 | netD=self.netD, 55 | optG=optG, 56 | log_data=log_data, 57 | device=self.device) 58 | 59 | for name, metric_dict in log_data.items(): 60 | assert type(name) == str 61 | assert type(metric_dict['value']) == float 62 | 63 | def teardown(self): 64 | del self.netG 65 | del self.netD 66 | 67 | 68 | if __name__ == "__main__": 69 | test = TestSSGAN128() 70 | test.setup() 71 | test.test_SSGANGenerator128() 72 | test.test_SSGANDiscriminator128() 73 | test.test_train_steps() 74 | test.teardown() 75 | -------------------------------------------------------------------------------- /tests/nets/ssgan/test_ssgan_32.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for SSGAN for image size 32. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.ssgan.ssgan_32 import SSGANGenerator32, SSGANDiscriminator32 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestSSGAN32: 13 | def setup(self): 14 | self.nz = 128 15 | self.N, self.C, self.H, self.W = (8, 3, 32, 32) 16 | self.ngf = 16 17 | self.ndf = 16 18 | 19 | self.netG = SSGANGenerator32(ngf=self.ngf) 20 | self.netD = SSGANDiscriminator32(ndf=self.ndf) 21 | 22 | def test_SSGANGenerator32(self): 23 | noise = torch.ones(self.N, self.nz) 24 | output = self.netG(noise) 25 | 26 | assert output.shape == (self.N, self.C, self.H, self.W) 27 | 28 | def test_SSGANDiscriminator32(self): 29 | images = torch.ones(self.N, self.C, self.H, self.W) 30 | output, labels = self.netD(images) 31 | 32 | assert output.shape == (self.N, 1) 33 | assert labels.shape == (self.N, 4) 34 | 35 | def test_train_steps(self): 36 | real_batch = common.load_images(self.N, size=self.H) 37 | 38 | # Setup optimizers 39 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 40 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 41 | 42 | # Log statistics to check 43 | log_data = metric_log.MetricLog() 44 | 45 | # Test D train step 46 | log_data = self.netD.train_step(real_batch=real_batch, 47 | netG=self.netG, 48 | optD=optD, 49 | device='cpu', 50 | log_data=log_data) 51 | 52 | log_data = self.netG.train_step(real_batch=real_batch, 53 | netD=self.netD, 54 | optG=optG, 55 | log_data=log_data, 56 | device='cpu') 57 | 58 | for name, metric_dict in log_data.items(): 59 | assert type(name) == str 60 | assert type(metric_dict['value']) == float 61 | 62 | def teardown(self): 63 | del self.netG 64 | del self.netD 65 | 66 | 67 | if __name__ == "__main__": 68 | test = TestSSGAN32() 69 | test.setup() 70 | test.test_SSGANGenerator32() 71 | test.test_SSGANDiscriminator32() 72 | test.test_train_steps() 73 | test.teardown() 74 | -------------------------------------------------------------------------------- /tests/nets/ssgan/test_ssgan_48.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for SSGAN for image size 48. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.ssgan.ssgan_48 import SSGANGenerator48, SSGANDiscriminator48 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestSSGAN48: 13 | def setup(self): 14 | self.nz = 128 15 | self.N, self.C, self.H, self.W = (8, 3, 48, 48) 16 | self.ngf = 16 17 | self.ndf = 16 18 | 19 | self.netG = SSGANGenerator48(ngf=self.ngf) 20 | self.netD = SSGANDiscriminator48(ndf=self.ndf) 21 | 22 | def test_SSGANGenerator48(self): 23 | noise = torch.ones(self.N, self.nz) 24 | output = self.netG(noise) 25 | 26 | assert output.shape == (self.N, self.C, self.H, self.W) 27 | 28 | def test_SSGANDiscriminator48(self): 29 | images = torch.ones(self.N, self.C, self.H, self.W) 30 | output, labels = self.netD(images) 31 | 32 | assert output.shape == (self.N, 1) 33 | assert labels.shape == (self.N, 4) 34 | 35 | def test_train_steps(self): 36 | real_batch = common.load_images(self.N, size=self.H) 37 | 38 | # Setup optimizers 39 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 40 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 41 | 42 | # Log statistics to check 43 | log_data = metric_log.MetricLog() 44 | 45 | # Test D train step 46 | log_data = self.netD.train_step(real_batch=real_batch, 47 | netG=self.netG, 48 | optD=optD, 49 | device='cpu', 50 | log_data=log_data) 51 | 52 | log_data = self.netG.train_step(real_batch=real_batch, 53 | netD=self.netD, 54 | optG=optG, 55 | log_data=log_data, 56 | device='cpu') 57 | 58 | for name, metric_dict in log_data.items(): 59 | assert type(name) == str 60 | assert type(metric_dict['value']) == float 61 | 62 | def teardown(self): 63 | del self.netG 64 | del self.netD 65 | 66 | 67 | if __name__ == "__main__": 68 | test = TestSSGAN48() 69 | test.setup() 70 | test.test_SSGANGenerator48() 71 | test.test_SSGANDiscriminator48() 72 | test.test_train_steps() 73 | test.teardown() 74 | -------------------------------------------------------------------------------- /tests/nets/ssgan/test_ssgan_64.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for SSGAN for image size 64. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.ssgan.ssgan_64 import SSGANGenerator64, SSGANDiscriminator64 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestSSGAN64: 13 | def setup(self): 14 | self.nz = 128 15 | self.N, self.C, self.H, self.W = (8, 3, 64, 64) 16 | self.ngf = 16 17 | self.ndf = 16 18 | 19 | self.netG = SSGANGenerator64(ngf=self.ngf) 20 | self.netD = SSGANDiscriminator64(ndf=self.ndf) 21 | 22 | def test_SSGANGenerator64(self): 23 | noise = torch.ones(self.N, self.nz) 24 | output = self.netG(noise) 25 | 26 | assert output.shape == (self.N, self.C, self.H, self.W) 27 | 28 | def test_SSGANDiscriminator64(self): 29 | images = torch.ones(self.N, self.C, self.H, self.W) 30 | output, labels = self.netD(images) 31 | 32 | assert output.shape == (self.N, 1) 33 | assert labels.shape == (self.N, 4) 34 | 35 | def test_train_steps(self): 36 | real_batch = common.load_images(self.N, size=self.H) 37 | 38 | # Setup optimizers 39 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 40 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 41 | 42 | # Log statistics to check 43 | log_data = metric_log.MetricLog() 44 | 45 | # Test D train step 46 | log_data = self.netD.train_step(real_batch=real_batch, 47 | netG=self.netG, 48 | optD=optD, 49 | device='cpu', 50 | log_data=log_data) 51 | 52 | log_data = self.netG.train_step(real_batch=real_batch, 53 | netD=self.netD, 54 | optG=optG, 55 | log_data=log_data, 56 | device='cpu') 57 | 58 | for name, metric_dict in log_data.items(): 59 | assert type(name) == str 60 | assert type(metric_dict['value']) == float 61 | 62 | def teardown(self): 63 | del self.netG 64 | del self.netD 65 | 66 | 67 | if __name__ == "__main__": 68 | test = TestSSGAN64() 69 | test.setup() 70 | test.test_SSGANGenerator64() 71 | test.test_SSGANDiscriminator64() 72 | test.test_train_steps() 73 | test.teardown() 74 | -------------------------------------------------------------------------------- /tests/nets/ssgan/test_ssgan_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test for SSGAN specific functions at the discriminator. 3 | """ 4 | import pytest 5 | import torch 6 | 7 | from torch_mimicry.nets.ssgan.ssgan_base import SSGANBaseDiscriminator 8 | from torch_mimicry.utils import common 9 | 10 | 11 | class ExampleDiscriminator(SSGANBaseDiscriminator): 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | 15 | def forward(self, x): 16 | return torch.ones(x.shape[0]) 17 | 18 | 19 | class TestSSGANBase: 20 | def setup(self): 21 | self.netD = ExampleDiscriminator(ndf=16) 22 | 23 | def test_rot_tensor(self): 24 | # Load image and model 25 | image, _ = common.load_images(1, size=32) 26 | 27 | # For any rotation, after performing the same action 4 times, 28 | # you should return to the same pixel value 29 | for deg in [0, 90, 180, 270]: 30 | x = image.clone() 31 | for _ in range(4): 32 | x = self.netD._rot_tensor(x, deg) 33 | 34 | assert torch.sum((x - image)**2) < 1e-5 35 | 36 | def test_rotate_batch(self): 37 | # Load image and model 38 | images, _ = common.load_images(8, size=32) 39 | 40 | check = images.clone() 41 | check, labels = self.netD._rotate_batch(check) 42 | degrees = [0, 90, 180, 270] 43 | 44 | # Rotate 3 more times to get back to original. 45 | for i in range(check.shape[0]): 46 | for _ in range(3): 47 | check[i] = self.netD._rot_tensor(check[i], degrees[labels[i]]) 48 | 49 | assert torch.sum((images - check)**2) < 1e-5 50 | 51 | with pytest.raises(ValueError): 52 | self.netD._rot_tensor(check[i], 9999) 53 | 54 | def teardown(self): 55 | del self.netD 56 | 57 | 58 | if __name__ == "__main__": 59 | test = TestSSGANBase() 60 | test.setup() 61 | test.test_rot_tensor() 62 | test.test_rotate_batch() 63 | test.teardown() 64 | -------------------------------------------------------------------------------- /tests/nets/wgan_gp/test_wgan_gp_128.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for WGAN-GP for image size 128. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.wgan_gp.wgan_gp_128 import WGANGPGenerator128, WGANGPDiscriminator128 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestWGANGP128: 13 | def setup(self): 14 | self.nz = 128 15 | self.N, self.C, self.H, self.W = (8, 3, 128, 128) 16 | self.ngf = 16 17 | self.ndf = 16 18 | 19 | self.netG = WGANGPGenerator128(ngf=self.ngf) 20 | self.netD = WGANGPDiscriminator128(ndf=self.ndf) 21 | 22 | def test_WGANGPGenerator128(self): 23 | noise = torch.ones(self.N, self.nz) 24 | output = self.netG(noise) 25 | 26 | assert output.shape == (self.N, self.C, self.H, self.W) 27 | 28 | def test_WGANGPDiscriminator128(self): 29 | images = torch.ones(self.N, self.C, self.H, self.W) 30 | output = self.netD(images) 31 | 32 | assert output.shape == (self.N, 1) 33 | 34 | def test_train_steps(self): 35 | # Get real and fake images 36 | real_batch = common.load_images(self.N, size=self.H) 37 | 38 | # Setup optimizers 39 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 40 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 41 | 42 | # Log statistics to check 43 | log_data = metric_log.MetricLog() 44 | 45 | # Test D train step 46 | log_data = self.netD.train_step(real_batch=real_batch, 47 | netG=self.netG, 48 | optD=optD, 49 | device='cpu', 50 | log_data=log_data) 51 | 52 | log_data = self.netG.train_step(real_batch=real_batch, 53 | netD=self.netD, 54 | optG=optG, 55 | log_data=log_data, 56 | device='cpu') 57 | 58 | for name, metric_dict in log_data.items(): 59 | assert type(name) == str 60 | assert type(metric_dict['value']) == float 61 | 62 | def teardown(self): 63 | del self.netG 64 | del self.netD 65 | 66 | 67 | if __name__ == "__main__": 68 | test = TestWGANGP128() 69 | test.setup() 70 | test.test_WGANGPGenerator128() 71 | test.test_WGANGPDiscriminator128() 72 | test.test_train_steps() 73 | test.teardown() 74 | -------------------------------------------------------------------------------- /tests/nets/wgan_gp/test_wgan_gp_32.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for WGAN-GP for image size 32. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.wgan_gp.wgan_gp_32 import WGANGPGenerator32, WGANGPDiscriminator32 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestWGANGP32: 13 | def setup(self): 14 | self.nz = 128 15 | self.N, self.C, self.H, self.W = (8, 3, 32, 32) 16 | self.ngf = 16 17 | self.ndf = 16 18 | 19 | self.netG = WGANGPGenerator32(ngf=self.ngf) 20 | self.netD = WGANGPDiscriminator32(ndf=self.ndf) 21 | 22 | def test_WGANGPGenerator32(self): 23 | noise = torch.ones(self.N, self.nz) 24 | output = self.netG(noise) 25 | 26 | assert output.shape == (self.N, self.C, self.H, self.W) 27 | 28 | def test_WGANGPDiscriminator32(self): 29 | images = torch.ones(self.N, self.C, self.H, self.W) 30 | output = self.netD(images) 31 | 32 | assert output.shape == (self.N, 1) 33 | 34 | def test_train_steps(self): 35 | # Get real and fake images 36 | real_batch = common.load_images(self.N, size=self.H) 37 | 38 | # Setup optimizers 39 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 40 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 41 | 42 | # Log statistics to check 43 | log_data = metric_log.MetricLog() 44 | 45 | # Test D train step 46 | log_data = self.netD.train_step(real_batch=real_batch, 47 | netG=self.netG, 48 | optD=optD, 49 | device='cpu', 50 | log_data=log_data) 51 | 52 | log_data = self.netG.train_step(real_batch=real_batch, 53 | netD=self.netD, 54 | optG=optG, 55 | log_data=log_data, 56 | device='cpu') 57 | 58 | for name, metric_dict in log_data.items(): 59 | assert type(name) == str 60 | assert type(metric_dict['value']) == float 61 | 62 | def teardown(self): 63 | del self.netG 64 | del self.netD 65 | 66 | 67 | if __name__ == "__main__": 68 | test = TestWGANGP32() 69 | test.setup() 70 | test.test_WGANGPGenerator32() 71 | test.test_WGANGPDiscriminator32() 72 | test.test_train_steps() 73 | test.teardown() 74 | -------------------------------------------------------------------------------- /tests/nets/wgan_gp/test_wgan_gp_48.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for WGAN-GP for image size 48. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.wgan_gp.wgan_gp_48 import WGANGPGenerator48, WGANGPDiscriminator48 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestWGANGP48: 13 | def setup(self): 14 | self.nz = 128 15 | self.N, self.C, self.H, self.W = (8, 3, 48, 48) 16 | self.ngf = 16 17 | self.ndf = 16 18 | 19 | self.netG = WGANGPGenerator48(ngf=self.ngf) 20 | self.netD = WGANGPDiscriminator48(ndf=self.ndf) 21 | 22 | def test_WGANGPGenerator48(self): 23 | noise = torch.ones(self.N, self.nz) 24 | output = self.netG(noise) 25 | 26 | assert output.shape == (self.N, self.C, self.H, self.W) 27 | 28 | def test_WGANGPDiscriminator48(self): 29 | images = torch.ones(self.N, self.C, self.H, self.W) 30 | output = self.netD(images) 31 | 32 | assert output.shape == (self.N, 1) 33 | 34 | def test_train_steps(self): 35 | # Get real and fake images 36 | real_batch = common.load_images(self.N, size=self.H) 37 | 38 | # Setup optimizers 39 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 40 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 41 | 42 | # Log statistics to check 43 | log_data = metric_log.MetricLog() 44 | 45 | # Test D train step 46 | log_data = self.netD.train_step(real_batch=real_batch, 47 | netG=self.netG, 48 | optD=optD, 49 | device='cpu', 50 | log_data=log_data) 51 | 52 | log_data = self.netG.train_step(real_batch=real_batch, 53 | netD=self.netD, 54 | optG=optG, 55 | log_data=log_data, 56 | device='cpu') 57 | 58 | for name, metric_dict in log_data.items(): 59 | assert type(name) == str 60 | assert type(metric_dict['value']) == float 61 | 62 | def teardown(self): 63 | del self.netG 64 | del self.netD 65 | 66 | 67 | if __name__ == "__main__": 68 | test = TestWGANGP48() 69 | test.setup() 70 | test.test_WGANGPGenerator48() 71 | test.test_WGANGPDiscriminator48() 72 | test.test_train_steps() 73 | test.teardown() 74 | -------------------------------------------------------------------------------- /tests/nets/wgan_gp/test_wgan_gp_64.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test functions for WGAN-GP for image size 64. 3 | """ 4 | import torch 5 | import torch.optim as optim 6 | 7 | from torch_mimicry.nets.wgan_gp.wgan_gp_64 import WGANGPGenerator64, WGANGPDiscriminator64 8 | from torch_mimicry.training import metric_log 9 | from torch_mimicry.utils import common 10 | 11 | 12 | class TestWGANGP64: 13 | def setup(self): 14 | self.nz = 128 15 | self.N, self.C, self.H, self.W = (8, 3, 64, 64) 16 | self.ngf = 16 17 | self.ndf = 16 18 | 19 | self.netG = WGANGPGenerator64(ngf=self.ngf) 20 | self.netD = WGANGPDiscriminator64(ndf=self.ndf) 21 | 22 | def test_WGANGPGenerator64(self): 23 | noise = torch.ones(self.N, self.nz) 24 | output = self.netG(noise) 25 | 26 | assert output.shape == (self.N, self.C, self.H, self.W) 27 | 28 | def test_WGANGPDiscriminator64(self): 29 | images = torch.ones(self.N, self.C, self.H, self.W) 30 | output = self.netD(images) 31 | 32 | assert output.shape == (self.N, 1) 33 | 34 | def test_train_steps(self): 35 | # Get real and fake images 36 | real_batch = common.load_images(self.N, size=self.H) 37 | 38 | # Setup optimizers 39 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9)) 40 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9)) 41 | 42 | # Log statistics to check 43 | log_data = metric_log.MetricLog() 44 | 45 | # Test D train step 46 | log_data = self.netD.train_step(real_batch=real_batch, 47 | netG=self.netG, 48 | optD=optD, 49 | device='cpu', 50 | log_data=log_data) 51 | 52 | log_data = self.netG.train_step(real_batch=real_batch, 53 | netD=self.netD, 54 | optG=optG, 55 | log_data=log_data, 56 | device='cpu') 57 | 58 | for name, metric_dict in log_data.items(): 59 | assert type(name) == str 60 | assert type(metric_dict['value']) == float 61 | 62 | def teardown(self): 63 | del self.netG 64 | del self.netD 65 | 66 | 67 | if __name__ == "__main__": 68 | test = TestWGANGP64() 69 | test.setup() 70 | test.test_WGANGPGenerator64() 71 | test.test_WGANGPDiscriminator64() 72 | test.test_train_steps() 73 | test.teardown() 74 | -------------------------------------------------------------------------------- /tests/nets/wgan_gp/test_wgan_gp_resblocks.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import torch 4 | 5 | from torch_mimicry.nets.wgan_gp import wgan_gp_resblocks 6 | 7 | 8 | class TestResBlocks: 9 | def setup(self): 10 | self.images = torch.ones(4, 3, 16, 16) 11 | 12 | def test_GBlock(self): 13 | # Arguments 14 | num_classes_list = [0, 10] 15 | spectral_norm_list = [True, False] 16 | in_channels = 3 17 | out_channels = 8 18 | args_comb = product(num_classes_list, spectral_norm_list) 19 | 20 | for args in args_comb: 21 | num_classes = args[0] 22 | spectral_norm = args[1] 23 | 24 | if num_classes > 0: 25 | y = torch.ones((4, ), dtype=torch.int64) 26 | else: 27 | y = None 28 | 29 | gen_block_up = wgan_gp_resblocks.GBlock( 30 | in_channels=in_channels, 31 | out_channels=out_channels, 32 | upsample=True, 33 | num_classes=num_classes, 34 | spectral_norm=spectral_norm) 35 | 36 | gen_block = wgan_gp_resblocks.GBlock(in_channels=in_channels, 37 | out_channels=out_channels, 38 | upsample=False, 39 | num_classes=num_classes, 40 | spectral_norm=spectral_norm) 41 | 42 | assert gen_block_up(self.images, y).shape == (4, 8, 32, 32) 43 | assert gen_block(self.images, y).shape == (4, 8, 16, 16) 44 | 45 | def test_DBlocks(self): 46 | in_channels = 3 47 | out_channels = 8 48 | 49 | for spectral_norm in [True, False]: 50 | dis_block_down = wgan_gp_resblocks.DBlock( 51 | in_channels=in_channels, 52 | out_channels=out_channels, 53 | downsample=True, 54 | spectral_norm=spectral_norm) 55 | 56 | dis_block = wgan_gp_resblocks.DBlock(in_channels=in_channels, 57 | out_channels=out_channels, 58 | downsample=False, 59 | spectral_norm=spectral_norm) 60 | 61 | dis_block_opt = wgan_gp_resblocks.DBlockOptimized( 62 | in_channels=in_channels, 63 | out_channels=out_channels, 64 | spectral_norm=spectral_norm) 65 | 66 | assert dis_block(self.images).shape == (4, out_channels, 16, 16) 67 | assert dis_block_down(self.images).shape == (4, out_channels, 8, 8) 68 | assert dis_block_opt(self.images).shape == (4, out_channels, 8, 8) 69 | 70 | def teardown(self): 71 | del self.images 72 | 73 | 74 | if __name__ == "__main__": 75 | test = TestResBlocks() 76 | test.setup() 77 | test.test_GBlock() 78 | test.test_DBlocks() 79 | test.teardown() 80 | -------------------------------------------------------------------------------- /tests/training/test_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torch_mimicry.nets.gan import gan, cgan 8 | from torch_mimicry.training import logger, metric_log 9 | 10 | 11 | class ExampleGen(gan.BaseGenerator): 12 | def __init__(self, 13 | bottom_width=4, 14 | nz=4, 15 | ngf=16, 16 | loss_type='gan', 17 | *args, 18 | **kwargs): 19 | super().__init__(nz=nz, 20 | ngf=ngf, 21 | bottom_width=bottom_width, 22 | loss_type=loss_type, 23 | *args, 24 | **kwargs) 25 | self.linear = nn.Linear(self.nz, 3072) 26 | 27 | def forward(self, x): 28 | output = self.linear(x) 29 | output = output.view(x.shape[0], 3, 32, 32) 30 | 31 | return output 32 | 33 | 34 | class ExampleConditionalGen(cgan.BaseConditionalGenerator): 35 | def __init__(self, 36 | bottom_width=4, 37 | nz=4, 38 | ngf=16, 39 | loss_type='gan', 40 | num_classes=10, 41 | **kwargs): 42 | super().__init__(nz=nz, 43 | ngf=ngf, 44 | bottom_width=bottom_width, 45 | loss_type=loss_type, 46 | num_classes=num_classes, 47 | **kwargs) 48 | self.linear = nn.Linear(self.nz, 3072) 49 | 50 | def forward(self, x, y=None): 51 | output = self.linear(x) 52 | output = output.view(x.shape[0], 3, 32, 32) 53 | 54 | return output 55 | 56 | 57 | class TestLogger: 58 | def setup(self): 59 | self.log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 60 | "test_log") 61 | 62 | self.logger = logger.Logger(log_dir=self.log_dir, 63 | num_steps=100, 64 | dataset_size=50000, 65 | flush_secs=30, 66 | device=torch.device('cpu')) 67 | 68 | self.scalars = [ 69 | 'errG', 70 | 'errD', 71 | 'D(x)', 72 | 'D(G(z))', 73 | 'img', 74 | 'lr_D', 75 | 'lr_G', 76 | ] 77 | 78 | def test_print_log(self): 79 | log_data = metric_log.MetricLog() 80 | global_step = 10 81 | 82 | # Populate log data with some value 83 | for scalar in self.scalars: 84 | if scalar == 'img': 85 | continue 86 | 87 | log_data.add_metric(scalar, 1.0) 88 | 89 | printed = self.logger.print_log(global_step=global_step, 90 | log_data=log_data, 91 | time_taken=10) 92 | 93 | assert printed == ( 94 | 'INFO: [Epoch 1/1][Global Step: 10/100] ' + 95 | '\n| D(G(z)): 1.0\n| D(x): 1.0\n| errD: 1.0\n| errG: 1.0' + 96 | '\n| lr_D: 1.0\n| lr_G: 1.0\n| (10.0000 sec/idx)') 97 | 98 | def test_vis_images(self): 99 | netG = ExampleGen() 100 | netG_conditional = ExampleConditionalGen() 101 | 102 | global_step = 10 103 | num_images = 64 104 | 105 | # Test unconditional 106 | self.logger.vis_images(netG, global_step, num_images) 107 | img_dir = os.path.join(self.log_dir, 'images') 108 | filenames = os.listdir(img_dir) 109 | assert 'fake_samples_step_10.png' in filenames 110 | assert 'fixed_fake_samples_step_10.png' in filenames 111 | 112 | # Remove images 113 | for file in filenames: 114 | os.remove(os.path.join(img_dir, file)) 115 | 116 | # Test conditional 117 | self.logger.vis_images(netG_conditional, global_step, num_images) 118 | assert 'fake_samples_step_10.png' in filenames 119 | assert 'fixed_fake_samples_step_10.png' in filenames 120 | 121 | def teardown(self): 122 | shutil.rmtree(self.log_dir) 123 | 124 | 125 | if __name__ == "__main__": 126 | test = TestLogger() 127 | test.setup() 128 | test.test_print_log() 129 | test.test_vis_images() 130 | test.teardown() 131 | -------------------------------------------------------------------------------- /tests/training/test_metric_log.py: -------------------------------------------------------------------------------- 1 | from torch_mimicry.training import metric_log 2 | 3 | 4 | class TestMetricLog: 5 | def setup(self): 6 | self.log_data = metric_log.MetricLog() 7 | 8 | def test_add_metric(self): 9 | # Singular metric 10 | self.log_data.add_metric('singular', 1.0124214) 11 | assert self.log_data['singular'] == 1.0124 12 | 13 | # Multiple metrics under same group 14 | self.log_data.add_metric('errD', 1.00001, group='loss') 15 | self.log_data.add_metric('errG', 1.0011, group='loss') 16 | 17 | assert self.log_data.get_group_name( 18 | 'errD') == self.log_data.get_group_name('errG') 19 | 20 | def teardown(self): 21 | del self.log_data 22 | 23 | 24 | if __name__ == "__main__": 25 | test = TestMetricLog() 26 | test.setup() 27 | test.test_add_metric() 28 | test.teardown() 29 | -------------------------------------------------------------------------------- /tests/training/test_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.optim as optim 3 | import pytest 4 | 5 | from torch_mimicry.training import scheduler, metric_log 6 | 7 | 8 | class TestLRScheduler: 9 | def setup(self): 10 | self.netD = nn.Linear(10, 10) 11 | self.netG = nn.Linear(10, 10) 12 | 13 | self.num_steps = 10 14 | self.lr_D = 2e-4 15 | self.lr_G = 2e-4 16 | 17 | def get_lr(self, optimizer): 18 | return optimizer.param_groups[0]['lr'] 19 | 20 | def test_linear_decay(self): 21 | optD = optim.Adam(self.netD.parameters(), self.lr_D, betas=(0.0, 0.9)) 22 | optG = optim.Adam(self.netG.parameters(), self.lr_G, betas=(0.0, 0.9)) 23 | 24 | lr_scheduler = scheduler.LRScheduler(lr_decay='linear', 25 | optD=optD, 26 | optG=optG, 27 | num_steps=self.num_steps, 28 | start_step=5) 29 | 30 | log_data = metric_log.MetricLog() 31 | for step in range(self.num_steps): 32 | lr_scheduler.step(log_data, step) 33 | 34 | if step < lr_scheduler.start_step: 35 | assert abs(2e-4 - self.get_lr(optD)) < 1e-5 36 | assert abs(2e-4 - self.get_lr(optG)) < 1e-5 37 | 38 | else: 39 | curr_lr = ((1 - (max(0, step - lr_scheduler.start_step) / 40 | (self.num_steps - lr_scheduler.start_step))) * 41 | self.lr_D) 42 | 43 | assert abs(curr_lr - self.get_lr(optD)) < 1e-5 44 | assert abs(curr_lr - self.get_lr(optG)) < 1e-5 45 | 46 | def test_no_decay(self): 47 | optD = optim.Adam(self.netD.parameters(), self.lr_D, betas=(0.0, 0.9)) 48 | optG = optim.Adam(self.netG.parameters(), self.lr_G, betas=(0.0, 0.9)) 49 | 50 | lr_scheduler = scheduler.LRScheduler(lr_decay='None', 51 | optD=optD, 52 | optG=optG, 53 | num_steps=self.num_steps) 54 | 55 | log_data = metric_log.MetricLog() 56 | for step in range(1, self.num_steps + 1): 57 | lr_scheduler.step(log_data, step) 58 | 59 | assert (self.lr_D == self.get_lr(optD)) 60 | assert (self.lr_G == self.get_lr(optG)) 61 | 62 | def test_arguments(self): 63 | with pytest.raises(NotImplementedError): 64 | optD = optim.Adam(self.netD.parameters(), 65 | self.lr_D, 66 | betas=(0.0, 0.9)) 67 | optG = optim.Adam(self.netG.parameters(), 68 | self.lr_G, 69 | betas=(0.0, 0.9)) 70 | scheduler.LRScheduler(lr_decay='does_not_exist', 71 | optD=optD, 72 | optG=optG, 73 | num_steps=self.num_steps) 74 | 75 | # with pytest. 76 | 77 | 78 | if __name__ == "__main__": 79 | test = TestLRScheduler() 80 | test.setup() 81 | test.test_arguments() 82 | test.test_linear_decay() 83 | test.test_no_decay() 84 | -------------------------------------------------------------------------------- /tests/utils/test_common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | 8 | from torch_mimicry.utils import common 9 | 10 | 11 | class TestCommon: 12 | def setup(self): 13 | # Build directories 14 | self.log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 15 | "test_log") 16 | if not os.path.exists(self.log_dir): 17 | os.makedirs(self.log_dir) 18 | 19 | def test_json_write_and_load(self): 20 | dict_to_write = dict(a=1, b=2, c=3) 21 | output_file = os.path.join(self.log_dir, 'output.json') 22 | common.write_to_json(dict_to_write, output_file) 23 | check = common.load_from_json(output_file) 24 | 25 | assert dict_to_write == check 26 | 27 | def test_load_and_save_image(self): 28 | image, label = common.load_images(n=1) 29 | 30 | image = torch.squeeze(image, dim=0) 31 | output_file = os.path.join(self.log_dir, 'images', 'test_img.png') 32 | common.save_tensor_image(image, output_file=output_file) 33 | 34 | check = np.array(Image.open(output_file)) 35 | 36 | assert check.shape == (32, 32, 3) 37 | assert label.shape == (1, ) 38 | 39 | def teardown(self): 40 | shutil.rmtree(self.log_dir) 41 | 42 | 43 | if __name__ == "__main__": 44 | test = TestCommon() 45 | test.setup() 46 | test.test_json_write_and_load() 47 | test.test_load_and_save_image() 48 | test.teardown() 49 | -------------------------------------------------------------------------------- /torch_mimicry/__init__.py: -------------------------------------------------------------------------------- 1 | from torch_mimicry import nets, training, metrics, datasets, modules 2 | 3 | __version__ = "0.1.16" 4 | -------------------------------------------------------------------------------- /torch_mimicry/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from . import imagenet 2 | from .data_utils import * 3 | from .image_loader import * 4 | -------------------------------------------------------------------------------- /torch_mimicry/datasets/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | from .imagenet import * 2 | -------------------------------------------------------------------------------- /torch_mimicry/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from . import fid, kid, inception_score, inception_model 2 | from .compute_fid import * 3 | from .compute_is import * 4 | from .compute_kid import * 5 | from .compute_metrics import * 6 | -------------------------------------------------------------------------------- /torch_mimicry/metrics/fid/__init__.py: -------------------------------------------------------------------------------- 1 | from .fid_utils import * 2 | -------------------------------------------------------------------------------- /torch_mimicry/metrics/fid/fid_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper functions for calculating FID as adopted from the official FID code: 3 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py 4 | """ 5 | import numpy as np 6 | from scipy import linalg 7 | 8 | from torch_mimicry.metrics.inception_model import inception_utils 9 | 10 | 11 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 12 | """ 13 | Numpy implementation of the Frechet Distance. 14 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 15 | and X_2 ~ N(mu_2, C_2) is 16 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 17 | 18 | Stable version by Dougal J. Sutherland. 19 | 20 | Args: 21 | mu1 : Numpy array containing the activations of the pool_3 layer of the 22 | inception net ( like returned by the function 'get_predictions') 23 | for generated samples. 24 | mu2: The sample mean over activations of the pool_3 layer, precalcualted 25 | on an representive data set. 26 | sigma1 (ndarray): The covariance matrix over activations of the pool_3 layer for 27 | generated samples. 28 | sigma2: The covariance matrix over activations of the pool_3 layer, 29 | precalcualted on an representive data set. 30 | 31 | Returns: 32 | np.float64: The Frechet Distance. 33 | """ 34 | if mu1.shape != mu2.shape or sigma1.shape != sigma2.shape: 35 | raise ValueError( 36 | "(mu1, sigma1) should have exactly the same shape as (mu2, sigma2)." 37 | ) 38 | 39 | mu1 = np.atleast_1d(mu1) 40 | mu2 = np.atleast_1d(mu2) 41 | 42 | sigma1 = np.atleast_2d(sigma1) 43 | sigma2 = np.atleast_2d(sigma2) 44 | 45 | diff = mu1 - mu2 46 | 47 | # Product might be almost singular 48 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 49 | if not np.isfinite(covmean).all(): 50 | print( 51 | "WARNING: fid calculation produces singular product; adding {} to diagonal of cov estimates" 52 | .format(eps)) 53 | 54 | offset = np.eye(sigma1.shape[0]) * eps 55 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 56 | 57 | # Numerical error might give slight imaginary component 58 | if np.iscomplexobj(covmean): 59 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 60 | m = np.max(np.abs(covmean.imag)) 61 | raise ValueError("Imaginary component {}".format(m)) 62 | covmean = covmean.real 63 | 64 | tr_covmean = np.trace(covmean) 65 | 66 | return diff.dot(diff) + np.trace(sigma1) + np.trace( 67 | sigma2) - 2 * tr_covmean 68 | 69 | 70 | def calculate_activation_statistics(images, sess, batch_size=50, verbose=True): 71 | """ 72 | Calculation of the statistics used by the FID. 73 | 74 | Args: 75 | images (ndarray): Numpy array of shape (N, H, W, 3) and values in 76 | the range [0, 255]. 77 | sess (Session): TensorFlow session object. 78 | batch_size (int): Batch size for inference. 79 | verbose (bool): If True, prints out logging information. 80 | 81 | Returns: 82 | ndarray: Mean of inception features from samples. 83 | ndarray: Covariance of inception features from samples. 84 | """ 85 | act = inception_utils.get_activations(images, sess, batch_size, verbose) 86 | mu = np.mean(act, axis=0) 87 | sigma = np.cov(act, rowvar=False) 88 | 89 | return mu, sigma 90 | -------------------------------------------------------------------------------- /torch_mimicry/metrics/inception_model/__init__.py: -------------------------------------------------------------------------------- 1 | from .inception_utils import * 2 | -------------------------------------------------------------------------------- /torch_mimicry/metrics/inception_score/__init__.py: -------------------------------------------------------------------------------- 1 | from .inception_score_utils import * 2 | -------------------------------------------------------------------------------- /torch_mimicry/metrics/kid/__init__.py: -------------------------------------------------------------------------------- 1 | from .kid_utils import * 2 | -------------------------------------------------------------------------------- /torch_mimicry/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | from .losses import * 3 | from .resblocks import * 4 | from .spectral_norm import * -------------------------------------------------------------------------------- /torch_mimicry/nets/__init__.py: -------------------------------------------------------------------------------- 1 | from . import (basemodel, gan, dcgan, wgan_gp, sngan, cgan_pd, ssgan, 2 | infomax_gan, sagan) 3 | -------------------------------------------------------------------------------- /torch_mimicry/nets/basemodel/__init__.py: -------------------------------------------------------------------------------- 1 | from .basemodel import * 2 | -------------------------------------------------------------------------------- /torch_mimicry/nets/basemodel/basemodel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of BaseModel. 3 | """ 4 | import os 5 | from abc import ABC, abstractmethod 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class BaseModel(nn.Module, ABC): 12 | r""" 13 | BaseModel with basic functionalities for checkpointing and restoration. 14 | """ 15 | def __init__(self): 16 | super().__init__() 17 | 18 | @abstractmethod 19 | def forward(self, x): 20 | pass 21 | 22 | @property 23 | def device(self): 24 | return next(self.parameters()).device 25 | 26 | def restore_checkpoint(self, ckpt_file, optimizer=None): 27 | r""" 28 | Restores checkpoint from a pth file and restores optimizer state. 29 | 30 | Args: 31 | ckpt_file (str): A PyTorch pth file containing model weights. 32 | optimizer (Optimizer): A vanilla optimizer to have its state restored from. 33 | 34 | Returns: 35 | int: Global step variable where the model was last checkpointed. 36 | """ 37 | if not ckpt_file: 38 | raise ValueError("No checkpoint file to be restored.") 39 | 40 | try: 41 | ckpt_dict = torch.load(ckpt_file) 42 | except RuntimeError: 43 | ckpt_dict = torch.load(ckpt_file, 44 | map_location=lambda storage, loc: storage) 45 | 46 | # Restore model weights 47 | self.load_state_dict(ckpt_dict['model_state_dict']) 48 | 49 | # Restore optimizer status if existing. Evaluation doesn't need this 50 | if optimizer: 51 | optimizer.load_state_dict(ckpt_dict['optimizer_state_dict']) 52 | 53 | # Return global step 54 | return ckpt_dict['global_step'] 55 | 56 | def save_checkpoint(self, 57 | directory, 58 | global_step, 59 | optimizer=None, 60 | name=None): 61 | r""" 62 | Saves checkpoint at a certain global step during training. Optimizer state 63 | is also saved together. 64 | 65 | Args: 66 | directory (str): Path to save checkpoint to. 67 | global_step (int): The global step variable during training. 68 | optimizer (Optimizer): Optimizer state to be saved concurrently. 69 | name (str): The name to save the checkpoint file as. 70 | 71 | Returns: 72 | None 73 | """ 74 | # Create directory to save to 75 | if not os.path.exists(directory): 76 | os.makedirs(directory) 77 | 78 | # Build checkpoint dict to save. 79 | ckpt_dict = { 80 | 'model_state_dict': 81 | self.state_dict(), 82 | 'optimizer_state_dict': 83 | optimizer.state_dict() if optimizer is not None else None, 84 | 'global_step': 85 | global_step 86 | } 87 | 88 | # Save the file with specific name 89 | if name is None: 90 | name = "{}_{}_steps.pth".format( 91 | os.path.basename(directory), # netD or netG 92 | global_step) 93 | 94 | torch.save(ckpt_dict, os.path.join(directory, name)) 95 | 96 | def count_params(self): 97 | r""" 98 | Computes the number of parameters in this model. 99 | 100 | Args: None 101 | 102 | Returns: 103 | int: Total number of weight parameters for this model. 104 | int: Total number of trainable parameters for this model. 105 | 106 | """ 107 | num_total_params = sum(p.numel() for p in self.parameters()) 108 | num_trainable_params = sum(p.numel() for p in self.parameters() 109 | if p.requires_grad) 110 | 111 | return num_total_params, num_trainable_params 112 | -------------------------------------------------------------------------------- /torch_mimicry/nets/cgan_pd/__init__.py: -------------------------------------------------------------------------------- 1 | from .cgan_pd_32 import * 2 | from .cgan_pd_base import * -------------------------------------------------------------------------------- /torch_mimicry/nets/cgan_pd/cgan_pd_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class definition of cGAN-PD. 3 | """ 4 | 5 | from torch_mimicry.nets.gan import cgan 6 | 7 | 8 | class CGANPDBaseGenerator(cgan.BaseConditionalGenerator): 9 | r""" 10 | ResNet backbone generator for cGAN-PD, 11 | 12 | Attributes: 13 | num_classes (int): Number of classes, more than 0 for conditional GANs. 14 | nz (int): Noise dimension for upsampling. 15 | ngf (int): Variable controlling generator feature map sizes. 16 | bottom_width (int): Starting width for upsampling generator output to an image. 17 | loss_type (str): Name of loss to use for GAN loss. 18 | """ 19 | def __init__(self, 20 | num_classes, 21 | bottom_width, 22 | nz, 23 | ngf, 24 | loss_type='hinge', 25 | **kwargs): 26 | super().__init__(nz=nz, 27 | ngf=ngf, 28 | bottom_width=bottom_width, 29 | loss_type=loss_type, 30 | num_classes=num_classes, 31 | **kwargs) 32 | 33 | 34 | class CGANPDBaseDiscriminator(cgan.BaseConditionalDiscriminator): 35 | r""" 36 | ResNet backbone discriminator for cGAN-PD. 37 | 38 | Attributes: 39 | num_classes (int): Number of classes, more than 0 for conditional GANs. 40 | ndf (int): Variable controlling discriminator feature map sizes. 41 | loss_type (str): Name of loss to use for GAN loss. 42 | """ 43 | def __init__(self, num_classes, ndf, loss_type='hinge', **kwargs): 44 | super().__init__(ndf=ndf, 45 | loss_type=loss_type, 46 | num_classes=num_classes, 47 | **kwargs) 48 | -------------------------------------------------------------------------------- /torch_mimicry/nets/dcgan/__init__.py: -------------------------------------------------------------------------------- 1 | from .dcgan_128 import * 2 | from .dcgan_32 import * 3 | from .dcgan_48 import * 4 | from .dcgan_64 import * 5 | from .dcgan_base import * 6 | from .dcgan_cifar import * 7 | -------------------------------------------------------------------------------- /torch_mimicry/nets/dcgan/dcgan_32.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of DCGAN for image size 32. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torch_mimicry.nets.dcgan import dcgan_base 8 | from torch_mimicry.modules.resblocks import DBlockOptimized, DBlock, GBlock 9 | 10 | 11 | class DCGANGenerator32(dcgan_base.DCGANBaseGenerator): 12 | r""" 13 | ResNet backbone generator for ResNet DCGAN. 14 | 15 | Attributes: 16 | nz (int): Noise dimension for upsampling. 17 | ngf (int): Variable controlling generator feature map sizes. 18 | bottom_width (int): Starting width for upsampling generator output to an image. 19 | loss_type (str): Name of loss to use for GAN loss. 20 | """ 21 | def __init__(self, nz=128, ngf=256, bottom_width=4, **kwargs): 22 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs) 23 | 24 | # Build the layers 25 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf) 26 | self.block2 = GBlock(self.ngf, self.ngf, upsample=True) 27 | self.block3 = GBlock(self.ngf, self.ngf, upsample=True) 28 | self.block4 = GBlock(self.ngf, self.ngf, upsample=True) 29 | self.b5 = nn.BatchNorm2d(self.ngf) 30 | self.c5 = nn.Conv2d(self.ngf, 3, 3, 1, padding=1) 31 | self.activation = nn.ReLU(True) 32 | 33 | # Initialise the weights 34 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0) 35 | nn.init.xavier_uniform_(self.c5.weight.data, 1.0) 36 | 37 | def forward(self, x): 38 | r""" 39 | Feedforwards a batch of noise vectors into a batch of fake images. 40 | 41 | Args: 42 | x (Tensor): A batch of noise vectors of shape (N, nz). 43 | 44 | Returns: 45 | Tensor: A batch of fake images of shape (N, C, H, W). 46 | """ 47 | h = self.l1(x) 48 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width) 49 | h = self.block2(h) 50 | h = self.block3(h) 51 | h = self.block4(h) 52 | h = self.b5(h) 53 | h = self.activation(h) 54 | h = torch.tanh(self.c5(h)) # tanh 55 | 56 | return h 57 | 58 | 59 | class DCGANDiscriminator32(dcgan_base.DCGANBaseDiscriminator): 60 | r""" 61 | ResNet backbone discriminator for ResNet DCGAN. 62 | 63 | Attributes: 64 | ndf (int): Variable controlling discriminator feature map sizes. 65 | loss_type (str): Name of loss to use for GAN loss. 66 | """ 67 | def __init__(self, ndf=128, **kwargs): 68 | super().__init__(ndf=ndf, **kwargs) 69 | 70 | # Build layers 71 | self.block1 = DBlockOptimized(3, self.ndf, spectral_norm=False) 72 | self.block2 = DBlock(self.ndf, 73 | self.ndf, 74 | downsample=True, 75 | spectral_norm=False) 76 | self.block3 = DBlock(self.ndf, 77 | self.ndf, 78 | downsample=False, 79 | spectral_norm=False) 80 | self.block4 = DBlock(self.ndf, 81 | self.ndf, 82 | downsample=False, 83 | spectral_norm=False) 84 | self.l5 = nn.Linear(self.ndf, 1) 85 | self.activation = nn.ReLU(True) 86 | 87 | # Initialise the weights 88 | nn.init.xavier_uniform_(self.l5.weight.data, 1.0) 89 | 90 | def forward(self, x): 91 | r""" 92 | Feedforwards a batch of real/fake images and produces a batch of GAN logits. 93 | 94 | Args: 95 | x (Tensor): A batch of images of shape (N, C, H, W). 96 | 97 | Returns: 98 | Tensor: A batch of GAN logits of shape (N, 1). 99 | """ 100 | h = x 101 | h = self.block1(h) 102 | h = self.block2(h) 103 | h = self.block3(h) 104 | h = self.block4(h) 105 | h = self.activation(h) 106 | 107 | # Global average pooling 108 | h = torch.sum(h, dim=(2, 3)) 109 | output = self.l5(h) 110 | 111 | return output 112 | -------------------------------------------------------------------------------- /torch_mimicry/nets/dcgan/dcgan_48.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of DCGAN for image size 48. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torch_mimicry.nets.dcgan import dcgan_base 8 | from torch_mimicry.modules.resblocks import DBlockOptimized, DBlock, GBlock 9 | 10 | 11 | class DCGANGenerator48(dcgan_base.DCGANBaseGenerator): 12 | r""" 13 | ResNet backbone generator for ResNet DCGAN. 14 | 15 | Attributes: 16 | nz (int): Noise dimension for upsampling. 17 | ngf (int): Variable controlling generator feature map sizes. 18 | bottom_width (int): Starting width for upsampling generator output to an image. 19 | loss_type (str): Name of loss to use for GAN loss. 20 | """ 21 | def __init__(self, nz=128, ngf=512, bottom_width=6, **kwargs): 22 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs) 23 | 24 | # Build the layers 25 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf) 26 | self.block2 = GBlock(self.ngf, self.ngf >> 1, upsample=True) 27 | self.block3 = GBlock(self.ngf >> 1, self.ngf >> 2, upsample=True) 28 | self.block4 = GBlock(self.ngf >> 2, self.ngf >> 3, upsample=True) 29 | self.b5 = nn.BatchNorm2d(self.ngf >> 3) 30 | self.c5 = nn.Conv2d(self.ngf >> 3, 3, 3, 1, padding=1) 31 | self.activation = nn.ReLU(True) 32 | 33 | # Initialise the weights 34 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0) 35 | nn.init.xavier_uniform_(self.c5.weight.data, 1.0) 36 | 37 | def forward(self, x): 38 | r""" 39 | Feedforwards a batch of noise vectors into a batch of fake images. 40 | 41 | Args: 42 | x (Tensor): A batch of noise vectors of shape (N, nz). 43 | 44 | Returns: 45 | Tensor: A batch of fake images of shape (N, C, H, W). 46 | """ 47 | h = self.l1(x) 48 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width) 49 | h = self.block2(h) 50 | h = self.block3(h) 51 | h = self.block4(h) 52 | h = self.b5(h) 53 | h = self.activation(h) 54 | h = torch.tanh(self.c5(h)) 55 | 56 | return h 57 | 58 | 59 | class DCGANDiscriminator48(dcgan_base.DCGANBaseDiscriminator): 60 | r""" 61 | ResNet backbone discriminator for ResNet DCGAN. 62 | 63 | Attributes: 64 | ndf (int): Variable controlling discriminator feature map sizes. 65 | loss_type (str): Name of loss to use for GAN loss. 66 | """ 67 | def __init__(self, ndf=1024, **kwargs): 68 | super().__init__(ndf=ndf, **kwargs) 69 | 70 | # Build layers 71 | self.block1 = DBlockOptimized(3, self.ndf >> 4, spectral_norm=False) 72 | self.block2 = DBlock(self.ndf >> 4, 73 | self.ndf >> 3, 74 | downsample=True, 75 | spectral_norm=False) 76 | self.block3 = DBlock(self.ndf >> 3, 77 | self.ndf >> 2, 78 | downsample=True, 79 | spectral_norm=False) 80 | self.block4 = DBlock(self.ndf >> 2, 81 | self.ndf >> 1, 82 | downsample=True, 83 | spectral_norm=False) 84 | self.block5 = DBlock(self.ndf >> 1, 85 | self.ndf, 86 | downsample=False, 87 | spectral_norm=False) 88 | self.l5 = nn.Linear(self.ndf, 1) 89 | 90 | self.activation = nn.ReLU(True) 91 | 92 | # Initialise the weights 93 | nn.init.xavier_uniform_(self.l5.weight.data, 1.0) 94 | 95 | def forward(self, x): 96 | r""" 97 | Feedforwards a batch of real/fake images and produces a batch of GAN logits. 98 | 99 | Args: 100 | x (Tensor): A batch of images of shape (N, C, H, W). 101 | 102 | Returns: 103 | Tensor: A batch of GAN logits of shape (N, 1). 104 | """ 105 | h = x 106 | h = self.block1(h) 107 | h = self.block2(h) 108 | h = self.block3(h) 109 | h = self.block4(h) 110 | h = self.block5(h) 111 | h = self.activation(h) 112 | 113 | # Global average pooling 114 | h = torch.sum(h, dim=(2, 3)) 115 | output = self.l5(h) 116 | 117 | return output 118 | -------------------------------------------------------------------------------- /torch_mimicry/nets/dcgan/dcgan_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class definition of DCGAN. 3 | """ 4 | from torch_mimicry.nets.gan import gan 5 | 6 | 7 | class DCGANBaseGenerator(gan.BaseGenerator): 8 | r""" 9 | ResNet backbone generator for ResNet DCGAN. 10 | 11 | Attributes: 12 | nz (int): Noise dimension for upsampling. 13 | ngf (int): Variable controlling generator feature map sizes. 14 | bottom_width (int): Starting width for upsampling generator output to an image. 15 | loss_type (str): Name of loss to use for GAN loss. 16 | """ 17 | def __init__(self, nz, ngf, bottom_width, loss_type='ns', **kwargs): 18 | super().__init__(nz=nz, 19 | ngf=ngf, 20 | bottom_width=bottom_width, 21 | loss_type=loss_type, 22 | **kwargs) 23 | 24 | 25 | class DCGANBaseDiscriminator(gan.BaseDiscriminator): 26 | r""" 27 | ResNet backbone discriminator for ResNet DCGAN. 28 | 29 | Attributes: 30 | ndf (int): Variable controlling discriminator feature map sizes. 31 | loss_type (str): Name of loss to use for GAN loss. 32 | """ 33 | def __init__(self, ndf, loss_type='ns', **kwargs): 34 | super().__init__(ndf=ndf, loss_type=loss_type, **kwargs) 35 | -------------------------------------------------------------------------------- /torch_mimicry/nets/dcgan/dcgan_cifar.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of DCGAN based on Kurach et al. specifically for CIFAR-10. 3 | The main difference with dcgan_32 is in using sigmoid 4 | as the final activation for the generator instead of tanh. 5 | 6 | To reproduce scores, CIFAR-10 images should not be normalized from -1 to 1, and should 7 | instead have values from 0 to 1, which is the default when loading images as np arrays. 8 | """ 9 | import torch 10 | import torch.nn as nn 11 | 12 | from torch_mimicry.nets.dcgan import dcgan_base 13 | from torch_mimicry.modules.resblocks import DBlockOptimized, DBlock, GBlock 14 | 15 | 16 | class DCGANGeneratorCIFAR(dcgan_base.DCGANBaseGenerator): 17 | r""" 18 | ResNet backbone generator for ResNet DCGAN. 19 | 20 | Attributes: 21 | nz (int): Noise dimension for upsampling. 22 | ngf (int): Variable controlling generator feature map sizes. 23 | bottom_width (int): Starting width for upsampling generator output to an image. 24 | loss_type (str): Name of loss to use for GAN loss. 25 | """ 26 | def __init__(self, nz=128, ngf=256, bottom_width=4, **kwargs): 27 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs) 28 | 29 | # Build the layers 30 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf) 31 | self.block2 = GBlock(self.ngf, self.ngf, upsample=True) 32 | self.block3 = GBlock(self.ngf, self.ngf, upsample=True) 33 | self.block4 = GBlock(self.ngf, self.ngf, upsample=True) 34 | self.b5 = nn.BatchNorm2d(self.ngf) 35 | self.c5 = nn.Conv2d(self.ngf, 3, 3, 1, padding=1) 36 | self.activation = nn.ReLU(True) 37 | 38 | # Initialise the weights 39 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0) 40 | nn.init.xavier_uniform_(self.c5.weight.data, 1.0) 41 | 42 | def forward(self, x): 43 | r""" 44 | Feedforwards a batch of noise vectors into a batch of fake images. 45 | 46 | Args: 47 | x (Tensor): A batch of noise vectors of shape (N, nz). 48 | 49 | Returns: 50 | Tensor: A batch of fake images of shape (N, C, H, W). 51 | """ 52 | h = self.l1(x) 53 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width) 54 | h = self.block2(h) 55 | h = self.block3(h) 56 | h = self.block4(h) 57 | h = self.b5(h) 58 | h = self.activation(h) 59 | h = torch.sigmoid(self.c5(h)) 60 | 61 | return h 62 | 63 | 64 | class DCGANDiscriminatorCIFAR(dcgan_base.DCGANBaseDiscriminator): 65 | r""" 66 | ResNet backbone discriminator for ResNet DCGAN. 67 | 68 | Attributes: 69 | ndf (int): Variable controlling discriminator feature map sizes. 70 | loss_type (str): Name of loss to use for GAN loss. 71 | """ 72 | def __init__(self, ndf=128, **kwargs): 73 | super().__init__(ndf=ndf, **kwargs) 74 | 75 | # Build layers 76 | self.block1 = DBlockOptimized(3, self.ndf, spectral_norm=False) 77 | self.block2 = DBlock(self.ndf, 78 | self.ndf, 79 | downsample=True, 80 | spectral_norm=False) 81 | self.block3 = DBlock(self.ndf, 82 | self.ndf, 83 | downsample=False, 84 | spectral_norm=False) 85 | self.block4 = DBlock(self.ndf, 86 | self.ndf, 87 | downsample=False, 88 | spectral_norm=False) 89 | self.l5 = nn.Linear(self.ndf, 1) 90 | self.activation = nn.ReLU(True) 91 | 92 | # Initialise the weights 93 | nn.init.xavier_uniform_(self.l5.weight.data, 1.0) 94 | 95 | def forward(self, x): 96 | r""" 97 | Feedforwards a batch of real/fake images and produces a batch of GAN logits. 98 | 99 | Args: 100 | x (Tensor): A batch of images of shape (N, C, H, W). 101 | 102 | Returns: 103 | Tensor: A batch of GAN logits of shape (N, 1). 104 | """ 105 | h = x 106 | h = self.block1(h) 107 | h = self.block2(h) 108 | h = self.block3(h) 109 | h = self.block4(h) 110 | h = self.activation(h) 111 | 112 | # Global mean pooling 113 | h = torch.sum(h, dim=(2, 3)) 114 | output = self.l5(h) 115 | 116 | return output 117 | -------------------------------------------------------------------------------- /torch_mimicry/nets/gan/__init__.py: -------------------------------------------------------------------------------- 1 | from .cgan import * 2 | from .gan import * -------------------------------------------------------------------------------- /torch_mimicry/nets/infomax_gan/__init__.py: -------------------------------------------------------------------------------- 1 | from .infomax_gan_128 import * 2 | from .infomax_gan_32 import * 3 | from .infomax_gan_48 import * 4 | from .infomax_gan_64 import * 5 | from .infomax_gan_base import * 6 | -------------------------------------------------------------------------------- /torch_mimicry/nets/sagan/__init__.py: -------------------------------------------------------------------------------- 1 | from .sagan_128 import * 2 | from .sagan_32 import * 3 | from .sagan_base import * 4 | -------------------------------------------------------------------------------- /torch_mimicry/nets/sagan/sagan_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base implementation of SAGAN. 3 | """ 4 | from torch_mimicry.nets.gan import cgan 5 | 6 | 7 | class SAGANBaseGenerator(cgan.BaseConditionalGenerator): 8 | r""" 9 | ResNet backbone generator for cGAN-PD, 10 | 11 | Attributes: 12 | num_classes (int): Number of classes, more than 0 for conditional GANs. 13 | nz (int): Noise dimension for upsampling. 14 | ngf (int): Variable controlling generator feature map sizes. 15 | bottom_width (int): Starting width for upsampling generator output to an image. 16 | loss_type (str): Name of loss to use for GAN loss. 17 | """ 18 | def __init__(self, 19 | num_classes, 20 | bottom_width, 21 | nz, 22 | ngf, 23 | loss_type='hinge', 24 | **kwargs): 25 | super().__init__(nz=nz, 26 | ngf=ngf, 27 | bottom_width=bottom_width, 28 | loss_type=loss_type, 29 | num_classes=num_classes, 30 | **kwargs) 31 | 32 | 33 | class SAGANBaseDiscriminator(cgan.BaseConditionalDiscriminator): 34 | r""" 35 | ResNet backbone discriminator for cGAN-PD. 36 | 37 | Attributes: 38 | num_classes (int): Number of classes, more than 0 for conditional GANs. 39 | ndf (int): Variable controlling discriminator feature map sizes. 40 | loss_type (str): Name of loss to use for GAN loss. 41 | """ 42 | def __init__(self, num_classes, ndf, loss_type='hinge', **kwargs): 43 | super().__init__(ndf=ndf, 44 | loss_type=loss_type, 45 | num_classes=num_classes, 46 | **kwargs) 47 | -------------------------------------------------------------------------------- /torch_mimicry/nets/sngan/__init__.py: -------------------------------------------------------------------------------- 1 | from .sngan_128 import * 2 | from .sngan_32 import * 3 | from .sngan_48 import * 4 | from .sngan_64 import * 5 | from .sngan_base import * 6 | -------------------------------------------------------------------------------- /torch_mimicry/nets/sngan/sngan_128.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of SNGAN for image size 128. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torch_mimicry.modules.layers import SNLinear 8 | from torch_mimicry.modules.resblocks import DBlockOptimized, DBlock, GBlock 9 | from torch_mimicry.nets.sngan import sngan_base 10 | 11 | 12 | class SNGANGenerator128(sngan_base.SNGANBaseGenerator): 13 | r""" 14 | ResNet backbone generator for SNGAN. 15 | 16 | Attributes: 17 | nz (int): Noise dimension for upsampling. 18 | ngf (int): Variable controlling generator feature map sizes. 19 | bottom_width (int): Starting width for upsampling generator output to an image. 20 | loss_type (str): Name of loss to use for GAN loss. 21 | """ 22 | def __init__(self, nz=128, ngf=1024, bottom_width=4, **kwargs): 23 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs) 24 | 25 | # Build the layers 26 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf) 27 | self.block2 = GBlock(self.ngf, self.ngf, upsample=True) 28 | self.block3 = GBlock(self.ngf, self.ngf >> 1, upsample=True) 29 | self.block4 = GBlock(self.ngf >> 1, self.ngf >> 2, upsample=True) 30 | self.block5 = GBlock(self.ngf >> 2, self.ngf >> 3, upsample=True) 31 | self.block6 = GBlock(self.ngf >> 3, self.ngf >> 4, upsample=True) 32 | self.b7 = nn.BatchNorm2d(self.ngf >> 4) 33 | self.c7 = nn.Conv2d(self.ngf >> 4, 3, 3, 1, padding=1) 34 | self.activation = nn.ReLU(True) 35 | 36 | # Initialise the weights 37 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0) 38 | nn.init.xavier_uniform_(self.c7.weight.data, 1.0) 39 | 40 | def forward(self, x): 41 | r""" 42 | Feedforwards a batch of noise vectors into a batch of fake images. 43 | 44 | Args: 45 | x (Tensor): A batch of noise vectors of shape (N, nz). 46 | 47 | Returns: 48 | Tensor: A batch of fake images of shape (N, C, H, W). 49 | """ 50 | h = self.l1(x) 51 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width) 52 | h = self.block2(h) 53 | h = self.block3(h) 54 | h = self.block4(h) 55 | h = self.block5(h) 56 | h = self.block6(h) 57 | h = self.b7(h) 58 | h = self.activation(h) 59 | h = torch.tanh(self.c7(h)) 60 | 61 | return h 62 | 63 | 64 | class SNGANDiscriminator128(sngan_base.SNGANBaseDiscriminator): 65 | r""" 66 | ResNet backbone discriminator for SNGAN. 67 | 68 | Attributes: 69 | ndf (int): Variable controlling discriminator feature map sizes. 70 | loss_type (str): Name of loss to use for GAN loss. 71 | """ 72 | def __init__(self, ndf=1024, **kwargs): 73 | super().__init__(ndf=ndf, **kwargs) 74 | 75 | # Build layers 76 | self.block1 = DBlockOptimized(3, self.ndf >> 4) 77 | self.block2 = DBlock(self.ndf >> 4, self.ndf >> 3, downsample=True) 78 | self.block3 = DBlock(self.ndf >> 3, self.ndf >> 2, downsample=True) 79 | self.block4 = DBlock(self.ndf >> 2, self.ndf >> 1, downsample=True) 80 | self.block5 = DBlock(self.ndf >> 1, self.ndf, downsample=True) 81 | self.block6 = DBlock(self.ndf, self.ndf, downsample=False) 82 | self.l7 = SNLinear(self.ndf, 1) 83 | self.activation = nn.ReLU(True) 84 | 85 | # Initialise the weights 86 | nn.init.xavier_uniform_(self.l7.weight.data, 1.0) 87 | 88 | def forward(self, x): 89 | r""" 90 | Feedforwards a batch of real/fake images and produces a batch of GAN logits. 91 | 92 | Args: 93 | x (Tensor): A batch of images of shape (N, C, H, W). 94 | 95 | Returns: 96 | Tensor: A batch of GAN logits of shape (N, 1). 97 | """ 98 | h = x 99 | h = self.block1(h) 100 | h = self.block2(h) 101 | h = self.block3(h) 102 | h = self.block4(h) 103 | h = self.block5(h) 104 | h = self.block6(h) 105 | h = self.activation(h) 106 | 107 | # Global sum pooling 108 | h = torch.sum(h, dim=(2, 3)) 109 | output = self.l7(h) 110 | 111 | return output 112 | -------------------------------------------------------------------------------- /torch_mimicry/nets/sngan/sngan_32.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of SNGAN for image size 32. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torch_mimicry.modules.layers import SNLinear 8 | from torch_mimicry.modules.resblocks import DBlockOptimized, DBlock, GBlock 9 | from torch_mimicry.nets.sngan import sngan_base 10 | 11 | 12 | class SNGANGenerator32(sngan_base.SNGANBaseGenerator): 13 | r""" 14 | ResNet backbone generator for SNGAN. 15 | 16 | Attributes: 17 | nz (int): Noise dimension for upsampling. 18 | ngf (int): Variable controlling generator feature map sizes. 19 | bottom_width (int): Starting width for upsampling generator output to an image. 20 | loss_type (str): Name of loss to use for GAN loss. 21 | """ 22 | def __init__(self, nz=128, ngf=256, bottom_width=4, **kwargs): 23 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs) 24 | 25 | # Build the layers 26 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf) 27 | self.block2 = GBlock(self.ngf, self.ngf, upsample=True) 28 | self.block3 = GBlock(self.ngf, self.ngf, upsample=True) 29 | self.block4 = GBlock(self.ngf, self.ngf, upsample=True) 30 | self.b5 = nn.BatchNorm2d(self.ngf) 31 | self.c5 = nn.Conv2d(self.ngf, 3, 3, 1, padding=1) 32 | self.activation = nn.ReLU(True) 33 | 34 | # Initialise the weights 35 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0) 36 | nn.init.xavier_uniform_(self.c5.weight.data, 1.0) 37 | 38 | def forward(self, x): 39 | r""" 40 | Feedforwards a batch of noise vectors into a batch of fake images. 41 | 42 | Args: 43 | x (Tensor): A batch of noise vectors of shape (N, nz). 44 | 45 | Returns: 46 | Tensor: A batch of fake images of shape (N, C, H, W). 47 | """ 48 | h = self.l1(x) 49 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width) 50 | h = self.block2(h) 51 | h = self.block3(h) 52 | h = self.block4(h) 53 | h = self.b5(h) 54 | h = self.activation(h) 55 | h = torch.tanh(self.c5(h)) 56 | 57 | return h 58 | 59 | 60 | class SNGANDiscriminator32(sngan_base.SNGANBaseDiscriminator): 61 | r""" 62 | ResNet backbone discriminator for SNGAN. 63 | 64 | Attributes: 65 | ndf (int): Variable controlling discriminator feature map sizes. 66 | loss_type (str): Name of loss to use for GAN loss. 67 | """ 68 | def __init__(self, ndf=128, **kwargs): 69 | super().__init__(ndf=ndf, **kwargs) 70 | 71 | # Build layers 72 | self.block1 = DBlockOptimized(3, self.ndf) 73 | self.block2 = DBlock(self.ndf, self.ndf, downsample=True) 74 | self.block3 = DBlock(self.ndf, self.ndf, downsample=False) 75 | self.block4 = DBlock(self.ndf, self.ndf, downsample=False) 76 | self.l5 = SNLinear(self.ndf, 1) 77 | self.activation = nn.ReLU(True) 78 | 79 | # Initialise the weights 80 | nn.init.xavier_uniform_(self.l5.weight.data, 1.0) 81 | 82 | def forward(self, x): 83 | r""" 84 | Feedforwards a batch of real/fake images and produces a batch of GAN logits. 85 | 86 | Args: 87 | x (Tensor): A batch of images of shape (N, C, H, W). 88 | 89 | Returns: 90 | Tensor: A batch of GAN logits of shape (N, 1). 91 | """ 92 | h = x 93 | h = self.block1(h) 94 | h = self.block2(h) 95 | h = self.block3(h) 96 | h = self.block4(h) 97 | h = self.activation(h) 98 | 99 | # Global sum pooling 100 | h = torch.sum(h, dim=(2, 3)) 101 | output = self.l5(h) 102 | 103 | return output 104 | -------------------------------------------------------------------------------- /torch_mimicry/nets/sngan/sngan_48.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of SNGAN for image size 48. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torch_mimicry.modules.layers import SNLinear 8 | from torch_mimicry.modules.resblocks import DBlockOptimized, DBlock, GBlock 9 | from torch_mimicry.nets.sngan import sngan_base 10 | 11 | 12 | class SNGANGenerator48(sngan_base.SNGANBaseGenerator): 13 | r""" 14 | ResNet backbone generator for SNGAN. 15 | 16 | Attributes: 17 | nz (int): Noise dimension for upsampling. 18 | ngf (int): Variable controlling generator feature map sizes. 19 | bottom_width (int): Starting width for upsampling generator output to an image. 20 | loss_type (str): Name of loss to use for GAN loss. 21 | """ 22 | def __init__(self, nz=128, ngf=512, bottom_width=6, **kwargs): 23 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs) 24 | 25 | # Build the layers 26 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf) 27 | self.block2 = GBlock(self.ngf, self.ngf >> 1, upsample=True) 28 | self.block3 = GBlock(self.ngf >> 1, self.ngf >> 2, upsample=True) 29 | self.block4 = GBlock(self.ngf >> 2, self.ngf >> 3, upsample=True) 30 | self.b5 = nn.BatchNorm2d(self.ngf >> 3) 31 | self.c5 = nn.Conv2d(self.ngf >> 3, 3, 3, 1, padding=1) 32 | self.activation = nn.ReLU(True) 33 | 34 | # Initialise the weights 35 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0) 36 | nn.init.xavier_uniform_(self.c5.weight.data, 1.0) 37 | 38 | def forward(self, x): 39 | r""" 40 | Feedforwards a batch of noise vectors into a batch of fake images. 41 | 42 | Args: 43 | x (Tensor): A batch of noise vectors of shape (N, nz). 44 | 45 | Returns: 46 | Tensor: A batch of fake images of shape (N, C, H, W). 47 | """ 48 | h = self.l1(x) 49 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width) 50 | h = self.block2(h) 51 | h = self.block3(h) 52 | h = self.block4(h) 53 | h = self.b5(h) 54 | h = self.activation(h) 55 | h = torch.tanh(self.c5(h)) 56 | 57 | return h 58 | 59 | 60 | class SNGANDiscriminator48(sngan_base.SNGANBaseDiscriminator): 61 | r""" 62 | ResNet backbone discriminator for SNGAN. 63 | 64 | Attribates: 65 | ndf (int): Variable controlling discriminator feature map sizes. 66 | loss_type (str): Name of loss to use for GAN loss. 67 | """ 68 | def __init__(self, ndf=1024, **kwargs): 69 | super().__init__(ndf=ndf, **kwargs) 70 | 71 | # Build layers 72 | self.block1 = DBlockOptimized(3, self.ndf >> 4) 73 | self.block2 = DBlock(self.ndf >> 4, self.ndf >> 3, downsample=True) 74 | self.block3 = DBlock(self.ndf >> 3, self.ndf >> 2, downsample=True) 75 | self.block4 = DBlock(self.ndf >> 2, self.ndf >> 1, downsample=True) 76 | self.block5 = DBlock(self.ndf >> 1, self.ndf, downsample=False) 77 | self.l5 = SNLinear(self.ndf, 1) 78 | 79 | self.activation = nn.ReLU(True) 80 | 81 | # Initialise the weights 82 | nn.init.xavier_uniform_(self.l5.weight.data, 1.0) 83 | 84 | def forward(self, x): 85 | r""" 86 | Feedforwards a batch of real/fake images and produces a batch of GAN logits. 87 | 88 | Args: 89 | x (Tensor): A batch of images of shape (N, C, H, W). 90 | 91 | Returns: 92 | Tensor: A batch of GAN logits of shape (N, 1). 93 | """ 94 | h = x 95 | h = self.block1(h) 96 | h = self.block2(h) 97 | h = self.block3(h) 98 | h = self.block4(h) 99 | h = self.block5(h) 100 | h = self.activation(h) 101 | 102 | # Global sum pooling 103 | h = torch.sum(h, dim=(2, 3)) 104 | output = self.l5(h) 105 | 106 | return output 107 | -------------------------------------------------------------------------------- /torch_mimicry/nets/sngan/sngan_64.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of SNGAN for image size 64. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torch_mimicry.modules.layers import SNLinear 8 | from torch_mimicry.modules.resblocks import DBlockOptimized, DBlock, GBlock 9 | from torch_mimicry.nets.sngan import sngan_base 10 | 11 | 12 | class SNGANGenerator64(sngan_base.SNGANBaseGenerator): 13 | r""" 14 | ResNet backbone generator for SNGAN. 15 | 16 | Attributes: 17 | nz (int): Noise dimension for upsampling. 18 | ngf (int): Variable controlling generator feature map sizes. 19 | bottom_width (int): Starting width for upsampling generator output to an image. 20 | loss_type (str): Name of loss to use for GAN loss. 21 | """ 22 | def __init__(self, nz=128, ngf=1024, bottom_width=4, **kwargs): 23 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs) 24 | 25 | # Build the layers 26 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf) 27 | self.block2 = GBlock(self.ngf, self.ngf >> 1, upsample=True) 28 | self.block3 = GBlock(self.ngf >> 1, self.ngf >> 2, upsample=True) 29 | self.block4 = GBlock(self.ngf >> 2, self.ngf >> 3, upsample=True) 30 | self.block5 = GBlock(self.ngf >> 3, self.ngf >> 4, upsample=True) 31 | self.b6 = nn.BatchNorm2d(self.ngf >> 4) 32 | self.c6 = nn.Conv2d(self.ngf >> 4, 3, 3, 1, padding=1) 33 | self.activation = nn.ReLU(True) 34 | 35 | # Initialise the weights 36 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0) 37 | nn.init.xavier_uniform_(self.c6.weight.data, 1.0) 38 | 39 | def forward(self, x): 40 | r""" 41 | Feedforwards a batch of noise vectors into a batch of fake images. 42 | 43 | Args: 44 | x (Tensor): A batch of noise vectors of shape (N, nz). 45 | 46 | Returns: 47 | Tensor: A batch of fake images of shape (N, C, H, W). 48 | """ 49 | h = self.l1(x) 50 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width) 51 | h = self.block2(h) 52 | h = self.block3(h) 53 | h = self.block4(h) 54 | h = self.block5(h) 55 | h = self.b6(h) 56 | h = self.activation(h) 57 | h = torch.tanh(self.c6(h)) 58 | 59 | return h 60 | 61 | 62 | class SNGANDiscriminator64(sngan_base.SNGANBaseDiscriminator): 63 | r""" 64 | ResNet backbone discriminator for SNGAN. 65 | 66 | Attributes: 67 | ndf (int): Variable controlling discriminator feature map sizes. 68 | loss_type (str): Name of loss to use for GAN loss. 69 | """ 70 | def __init__(self, ndf=1024, **kwargs): 71 | super().__init__(ndf=ndf, **kwargs) 72 | 73 | # Build layers 74 | self.block1 = DBlockOptimized(3, self.ndf >> 4) 75 | self.block2 = DBlock(self.ndf >> 4, self.ndf >> 3, downsample=True) 76 | self.block3 = DBlock(self.ndf >> 3, self.ndf >> 2, downsample=True) 77 | self.block4 = DBlock(self.ndf >> 2, self.ndf >> 1, downsample=True) 78 | self.block5 = DBlock(self.ndf >> 1, self.ndf, downsample=True) 79 | self.l6 = SNLinear(self.ndf, 1) 80 | self.activation = nn.ReLU(True) 81 | 82 | # Initialise the weights 83 | nn.init.xavier_uniform_(self.l6.weight.data, 1.0) 84 | 85 | def forward(self, x): 86 | r""" 87 | Feedforwards a batch of real/fake images and produces a batch of GAN logits. 88 | 89 | Args: 90 | x (Tensor): A batch of images of shape (N, C, H, W). 91 | 92 | Returns: 93 | Tensor: A batch of GAN logits of shape (N, 1). 94 | """ 95 | h = x 96 | h = self.block1(h) 97 | h = self.block2(h) 98 | h = self.block3(h) 99 | h = self.block4(h) 100 | h = self.block5(h) 101 | h = self.activation(h) 102 | 103 | # Global sum pooling 104 | h = torch.sum(h, dim=(2, 3)) 105 | output = self.l6(h) 106 | 107 | return output 108 | -------------------------------------------------------------------------------- /torch_mimicry/nets/sngan/sngan_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base implementation of SNGAN with default variables. 3 | """ 4 | from torch_mimicry.nets.gan import gan 5 | 6 | 7 | class SNGANBaseGenerator(gan.BaseGenerator): 8 | r""" 9 | ResNet backbone generator for SNGAN. 10 | 11 | Attributes: 12 | nz (int): Noise dimension for upsampling. 13 | ngf (int): Variable controlling generator feature map sizes. 14 | bottom_width (int): Starting width for upsampling generator output to an image. 15 | loss_type (str): Name of loss to use for GAN loss. 16 | """ 17 | def __init__(self, nz, ngf, bottom_width, loss_type='hinge', **kwargs): 18 | super().__init__(nz=nz, 19 | ngf=ngf, 20 | bottom_width=bottom_width, 21 | loss_type=loss_type, 22 | **kwargs) 23 | 24 | 25 | class SNGANBaseDiscriminator(gan.BaseDiscriminator): 26 | r""" 27 | ResNet backbone discriminator for SNGAN. 28 | 29 | Attributes: 30 | ndf (int): Variable controlling discriminator feature map sizes. 31 | """ 32 | def __init__(self, ndf, loss_type='hinge', **kwargs): 33 | super().__init__(ndf=ndf, loss_type=loss_type, **kwargs) 34 | -------------------------------------------------------------------------------- /torch_mimicry/nets/ssgan/__init__.py: -------------------------------------------------------------------------------- 1 | from .ssgan_128 import * 2 | from .ssgan_32 import * 3 | from .ssgan_48 import * 4 | from .ssgan_64 import * 5 | from .ssgan_base import * 6 | -------------------------------------------------------------------------------- /torch_mimicry/nets/ssgan/ssgan_32.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of SSGAN for image size 32. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torch_mimicry.modules import SNLinear 8 | from torch_mimicry.modules.resblocks import DBlockOptimized, DBlock, GBlock 9 | from torch_mimicry.nets.ssgan import ssgan_base 10 | 11 | 12 | class SSGANGenerator32(ssgan_base.SSGANBaseGenerator): 13 | r""" 14 | ResNet backbone generator for SSGAN. 15 | 16 | Attributes: 17 | nz (int): Noise dimension for upsampling. 18 | ngf (int): Variable controlling generator feature map sizes. 19 | bottom_width (int): Starting width for upsampling generator output to an image. 20 | loss_type (str): Name of loss to use for GAN loss. 21 | ss_loss_scale (float): Self-supervised loss scale for generator. 22 | """ 23 | def __init__(self, nz=128, ngf=256, bottom_width=4, **kwargs): 24 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs) 25 | 26 | # Build the layers 27 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf) 28 | self.block2 = GBlock(self.ngf, self.ngf, upsample=True) 29 | self.block3 = GBlock(self.ngf, self.ngf, upsample=True) 30 | self.block4 = GBlock(self.ngf, self.ngf, upsample=True) 31 | self.b5 = nn.BatchNorm2d(self.ngf) 32 | self.c5 = nn.Conv2d(ngf, 3, 3, 1, padding=1) 33 | self.activation = nn.ReLU(True) 34 | 35 | # Initialise the weights 36 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0) 37 | nn.init.xavier_uniform_(self.c5.weight.data, 1.0) 38 | 39 | def forward(self, x): 40 | r""" 41 | Feedforwards a batch of noise vectors into a batch of fake images. 42 | 43 | Args: 44 | x (Tensor): A batch of noise vectors of shape (N, nz). 45 | 46 | Returns: 47 | Tensor: A batch of fake images of shape (N, C, H, W). 48 | """ 49 | h = self.l1(x) 50 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width) 51 | h = self.block2(h) 52 | h = self.block3(h) 53 | h = self.block4(h) 54 | h = self.b5(h) 55 | h = self.activation(h) 56 | h = torch.tanh(self.c5(h)) 57 | 58 | return h 59 | 60 | 61 | class SSGANDiscriminator32(ssgan_base.SSGANBaseDiscriminator): 62 | r""" 63 | ResNet backbone discriminator for SSGAN. 64 | 65 | Attributes: 66 | ndf (int): Variable controlling discriminator feature map sizes. 67 | loss_type (str): Name of loss to use for GAN loss. 68 | ss_loss_scale (float): Self-supervised loss scale for discriminator. 69 | """ 70 | def __init__(self, ndf=128, **kwargs): 71 | super().__init__(ndf=ndf, **kwargs) 72 | 73 | # Build layers 74 | self.block1 = DBlockOptimized(3, self.ndf) 75 | self.block2 = DBlock(self.ndf, self.ndf, downsample=True) 76 | self.block3 = DBlock(self.ndf, self.ndf, downsample=False) 77 | self.block4 = DBlock(self.ndf, self.ndf, downsample=False) 78 | self.l5 = SNLinear(self.ndf, 1) 79 | 80 | # Rotation class prediction layer 81 | self.l_y = SNLinear(self.ndf, self.num_classes) 82 | 83 | # Initialise the weights 84 | nn.init.xavier_uniform_(self.l5.weight.data, 1.0) 85 | nn.init.xavier_uniform_(self.l_y.weight.data, 1.0) 86 | 87 | self.activation = nn.ReLU(True) 88 | 89 | def forward(self, x): 90 | r""" 91 | Feedforwards a batch of real/fake images and produces a batch of GAN logits, 92 | and rotation classes. 93 | 94 | Args: 95 | x (Tensor): A batch of images of shape (N, C, H, W). 96 | 97 | Returns: 98 | Tensor: A batch of GAN logits of shape (N, 1). 99 | Tensor: A batch of predicted classes of shape (N, num_classes). 100 | """ 101 | h = x 102 | h = self.block1(h) 103 | h = self.block2(h) 104 | h = self.block3(h) 105 | h = self.block4(h) 106 | h = self.activation(h) 107 | 108 | # Global sum pooling 109 | h = torch.sum(h, dim=(2, 3)) 110 | output = self.l5(h) 111 | 112 | # Produce the class output logits 113 | output_classes = self.l_y(h) 114 | 115 | return output, output_classes 116 | -------------------------------------------------------------------------------- /torch_mimicry/nets/wgan_gp/__init__.py: -------------------------------------------------------------------------------- 1 | from .wgan_gp_128 import * 2 | from .wgan_gp_32 import * 3 | from .wgan_gp_48 import * 4 | from .wgan_gp_64 import * 5 | from .wgan_gp_base import * 6 | from .wgan_gp_resblocks import * 7 | -------------------------------------------------------------------------------- /torch_mimicry/nets/wgan_gp/wgan_gp_128.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of WGAN-GP for image size 128. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torch_mimicry.nets.wgan_gp import wgan_gp_base 8 | from torch_mimicry.nets.wgan_gp.wgan_gp_resblocks import DBlockOptimized, DBlock, GBlock 9 | 10 | 11 | class WGANGPGenerator128(wgan_gp_base.WGANGPBaseGenerator): 12 | r""" 13 | ResNet backbone generator for WGAN-GP. 14 | 15 | Attributes: 16 | nz (int): Noise dimension for upsampling. 17 | ngf (int): Variable controlling generator feature map sizes. 18 | bottom_width (int): Starting width for upsampling generator output to an image. 19 | loss_type (str): Name of loss to use for GAN loss. 20 | """ 21 | def __init__(self, nz=128, ngf=1024, bottom_width=4, **kwargs): 22 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs) 23 | 24 | # Build the layers 25 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf) 26 | self.block2 = GBlock(self.ngf, self.ngf, upsample=True) 27 | self.block3 = GBlock(self.ngf, self.ngf >> 1, upsample=True) 28 | self.block4 = GBlock(self.ngf >> 1, self.ngf >> 2, upsample=True) 29 | self.block5 = GBlock(self.ngf >> 2, self.ngf >> 3, upsample=True) 30 | self.block6 = GBlock(self.ngf >> 3, self.ngf >> 4, upsample=True) 31 | self.b7 = nn.BatchNorm2d(self.ngf >> 4) 32 | self.c7 = nn.Conv2d(self.ngf >> 4, 3, 3, 1, padding=1) 33 | self.activation = nn.ReLU(True) 34 | 35 | # Initialise the weights 36 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0) 37 | 38 | def forward(self, x): 39 | r""" 40 | Feedforwards a batch of noise vectors into a batch of fake images. 41 | 42 | Args: 43 | x (Tensor): A batch of noise vectors of shape (N, nz). 44 | 45 | Returns: 46 | Tensor: A batch of fake images of shape (N, C, H, W). 47 | """ 48 | h = self.l1(x) 49 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width) 50 | h = self.block2(h) 51 | h = self.block3(h) 52 | h = self.block4(h) 53 | h = self.block5(h) 54 | h = self.block6(h) 55 | h = self.b7(h) 56 | h = self.activation(h) 57 | h = torch.tanh(self.c7(h)) 58 | 59 | return h 60 | 61 | 62 | class WGANGPDiscriminator128(wgan_gp_base.WGANGPBaseDiscriminator): 63 | r""" 64 | ResNet backbone discriminator for WGAN-GP. 65 | 66 | Attributes: 67 | ndf (int): Variable controlling discriminator feature map sizes. 68 | loss_type (str): Name of loss to use for GAN loss. 69 | gp_scale (float): Lamda parameter for gradient penalty. 70 | """ 71 | def __init__(self, ndf=1024, **kwargs): 72 | super().__init__(ndf=ndf, **kwargs) 73 | 74 | # Build layers 75 | self.block1 = DBlockOptimized(3, self.ndf >> 4) 76 | self.block2 = DBlock(self.ndf >> 4, self.ndf >> 3, downsample=True) 77 | self.block3 = DBlock(self.ndf >> 3, self.ndf >> 2, downsample=True) 78 | self.block4 = DBlock(self.ndf >> 2, self.ndf >> 1, downsample=True) 79 | self.block5 = DBlock(self.ndf >> 1, self.ndf, downsample=True) 80 | self.block6 = DBlock(self.ndf, self.ndf, downsample=False) 81 | self.l7 = nn.Linear(self.ndf, 1) 82 | self.activation = nn.ReLU(True) 83 | 84 | # Initialise the weights 85 | nn.init.xavier_uniform_(self.l7.weight.data, 1.0) 86 | 87 | def forward(self, x): 88 | r""" 89 | Feedforwards a batch of real/fake images and produces a batch of GAN logits. 90 | 91 | Args: 92 | x (Tensor): A batch of images of shape (N, C, H, W). 93 | 94 | Returns: 95 | Tensor: A batch of GAN logits of shape (N, 1). 96 | """ 97 | h = x 98 | h = self.block1(h) 99 | h = self.block2(h) 100 | h = self.block3(h) 101 | h = self.block4(h) 102 | h = self.block5(h) 103 | h = self.block6(h) 104 | h = self.activation(h) 105 | 106 | # Global average pooling 107 | h = torch.mean(h, dim=(2, 3)) # WGAN uses mean pooling 108 | output = self.l7(h) 109 | 110 | return output 111 | -------------------------------------------------------------------------------- /torch_mimicry/nets/wgan_gp/wgan_gp_32.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of WGAN-GP for image size 32. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torch_mimicry.nets.wgan_gp import wgan_gp_base 8 | from torch_mimicry.nets.wgan_gp.wgan_gp_resblocks import DBlockOptimized, DBlock, GBlock 9 | 10 | 11 | class WGANGPGenerator32(wgan_gp_base.WGANGPBaseGenerator): 12 | r""" 13 | ResNet backbone generator for WGAN-GP. 14 | 15 | Attributes: 16 | nz (int): Noise dimension for upsampling. 17 | ngf (int): Variable controlling generator feature map sizes. 18 | bottom_width (int): Starting width for upsampling generator output to an image. 19 | loss_type (str): Name of loss to use for GAN loss. 20 | """ 21 | def __init__(self, nz=128, ngf=256, bottom_width=4, **kwargs): 22 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs) 23 | 24 | # Build the layers 25 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf) 26 | self.block2 = GBlock(self.ngf, self.ngf, upsample=True) 27 | self.block3 = GBlock(self.ngf, self.ngf, upsample=True) 28 | self.block4 = GBlock(self.ngf, self.ngf, upsample=True) 29 | self.b5 = nn.BatchNorm2d(self.ngf) 30 | self.c5 = nn.Conv2d(self.ngf, 3, 3, 1, padding=1) 31 | self.activation = nn.ReLU(True) 32 | 33 | # Initialise the weights 34 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0) 35 | 36 | def forward(self, x): 37 | r""" 38 | Feedforwards a batch of noise vectors into a batch of fake images. 39 | 40 | Args: 41 | x (Tensor): A batch of noise vectors of shape (N, nz). 42 | 43 | Returns: 44 | Tensor: A batch of fake images of shape (N, C, H, W). 45 | """ 46 | h = self.l1(x) 47 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width) 48 | h = self.block2(h) 49 | h = self.block3(h) 50 | h = self.block4(h) 51 | h = self.b5(h) 52 | h = self.activation(h) 53 | h = torch.tanh(self.c5(h)) 54 | 55 | return h 56 | 57 | 58 | class WGANGPDiscriminator32(wgan_gp_base.WGANGPBaseDiscriminator): 59 | r""" 60 | ResNet backbone discriminator for WGAN-GP. 61 | 62 | Attributes: 63 | ndf (int): Variable controlling discriminator feature map sizes. 64 | loss_type (str): Name of loss to use for GAN loss. 65 | gp_scale (float): Lamda parameter for gradient penalty. 66 | """ 67 | def __init__(self, ndf=128, **kwargs): 68 | super().__init__(ndf=ndf, **kwargs) 69 | 70 | # Build layers 71 | self.block1 = DBlockOptimized(3, self.ndf) 72 | self.block2 = DBlock(self.ndf, self.ndf, downsample=True) 73 | self.block3 = DBlock(self.ndf, self.ndf, downsample=False) 74 | self.block4 = DBlock(self.ndf, self.ndf, downsample=False) 75 | self.l5 = nn.Linear(self.ndf, 1) 76 | 77 | self.activation = nn.ReLU(True) 78 | 79 | # Initialise the weights 80 | nn.init.xavier_uniform_(self.l5.weight.data, 1.0) 81 | 82 | def forward(self, x): 83 | r""" 84 | Feedforwards a batch of real/fake images and produces a batch of GAN logits. 85 | 86 | Args: 87 | x (Tensor): A batch of images of shape (N, C, H, W). 88 | 89 | Returns: 90 | Tensor: A batch of GAN logits of shape (N, 1). 91 | """ 92 | h = x 93 | h = self.block1(h) 94 | h = self.block2(h) 95 | h = self.block3(h) 96 | h = self.block4(h) 97 | h = self.activation(h) 98 | 99 | # Global average pooling 100 | h = torch.mean(h, dim=(2, 3)) # WGAN uses mean pooling 101 | output = self.l5(h) 102 | 103 | return output 104 | -------------------------------------------------------------------------------- /torch_mimicry/nets/wgan_gp/wgan_gp_48.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of WGAN-GP for image size 48. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torch_mimicry.nets.wgan_gp import wgan_gp_base 8 | from torch_mimicry.nets.wgan_gp.wgan_gp_resblocks import DBlockOptimized, DBlock, GBlock 9 | 10 | 11 | class WGANGPGenerator48(wgan_gp_base.WGANGPBaseGenerator): 12 | r""" 13 | ResNet backbone generator for WGAN-GP. 14 | 15 | Attributes: 16 | nz (int): Noise dimension for upsampling. 17 | ngf (int): Variable controlling generator feature map sizes. 18 | bottom_width (int): Starting width for upsampling generator output to an image. 19 | loss_type (str): Name of loss to use for GAN loss. 20 | """ 21 | def __init__(self, nz=128, ngf=512, bottom_width=6, **kwargs): 22 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs) 23 | 24 | # Build the layers 25 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf) 26 | self.block2 = GBlock(self.ngf, self.ngf >> 1, upsample=True) 27 | self.block3 = GBlock(self.ngf >> 1, self.ngf >> 2, upsample=True) 28 | self.block4 = GBlock(self.ngf >> 2, self.ngf >> 3, upsample=True) 29 | self.b5 = nn.BatchNorm2d(self.ngf >> 3) 30 | self.c5 = nn.Conv2d(self.ngf >> 3, 3, 3, 1, padding=1) 31 | self.activation = nn.ReLU(True) 32 | 33 | # Initialise the weights 34 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0) 35 | 36 | def forward(self, x): 37 | r""" 38 | Feedforwards a batch of noise vectors into a batch of fake images. 39 | 40 | Args: 41 | x (Tensor): A batch of noise vectors of shape (N, nz). 42 | 43 | Returns: 44 | Tensor: A batch of fake images of shape (N, C, H, W). 45 | """ 46 | h = self.l1(x) 47 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width) 48 | h = self.block2(h) 49 | h = self.block3(h) 50 | h = self.block4(h) 51 | h = self.b5(h) 52 | h = self.activation(h) 53 | h = torch.tanh(self.c5(h)) 54 | 55 | return h 56 | 57 | 58 | class WGANGPDiscriminator48(wgan_gp_base.WGANGPBaseDiscriminator): 59 | r""" 60 | ResNet backbone discriminator for WGAN-GP. 61 | 62 | Attributes: 63 | ndf (int): Variable controlling discriminator feature map sizes. 64 | loss_type (str): Name of loss to use for GAN loss. 65 | gp_scale (float): Lamda parameter for gradient penalty. 66 | """ 67 | def __init__(self, ndf=1024, **kwargs): 68 | super().__init__(ndf=ndf, **kwargs) 69 | 70 | # Build layers 71 | self.block1 = DBlockOptimized(3, self.ndf >> 4) 72 | self.block2 = DBlock(self.ndf >> 4, self.ndf >> 3, downsample=True) 73 | self.block3 = DBlock(self.ndf >> 3, self.ndf >> 2, downsample=True) 74 | self.block4 = DBlock(self.ndf >> 2, self.ndf >> 1, downsample=True) 75 | self.block5 = DBlock(self.ndf >> 1, self.ndf, downsample=False) 76 | self.l6 = nn.Linear(self.ndf, 1) 77 | 78 | self.activation = nn.ReLU(True) 79 | 80 | # Initialise the weights 81 | nn.init.xavier_uniform_(self.l6.weight.data, 1.0) 82 | 83 | def forward(self, x): 84 | r""" 85 | Feedforwards a batch of real/fake images and produces a batch of GAN logits. 86 | 87 | Args: 88 | x (Tensor): A batch of images of shape (N, C, H, W). 89 | 90 | Returns: 91 | Tensor: A batch of GAN logits of shape (N, 1). 92 | """ 93 | h = x 94 | h = self.block1(h) 95 | h = self.block2(h) 96 | h = self.block3(h) 97 | h = self.block4(h) 98 | h = self.block5(h) 99 | h = self.activation(h) 100 | 101 | # Global average pooling 102 | h = torch.mean(h, dim=(2, 3)) # WGAN uses mean pooling 103 | output = self.l6(h) 104 | 105 | return output 106 | -------------------------------------------------------------------------------- /torch_mimicry/nets/wgan_gp/wgan_gp_64.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of WGAN-GP for image size 64. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torch_mimicry.nets.wgan_gp import wgan_gp_base 8 | from torch_mimicry.nets.wgan_gp.wgan_gp_resblocks import DBlockOptimized, DBlock, GBlock 9 | 10 | 11 | class WGANGPGenerator64(wgan_gp_base.WGANGPBaseGenerator): 12 | r""" 13 | ResNet backbone generator for WGAN-GP. 14 | 15 | Attributes: 16 | nz (int): Noise dimension for upsampling. 17 | ngf (int): Variable controlling generator feature map sizes. 18 | bottom_width (int): Starting width for upsampling generator output to an image. 19 | loss_type (str): Name of loss to use for GAN loss. 20 | """ 21 | def __init__(self, nz=128, ngf=1024, bottom_width=4, **kwargs): 22 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs) 23 | 24 | # Build the layers 25 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf) 26 | self.block2 = GBlock(self.ngf, self.ngf >> 1, upsample=True) 27 | self.block3 = GBlock(self.ngf >> 1, self.ngf >> 2, upsample=True) 28 | self.block4 = GBlock(self.ngf >> 2, self.ngf >> 3, upsample=True) 29 | self.block5 = GBlock(self.ngf >> 3, self.ngf >> 4, upsample=True) 30 | self.b6 = nn.BatchNorm2d(self.ngf >> 4) 31 | self.c6 = nn.Conv2d(self.ngf >> 4, 3, 3, 1, padding=1) 32 | self.activation = nn.ReLU(True) 33 | 34 | # Initialise the weights 35 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0) 36 | 37 | def forward(self, x): 38 | r""" 39 | Feedforwards a batch of noise vectors into a batch of fake images. 40 | 41 | Args: 42 | x (Tensor): A batch of noise vectors of shape (N, nz). 43 | 44 | Returns: 45 | Tensor: A batch of fake images of shape (N, C, H, W). 46 | """ 47 | h = self.l1(x) 48 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width) 49 | h = self.block2(h) 50 | h = self.block3(h) 51 | h = self.block4(h) 52 | h = self.block5(h) 53 | h = self.b6(h) 54 | h = self.activation(h) 55 | h = torch.tanh(self.c6(h)) 56 | 57 | return h 58 | 59 | 60 | class WGANGPDiscriminator64(wgan_gp_base.WGANGPBaseDiscriminator): 61 | r""" 62 | ResNet backbone discriminator for WGAN-GP. 63 | 64 | Attributes: 65 | ndf (int): Variable controlling discriminator feature map sizes. 66 | loss_type (str): Name of loss to use for GAN loss. 67 | gp_scale (float): Lamda parameter for gradient penalty. 68 | """ 69 | def __init__(self, ndf=1024, **kwargs): 70 | super().__init__(ndf=ndf, **kwargs) 71 | 72 | # Build layers 73 | self.block1 = DBlockOptimized(3, self.ndf >> 4) 74 | self.block2 = DBlock(self.ndf >> 4, self.ndf >> 3, downsample=True) 75 | self.block3 = DBlock(self.ndf >> 3, self.ndf >> 2, downsample=True) 76 | self.block4 = DBlock(self.ndf >> 2, self.ndf >> 1, downsample=True) 77 | self.block5 = DBlock(self.ndf >> 1, self.ndf, downsample=True) 78 | self.l6 = nn.Linear(self.ndf, 1) 79 | self.activation = nn.ReLU(True) 80 | 81 | # Initialise the weights 82 | nn.init.xavier_uniform_(self.l6.weight.data, 1.0) 83 | 84 | def forward(self, x): 85 | r""" 86 | Feedforwards a batch of real/fake images and produces a batch of GAN logits. 87 | 88 | Args: 89 | x (Tensor): A batch of images of shape (N, C, H, W). 90 | 91 | Returns: 92 | Tensor: A batch of GAN logits of shape (N, 1). 93 | """ 94 | h = x 95 | h = self.block1(h) 96 | h = self.block2(h) 97 | h = self.block3(h) 98 | h = self.block4(h) 99 | h = self.block5(h) 100 | h = self.activation(h) 101 | 102 | # Global average pooling 103 | h = torch.mean(h, dim=(2, 3)) # WGAN uses mean pooling 104 | output = self.l6(h) 105 | 106 | return output 107 | -------------------------------------------------------------------------------- /torch_mimicry/training/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | from .metric_log import * 3 | from .scheduler import * 4 | from .trainer import * -------------------------------------------------------------------------------- /torch_mimicry/training/metric_log.py: -------------------------------------------------------------------------------- 1 | """ 2 | MetricLog object for intelligently logging data to display them more intuitively. 3 | """ 4 | 5 | 6 | class MetricLog: 7 | """ 8 | A dictionary-like object that logs data, and includes an extra dict to map the metrics 9 | to its group name, if any, and the corresponding precision to print out. 10 | 11 | Attributes: 12 | metrics_dict (dict): A dictionary mapping to another dict containing 13 | the corresponding value, precision, and the group this metric belongs to. 14 | """ 15 | def __init__(self, **kwargs): 16 | self.metrics_dict = {} 17 | 18 | def add_metric(self, name, value, group=None, precision=4): 19 | """ 20 | Logs metric to internal dict, but with an additional option 21 | of grouping certain metrics together. 22 | 23 | Args: 24 | name (str): Name of metric to log. 25 | value (Tensor/Float): Value of the metric to log. 26 | group (str): Name of the group to classify different metrics together. 27 | precision (int): The number of floating point precision to represent the value. 28 | 29 | Returns: 30 | None 31 | """ 32 | # Grab tensor values only 33 | try: 34 | value = value.item() 35 | except AttributeError: 36 | value = value 37 | 38 | self.metrics_dict[name] = dict(value=value, 39 | group=group, 40 | precision=precision) 41 | 42 | def __getitem__(self, key): 43 | return round(self.metrics_dict[key]['value'], 44 | self.metrics_dict[key]['precision']) 45 | 46 | def get_group_name(self, name): 47 | """ 48 | Obtains the group name of a particular metric. For example, errD and errG 49 | which represents the discriminator/generator losses could fall under a 50 | group name called "loss". 51 | 52 | Args: 53 | name (str): The name of the metric to retrieve group name. 54 | 55 | Returns: 56 | str: A string representing the group name of the metric. 57 | """ 58 | return self.metrics_dict[name]['group'] 59 | 60 | def keys(self): 61 | """ 62 | Dict like functionality for retrieving keys. 63 | """ 64 | return self.metrics_dict.keys() 65 | 66 | def items(self): 67 | """ 68 | Dict like functionality for retrieving items. 69 | """ 70 | return self.metrics_dict.items() 71 | -------------------------------------------------------------------------------- /torch_mimicry/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import * 2 | -------------------------------------------------------------------------------- /torch_mimicry/utils/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for common utility functions. 3 | """ 4 | import json 5 | import os 6 | 7 | import numpy as np 8 | import torch 9 | from skimage import io 10 | 11 | 12 | def write_to_json(dict_to_write, output_file): 13 | """ 14 | Outputs a given dictionary as a JSON file with indents. 15 | 16 | Args: 17 | dict_to_write (dict): Input dictionary to output. 18 | output_file (str): File path to write the dictionary. 19 | 20 | Returns: 21 | None 22 | """ 23 | with open(output_file, 'w') as file: 24 | json.dump(dict_to_write, file, indent=4) 25 | 26 | 27 | def load_from_json(json_file): 28 | """ 29 | Loads a JSON file as a dictionary and return it. 30 | 31 | Args: 32 | json_file (str): Input JSON file to read. 33 | 34 | Returns: 35 | dict: Dictionary loaded from the JSON file. 36 | """ 37 | with open(json_file, 'r') as file: 38 | return json.load(file) 39 | 40 | 41 | def save_tensor_image(x, output_file): 42 | """ 43 | Saves an input image tensor as some numpy array, useful for tests. 44 | 45 | Args: 46 | x (Tensor): A 3D tensor image of shape (3, H, W). 47 | output_file (str): The output image file to save the tensor. 48 | 49 | Returns: 50 | None 51 | """ 52 | folder = os.path.dirname(output_file) 53 | if not os.path.exists(folder): 54 | os.makedirs(folder) 55 | 56 | x = x.permute(1, 2, 0).numpy() 57 | io.imsave(output_file, x) 58 | 59 | 60 | def load_images(n=1, size=32): 61 | """ 62 | Load n image tensors with some fake labels. 63 | 64 | Args: 65 | n (int): Number of random images to load. 66 | size (int): Spatial size of random image. 67 | 68 | Returns: 69 | Tensor: Random images of shape (n, 3, size, size) and 0-valued labels. 70 | """ 71 | images = torch.randn(n, 3, size, size) 72 | labels = torch.from_numpy(np.array([0 * n])) 73 | 74 | return images, labels --------------------------------------------------------------------------------