├── .circleci
└── config.yml
├── .coveragerc
├── .readthedocs.yml
├── .style.yapf
├── LICENSE
├── README.md
├── docs
├── Makefile
├── gallery
│ └── README.md
├── images
│ ├── celeba_128.png
│ ├── cifar10.png
│ ├── cifar100.png
│ ├── imagenet_32.png
│ ├── lsun_bedroom_128.png
│ ├── mimicry_logo.png
│ └── stl10_48.png
├── requirements.txt
└── source
│ ├── _static
│ ├── css
│ │ └── custom.css
│ └── img
│ │ └── mimicry_logo.png
│ ├── conf.py
│ ├── guides
│ ├── baselines.rst
│ ├── images
│ │ ├── fake_vis.png
│ │ ├── fixed_fake_vis.png
│ │ ├── output.gif
│ │ ├── resnet_backbone.png
│ │ ├── ssgan.png
│ │ └── tensorboard.png
│ ├── introduction.rst
│ └── tutorial.rst
│ ├── index.rst
│ └── modules
│ ├── datasets.rst
│ ├── metrics.rst
│ ├── modules.rst
│ ├── nets.rst
│ ├── training.rst
│ └── utils.rst
├── examples
├── eval_pretrained.py
├── sngan_example.py
└── ssgan_tutorial.py
├── pytest.ini
├── requirements.txt
├── setup.cfg
├── setup.py
├── tests
├── datasets
│ ├── imagenet
│ │ └── test.bin
│ ├── test_data_utils.py
│ └── test_image_loader.py
├── metrics
│ ├── fid
│ │ └── test_fid.py
│ ├── inception_model
│ │ └── test_inception_utils.py
│ ├── inception_score
│ │ └── test_inception_score.py
│ ├── kid
│ │ └── test_kid.py
│ ├── test_compute_fid.py
│ ├── test_compute_is.py
│ ├── test_compute_kid.py
│ └── test_compute_metrics.py
├── modules
│ ├── test_layers.py
│ ├── test_losses.py
│ ├── test_resblocks.py
│ └── test_spectral_norm.py
├── nets
│ ├── basemodel
│ │ └── test_basemodel.py
│ ├── cgan_pd
│ │ ├── test_cgan_pd_128.py
│ │ └── test_cgan_pd_32.py
│ ├── dcgan
│ │ ├── test_dcgan_128.py
│ │ ├── test_dcgan_32.py
│ │ ├── test_dcgan_48.py
│ │ ├── test_dcgan_64.py
│ │ └── test_dcgan_cifar.py
│ ├── gan
│ │ ├── test_cgan.py
│ │ └── test_gan.py
│ ├── infomax_gan
│ │ ├── test_infomax_gan_128.py
│ │ ├── test_infomax_gan_32.py
│ │ ├── test_infomax_gan_48.py
│ │ ├── test_infomax_gan_64.py
│ │ └── test_infomax_gan_base.py
│ ├── sagan
│ │ ├── test_sagan_128.py
│ │ └── test_sagan_32.py
│ ├── sngan
│ │ ├── test_sngan_128.py
│ │ ├── test_sngan_32.py
│ │ ├── test_sngan_48.py
│ │ └── test_sngan_64.py
│ ├── ssgan
│ │ ├── test_ssgan_128.py
│ │ ├── test_ssgan_32.py
│ │ ├── test_ssgan_48.py
│ │ ├── test_ssgan_64.py
│ │ └── test_ssgan_base.py
│ └── wgan_gp
│ │ ├── test_wgan_gp_128.py
│ │ ├── test_wgan_gp_32.py
│ │ ├── test_wgan_gp_48.py
│ │ ├── test_wgan_gp_64.py
│ │ └── test_wgan_gp_resblocks.py
├── training
│ ├── test_logger.py
│ ├── test_metric_log.py
│ ├── test_scheduler.py
│ └── test_trainer.py
└── utils
│ └── test_common.py
└── torch_mimicry
├── __init__.py
├── datasets
├── __init__.py
├── data_utils.py
├── image_loader.py
└── imagenet
│ ├── __init__.py
│ └── imagenet.py
├── metrics
├── __init__.py
├── compute_fid.py
├── compute_is.py
├── compute_kid.py
├── compute_metrics.py
├── fid
│ ├── __init__.py
│ └── fid_utils.py
├── inception_model
│ ├── __init__.py
│ └── inception_utils.py
├── inception_score
│ ├── __init__.py
│ └── inception_score_utils.py
└── kid
│ ├── __init__.py
│ └── kid_utils.py
├── modules
├── __init__.py
├── layers.py
├── losses.py
├── resblocks.py
└── spectral_norm.py
├── nets
├── __init__.py
├── basemodel
│ ├── __init__.py
│ └── basemodel.py
├── cgan_pd
│ ├── __init__.py
│ ├── cgan_pd_128.py
│ ├── cgan_pd_32.py
│ └── cgan_pd_base.py
├── dcgan
│ ├── __init__.py
│ ├── dcgan_128.py
│ ├── dcgan_32.py
│ ├── dcgan_48.py
│ ├── dcgan_64.py
│ ├── dcgan_base.py
│ └── dcgan_cifar.py
├── gan
│ ├── __init__.py
│ ├── cgan.py
│ └── gan.py
├── infomax_gan
│ ├── __init__.py
│ ├── infomax_gan_128.py
│ ├── infomax_gan_32.py
│ ├── infomax_gan_48.py
│ ├── infomax_gan_64.py
│ └── infomax_gan_base.py
├── sagan
│ ├── __init__.py
│ ├── sagan_128.py
│ ├── sagan_32.py
│ └── sagan_base.py
├── sngan
│ ├── __init__.py
│ ├── sngan_128.py
│ ├── sngan_32.py
│ ├── sngan_48.py
│ ├── sngan_64.py
│ └── sngan_base.py
├── ssgan
│ ├── __init__.py
│ ├── ssgan_128.py
│ ├── ssgan_32.py
│ ├── ssgan_48.py
│ ├── ssgan_64.py
│ └── ssgan_base.py
└── wgan_gp
│ ├── __init__.py
│ ├── wgan_gp_128.py
│ ├── wgan_gp_32.py
│ ├── wgan_gp_48.py
│ ├── wgan_gp_64.py
│ ├── wgan_gp_base.py
│ └── wgan_gp_resblocks.py
├── training
├── __init__.py
├── logger.py
├── metric_log.py
├── scheduler.py
└── trainer.py
└── utils
├── __init__.py
└── common.py
/.circleci/config.yml:
--------------------------------------------------------------------------------
1 | # Python CircleCI 2.1 configuration file
2 | version: 2.1
3 | orbs:
4 | codecov: codecov/codecov@1.0.5
5 | jobs:
6 | build:
7 | # machine:
8 | # image: ubuntu-1604:201903-01 # recommended linux image - includes Ubuntu 16.04, docker 18.09.3, docker-compose 1.23.1
9 |
10 | docker:
11 | # specify the version you desire here
12 | # use `-browsers` prefix for selenium tests, e.g. `3.6.1-browsers`
13 | - image: cimg/python:3.8.13
14 |
15 | # Specify service dependencies here if necessary
16 | # CircleCI maintains a library of pre-built images
17 | # documented at https://circleci.com/docs/2.0/circleci-images/
18 | # - image: circleci/postgres:9.4
19 |
20 | working_directory: ~/repo
21 | parallelism: 4
22 | resource_class: large
23 |
24 | steps:
25 | - checkout
26 |
27 | # Download and cache dependencies
28 | - restore_cache:
29 | keys:
30 | - v1-dependencies-{{ checksum "requirements.txt" }}
31 | # fallback to using the latest cache if no exact match is found
32 | - v1-dependencies-
33 |
34 | - run:
35 | name: install dependencies
36 | command: |
37 | python3 -m venv venv
38 | . venv/bin/activate
39 | pip install pip==22.1.2 && pip install -r requirements.txt
40 |
41 | - save_cache:
42 | paths:
43 | - ./venv
44 | key: v1-dependencies-{{ checksum "requirements.txt" }}
45 |
46 | # run tests!
47 | # this example uses Django's built-in test-runner
48 | # other common Python testing frameworks include pytest and nose
49 | # https://pytest.org
50 | # https://nose.readthedocs.io
51 | - run:
52 | name: run tests
53 | command: |
54 | . venv/bin/activate
55 | python -m pytest --cov=./ --cov-report=xml:coverage.xml tests
56 |
57 | - codecov/upload:
58 | file: ./coverage.xml
59 | token: e963353d-df2c-4b11-bcc8-c8295b40bd7f
60 |
61 | - store_artifacts:
62 | path: test-reports
63 | destination: test-reports
64 |
--------------------------------------------------------------------------------
/.coveragerc:
--------------------------------------------------------------------------------
1 | [report]
2 | exclude_lines =
3 | if __name__ == .__main__.:
4 |
5 | [run]
6 | omit =
7 | *test*
8 | torch_mimicry/datasets/imagenet/*
9 | setup.py
10 |
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | build:
9 | image: stable
10 |
11 | python:
12 | version: 3.6
13 | system_packages: true
14 | install:
15 | - requirements: docs/requirements.txt
16 | - method: setuptools
17 | path: .
18 |
19 | formats: []
20 |
--------------------------------------------------------------------------------
/.style.yapf:
--------------------------------------------------------------------------------
1 | [style]
2 | based_on_style = pep8
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Kwot Sin Lee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | SPHINXBUILD = sphinx-build
2 | SPHINXPROJ = torch_mimicry
3 | SOURCEDIR = source
4 | BUILDDIR = build
5 |
6 | .PHONY: help Makefile
7 |
8 | %: Makefile
9 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)"
10 |
--------------------------------------------------------------------------------
/docs/gallery/README.md:
--------------------------------------------------------------------------------
1 | # Gallery
2 | We provide randomly sampled, non-cherry picked images for various datasets as produced by InfoMax-GAN.
3 |
4 | ### LSUN-Bedroom
5 | Resolution: 128 x 128
6 |
7 | ")
8 |
9 | ### CelebA
10 | Resolution: 128 x 128
11 |
12 | ")
13 |
14 | ### STL-10
15 | Resolution: 48 x 48
16 |
17 | ")
18 |
19 | ### ImageNet
20 | Resolution: 32 x 32
21 |
22 | ")
23 |
24 | ### CIFAR-10
25 | Resolution: 32 x 32
26 |
27 | ")
28 |
29 | ### CIFAR-100
30 | Resolution: 32 x 32
31 |
32 | ")
33 |
--------------------------------------------------------------------------------
/docs/images/celeba_128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/images/celeba_128.png
--------------------------------------------------------------------------------
/docs/images/cifar10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/images/cifar10.png
--------------------------------------------------------------------------------
/docs/images/cifar100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/images/cifar100.png
--------------------------------------------------------------------------------
/docs/images/imagenet_32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/images/imagenet_32.png
--------------------------------------------------------------------------------
/docs/images/lsun_bedroom_128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/images/lsun_bedroom_128.png
--------------------------------------------------------------------------------
/docs/images/mimicry_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/images/mimicry_logo.png
--------------------------------------------------------------------------------
/docs/images/stl10_48.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/images/stl10_48.png
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | python_version >= '3.6'
2 | scipy==1.4.1
3 | sphinx
4 | sphinx_rtd_theme
5 | torch
6 | torchvision
7 | numpy
8 | -e git+https://github.com/kwotsin/mimicry.git#egg=torch_mimicry
9 |
--------------------------------------------------------------------------------
/docs/source/_static/css/custom.css:
--------------------------------------------------------------------------------
1 | /* White logo background. */
2 | .wy-side-nav-search {
3 | background-color: #fff;
4 | }
5 |
6 | .wy-side-nav-search > div.version {
7 | color: #000;
8 | }
9 |
--------------------------------------------------------------------------------
/docs/source/_static/img/mimicry_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/source/_static/img/mimicry_logo.png
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import sphinx_rtd_theme
3 | import doctest
4 | import torch_mimicry
5 |
6 | extensions = [
7 | 'sphinx.ext.autodoc',
8 | 'sphinx.ext.doctest',
9 | 'sphinx.ext.intersphinx',
10 | 'sphinx.ext.mathjax',
11 | 'sphinx.ext.napoleon',
12 | 'sphinx.ext.viewcode',
13 | 'sphinx.ext.githubpages',
14 | ]
15 |
16 | source_suffix = '.rst'
17 | master_doc = 'index'
18 |
19 | author = 'Kwot Sin Lee'
20 | project = 'torch_mimicry'
21 | copyright = '{}, {}'.format(datetime.datetime.now().year, author)
22 |
23 | version = 'master'
24 | release = torch_mimicry.__version__
25 |
26 | html_theme = 'sphinx_rtd_theme'
27 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
28 |
29 | doctest_default_flags = doctest.NORMALIZE_WHITESPACE
30 | intersphinx_mapping = {'python': ('https://docs.python.org/', None)}
31 |
32 | html_theme_options = {
33 | 'collapse_navigation': False,
34 | 'display_version': True,
35 | 'logo_only': True,
36 | }
37 |
38 | html_logo = '_static/img/mimicry_logo.png'
39 | html_static_path = ['_static']
40 | html_context = {'css_files': ['_static/css/custom.css']}
41 |
42 | add_module_names = False
43 |
44 |
45 | def setup(app):
46 | def skip(app, what, name, obj, skip, options):
47 | members = [
48 | '__init__',
49 | '__repr__',
50 | '__weakref__',
51 | '__dict__',
52 | '__module__',
53 | ]
54 | return True if name in members else skip
55 |
56 | app.connect('autodoc-skip-member', skip)
57 |
--------------------------------------------------------------------------------
/docs/source/guides/images/fake_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/source/guides/images/fake_vis.png
--------------------------------------------------------------------------------
/docs/source/guides/images/fixed_fake_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/source/guides/images/fixed_fake_vis.png
--------------------------------------------------------------------------------
/docs/source/guides/images/output.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/source/guides/images/output.gif
--------------------------------------------------------------------------------
/docs/source/guides/images/resnet_backbone.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/source/guides/images/resnet_backbone.png
--------------------------------------------------------------------------------
/docs/source/guides/images/ssgan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/source/guides/images/ssgan.png
--------------------------------------------------------------------------------
/docs/source/guides/images/tensorboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/docs/source/guides/images/tensorboard.png
--------------------------------------------------------------------------------
/docs/source/guides/introduction.rst:
--------------------------------------------------------------------------------
1 | Introduction
2 | ============
3 |
4 | Installation
5 | ------------
6 | For best performance, we recommend you to install the GPU versions of both TensorFlow and PyTorch, which are used in this library.
7 |
8 | First create a new environment with :code:`conda` using Python 3.6 or 3.7 (Python 3.8 is currently not supported by TensorFlow):
9 |
10 | .. code-block:: none
11 |
12 | $ conda create -n mimicry python=3.6
13 |
14 | Install TensorFlow (GPU):
15 |
16 | .. code-block:: none
17 |
18 | $ conda install -c anaconda tensorflow-gpu
19 |
20 | Install PyTorch (GPU):
21 |
22 | .. code-block:: none
23 |
24 | $ conda install pytorch torchvision cudatoolkit=9.2 -c pytorch
25 |
26 | For installing the CUDA version matching your drivers, see the `official instructions `_.
27 |
28 | Mimicry can be installed using :code:`pip` directly:
29 |
30 | .. code-block:: none
31 |
32 | $ pip install torch-mimicry
33 |
34 | .. To install the GPU versions of TensorFlow and PyTorch see the official `PyTorch `_ and `TensorFlow `_ pages based on your system. The library is currently only compatible with :code:`TensorFlow 1.x` to accommodate the original implementations of the GAN metrics.
35 |
36 | Quick Start
37 | -----------
38 | We provide a sample training script for training the `Spectral Normalization GAN `_ on the CIFAR-10 dataset, with the same training hyperparameters that reproduce results in the paper.
39 |
40 |
41 | .. code-block:: python
42 |
43 | import torch
44 | import torch.optim as optim
45 | import torch_mimicry as mmc
46 | from torch_mimicry.nets import sngan
47 |
48 |
49 | # Data handling objects
50 | device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
51 | dataset = mmc.datasets.load_dataset(root='./datasets', name='cifar10')
52 | dataloader = torch.utils.data.DataLoader(
53 | dataset, batch_size=64, shuffle=True, num_workers=4)
54 |
55 | # Define models and optimizers
56 | netG = sngan.SNGANGenerator32().to(device)
57 | netD = sngan.SNGANDiscriminator32().to(device)
58 | optD = optim.Adam(netD.parameters(), 2e-4, betas=(0.0, 0.9))
59 | optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9))
60 |
61 | # Start training
62 | trainer = mmc.training.Trainer(
63 | netD=netD,
64 | netG=netG,
65 | optD=optD,
66 | optG=optG,
67 | n_dis=5,
68 | num_steps=100000,
69 | lr_decay='linear',
70 | dataloader=dataloader,
71 | log_dir='./log/example',
72 | device=device)
73 | trainer.train()
74 |
75 | To evaluate its FID, we can simply run the following:
76 |
77 | .. code-block:: python
78 |
79 | # Evaluate fid
80 | mmc.metrics.evaluate(
81 | metric='fid',
82 | log_dir='./log/example',
83 | netG=netG,
84 | dataset='cifar10',
85 | num_real_samples=50000,
86 | num_fake_samples=50000,
87 | evaluate_step=100000,
88 | device=device)
89 |
90 | Alternatively, one could evaluate FID progressively over an interval by swapping the `evaluate_step` argument for `evaluate_range`:
91 |
92 | .. code-block:: python
93 |
94 | # Evaluate fid
95 | mmc.metrics.evaluate(
96 | metric='fid',
97 | log_dir='./log/example',
98 | netG=netG,
99 | dataset='cifar10',
100 | num_real_samples=50000,
101 | num_fake_samples=50000,
102 | evaluate_range=(5000, 100000, 5000),
103 | device=device)
104 |
105 | We support other datasets and models See `datasets `_ and `nets `_ for more information.
106 |
107 | Visualizations
108 | --------------
109 | Mimicry provides TensorBoard support for visualizing the following:
110 |
111 | - Loss and probability curves for monitoring GAN training
112 | - Randomly generated images for checking diversity.
113 | - Generated images from a fixed set of noise vectors.
114 |
115 | .. code-block:: none
116 |
117 | $ tensorboard --logdir=./log/example
118 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/kwotsin/mimicry
2 |
3 | Mimicry Documentation
4 | =====================
5 |
6 | `Mimicry `_ is a lightweight PyTorch library aimed towards the reproducibility of GAN research.
7 |
8 | Comparing GANs is often difficult - mild differences in implementations and evaluation methodologies can result in huge performance differences. Mimicry aims to resolve this by providing:
9 |
10 | - Standardized implementations of popular GANs that closely reproduce reported scores
11 |
12 | - Baseline scores of GANs trained and evaluated under the same conditions
13 |
14 | - A framework for researchers to focus on implementation of GANs without rewriting most of GAN training boilerplate code, with support for multiple GAN evaluation metrics.
15 |
16 | .. toctree::
17 | :glob:
18 | :maxdepth: 1
19 | :caption: Guides
20 |
21 | guides/introduction
22 | guides/tutorial
23 | guides/baselines
24 |
25 | .. toctree::
26 | :caption: API Reference
27 | :maxdepth: 1
28 |
29 | modules/nets
30 | modules/modules
31 | modules/training
32 | modules/metrics
33 | modules/datasets
34 | modules/utils
35 |
--------------------------------------------------------------------------------
/docs/source/modules/datasets.rst:
--------------------------------------------------------------------------------
1 | torch_mimicry.datasets
2 | ======================
3 |
4 | .. contents:: Contents
5 | :local:
6 |
7 |
8 | Dataset Loaders
9 | ---------------
10 | .. automodule:: torch_mimicry.datasets.data_utils
11 | :members:
12 |
13 |
14 | Image Loaders
15 | -------------
16 | .. automodule:: torch_mimicry.datasets.image_loader
17 | :members:
--------------------------------------------------------------------------------
/docs/source/modules/metrics.rst:
--------------------------------------------------------------------------------
1 | torch_mimicry.metrics
2 | =====================
3 |
4 | .. contents:: Contents
5 | :local:
6 |
7 |
8 | FID
9 | ---
10 | .. automodule:: torch_mimicry.metrics.compute_fid
11 | :members:
12 |
13 | KID
14 | ---
15 | .. automodule:: torch_mimicry.metrics.compute_kid
16 | :members:
17 |
18 | Inception Score
19 | ---------------
20 | .. automodule:: torch_mimicry.metrics.compute_is
21 | :members:
22 |
23 | Metrics
24 | ---------------
25 | .. automodule:: torch_mimicry.metrics.compute_metrics
26 | :members:
27 |
--------------------------------------------------------------------------------
/docs/source/modules/modules.rst:
--------------------------------------------------------------------------------
1 | torch_mimicry.modules
2 | =====================
3 |
4 | .. contents:: Contents
5 | :local:
6 |
7 | Layers
8 | ------
9 | .. automodule:: torch_mimicry.modules.layers
10 | :members:
11 |
12 | Residual Blocks
13 | ---------------
14 | .. automodule:: torch_mimicry.modules.resblocks
15 | :members:
16 |
17 | Losses
18 | ------
19 | .. automodule:: torch_mimicry.modules.losses
20 | :members:
21 |
--------------------------------------------------------------------------------
/docs/source/modules/nets.rst:
--------------------------------------------------------------------------------
1 | torch_mimicry.nets
2 | ==================
3 |
4 | .. contents:: Contents
5 | :local:
6 |
7 | torch_mimicry.nets.basemodel
8 | ----------------------------
9 |
10 | .. automodule:: torch_mimicry.nets.basemodel.basemodel
11 | :members:
12 |
13 | torch_mimicry.nets.gan
14 | ----------------------
15 |
16 | Base Unconditional GAN
17 | ^^^^^^^^^^^^^^^^^^^^^^
18 |
19 | .. automodule:: torch_mimicry.nets.gan.gan
20 | :members:
21 |
22 | Base Conditional GAN
23 | ^^^^^^^^^^^^^^^^^^^^
24 | .. automodule:: torch_mimicry.nets.gan.cgan
25 | :members:
26 |
27 |
28 | torch_mimicry.nets.dcgan
29 | ------------------------
30 | DCGAN CIFAR
31 | ^^^^^^^^^^^
32 | .. automodule:: torch_mimicry.nets.dcgan.dcgan_cifar
33 | :members:
34 |
35 | DCGAN 32
36 | ^^^^^^^^
37 | .. automodule:: torch_mimicry.nets.dcgan.dcgan_32
38 | :members:
39 |
40 | DCGAN 48
41 | ^^^^^^^^
42 | .. automodule:: torch_mimicry.nets.dcgan.dcgan_48
43 | :members:
44 |
45 | DCGAN 64
46 | ^^^^^^^^
47 | .. automodule:: torch_mimicry.nets.dcgan.dcgan_64
48 | :members:
49 |
50 |
51 | DCGAN 128
52 | ^^^^^^^^^
53 | .. automodule:: torch_mimicry.nets.dcgan.dcgan_128
54 | :members:
55 |
56 | DCGAN Base
57 | ^^^^^^^^^^
58 | .. automodule:: torch_mimicry.nets.dcgan.dcgan_base
59 | :members:
60 |
61 | torch_mimicry.nets.wgan_gp
62 | --------------------------
63 | WGAN-GP 32
64 | ^^^^^^^^^^
65 | .. automodule:: torch_mimicry.nets.wgan_gp.wgan_gp_32
66 | :members:
67 |
68 | WGAN-GP 48
69 | ^^^^^^^^^^
70 | .. automodule:: torch_mimicry.nets.wgan_gp.wgan_gp_48
71 | :members:
72 |
73 | WGAN-GP 64
74 | ^^^^^^^^^^
75 | .. automodule:: torch_mimicry.nets.wgan_gp.wgan_gp_64
76 | :members:
77 |
78 | WGAN-GP 128
79 | ^^^^^^^^^^^
80 | .. automodule:: torch_mimicry.nets.wgan_gp.wgan_gp_128
81 | :members:
82 |
83 | WGAN-GP Base
84 | ^^^^^^^^^^^^
85 | .. automodule:: torch_mimicry.nets.wgan_gp.wgan_gp_base
86 | :members:
87 |
88 | torch_mimicry.nets.sngan
89 | ------------------------
90 | SNGAN 32
91 | ^^^^^^^^
92 | .. automodule:: torch_mimicry.nets.sngan.sngan_32
93 | :members:
94 |
95 | SNGAN 48
96 | ^^^^^^^^
97 | .. automodule:: torch_mimicry.nets.sngan.sngan_48
98 | :members:
99 |
100 | SNGAN 64
101 | ^^^^^^^^
102 | .. automodule:: torch_mimicry.nets.sngan.sngan_64
103 | :members:
104 |
105 | SNGAN 128
106 | ^^^^^^^^^
107 | .. automodule:: torch_mimicry.nets.sngan.sngan_128
108 | :members:
109 |
110 | SNGAN Base
111 | ^^^^^^^^^^
112 | .. automodule:: torch_mimicry.nets.sngan.sngan_base
113 | :members:
114 |
115 | torch_mimicry.nets.cgan_pd
116 | --------------------------
117 | CGAN-PD 32
118 | ^^^^^^^^^^
119 | .. automodule:: torch_mimicry.nets.cgan_pd.cgan_pd_32
120 | :members:
121 |
122 | CGAN-PD 128
123 | ^^^^^^^^^^^
124 | .. automodule:: torch_mimicry.nets.cgan_pd.cgan_pd_128
125 | :members:
126 |
127 | CGAN-PD Base
128 | ^^^^^^^^^^^^
129 | .. automodule:: torch_mimicry.nets.cgan_pd.cgan_pd_base
130 | :members:
131 |
132 | torch_mimicry.nets.ssgan
133 | ------------------------
134 | SSGAN 32
135 | ^^^^^^^^
136 | .. automodule:: torch_mimicry.nets.ssgan.ssgan_32
137 | :members:
138 |
139 | SSGAN 48
140 | ^^^^^^^^
141 | .. automodule:: torch_mimicry.nets.ssgan.ssgan_48
142 | :members:
143 |
144 | SSGAN 64
145 | ^^^^^^^^
146 | .. automodule:: torch_mimicry.nets.ssgan.ssgan_64
147 | :members:
148 |
149 | SSGAN 128
150 | ^^^^^^^^^
151 | .. automodule:: torch_mimicry.nets.ssgan.ssgan_128
152 | :members:
153 |
154 | SSGAN Base
155 | ^^^^^^^^^^
156 | .. automodule:: torch_mimicry.nets.ssgan.ssgan_base
157 | :members:
158 |
159 | torch_mimicry.nets.infomax_gan
160 | ------------------------------
161 | InfoMax-GAN 32
162 | ^^^^^^^^^^^^^^
163 | .. automodule:: torch_mimicry.nets.infomax_gan.infomax_gan_32
164 | :members:
165 |
166 | InfoMax-GAN 48
167 | ^^^^^^^^^^^^^^
168 | .. automodule:: torch_mimicry.nets.infomax_gan.infomax_gan_48
169 | :members:
170 |
171 | InfoMax-GAN 64
172 | ^^^^^^^^^^^^^^
173 | .. automodule:: torch_mimicry.nets.infomax_gan.infomax_gan_64
174 | :members:
175 |
176 | InfoMax-GAN 128
177 | ^^^^^^^^^^^^^^^
178 | .. automodule:: torch_mimicry.nets.infomax_gan.infomax_gan_128
179 | :members:
180 |
181 | InfoMax-GAN Base
182 | ^^^^^^^^^^^^^^^^
183 | .. automodule:: torch_mimicry.nets.infomax_gan.infomax_gan_base
184 | :members:
185 |
--------------------------------------------------------------------------------
/docs/source/modules/training.rst:
--------------------------------------------------------------------------------
1 | torch_mimicry.training
2 | ======================
3 |
4 | .. contents:: Contents
5 | :local:
6 |
7 |
8 | Trainer
9 | -------
10 | .. automodule:: torch_mimicry.training.trainer
11 | :members:
12 |
13 |
14 | Logger
15 | ------
16 | .. automodule:: torch_mimicry.training.logger
17 | :members:
18 |
19 | MetricLog
20 | ----------
21 | .. automodule:: torch_mimicry.training.metric_log
22 | :members:
23 |
24 | Scheduler
25 | ---------
26 | .. automodule:: torch_mimicry.training.scheduler
27 | :members:
28 |
--------------------------------------------------------------------------------
/docs/source/modules/utils.rst:
--------------------------------------------------------------------------------
1 | torch_mimicry.utils
2 | ===================
3 |
4 | .. contents:: Contents
5 | :local:
6 |
7 |
8 | Common Utilities
9 | ----------------
10 | .. automodule:: torch_mimicry.utils.common
11 | :members:
--------------------------------------------------------------------------------
/examples/eval_pretrained.py:
--------------------------------------------------------------------------------
1 | """
2 | Example script of evaluating a pretrained generator.
3 | """
4 | import torch
5 | import torch_mimicry as mmc
6 | from torch_mimicry.nets import sngan
7 |
8 | ######################################################
9 | # Computing Metrics with Default Datasets
10 | ######################################################
11 |
12 | # Download cifar10 checkpoint: https://drive.google.com/uc?id=1Gn4ouslRAHq3D7AP_V-T2x8Wi1S1hTXJ&export=download
13 | ckpt_file = "./log/sngan_example/checkpoints/netG/netG_100000_steps.pth"
14 |
15 | # Default variables
16 | log_dir = './examples/example_log'
17 | dataset = 'cifar10'
18 | seed = 0
19 | device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
20 |
21 | # Restore model
22 | netG = sngan.SNGANGenerator32().to(device)
23 | netG.restore_checkpoint(ckpt_file)
24 |
25 | # Metrics with a known/popular dataset.
26 | mmc.metrics.fid_score(num_real_samples=50000,
27 | num_fake_samples=50000,
28 | netG=netG,
29 | seed=seed,
30 | dataset=dataset,
31 | log_dir=log_dir,
32 | device=device)
33 |
34 | mmc.metrics.kid_score(num_samples=50000,
35 | netG=netG,
36 | seed=seed,
37 | dataset=dataset,
38 | log_dir=log_dir,
39 | device=device)
40 |
41 | mmc.metrics.inception_score(num_samples=50000,
42 | netG=netG,
43 | seed=seed,
44 | log_dir=log_dir,
45 | device=device)
46 |
47 | ######################################################
48 | # Computing Metrics with Custom Datasets
49 | ######################################################
50 | """
51 | Simply define a custom dataset as below to compute FID/KID, and define
52 | a stats_file/feat_file to save the cached statistics since we don't know what
53 | name to give your file.
54 | """
55 |
56 |
57 | class CustomDataset(torch.utils.data.Dataset):
58 | def __init__(self):
59 | super().__init__()
60 | self.data = torch.ones(1000, 3, 32, 32)
61 |
62 | def __len__(self):
63 | return self.data.shape[0]
64 |
65 | def __getitem__(self, idx):
66 | return self.data[idx]
67 |
68 |
69 | custom_dataset = CustomDataset()
70 |
71 | # Metrics with a custom dataset.
72 | mmc.metrics.fid_score(num_real_samples=1000,
73 | num_fake_samples=1000,
74 | netG=netG,
75 | seed=seed,
76 | dataset=custom_dataset,
77 | log_dir=log_dir,
78 | device=device,
79 | stats_file='./examples/example_log/fid_stats.npz')
80 |
81 | mmc.metrics.kid_score(num_samples=1000,
82 | netG=netG,
83 | seed=seed,
84 | dataset=custom_dataset,
85 | log_dir=log_dir,
86 | device=device,
87 | feat_file='./examples/example_log/kid_stats.npz')
88 |
89 | # Using the evaluate API, which assumes a more fixed directory.
90 | netG = sngan.SNGANGenerator32().to(device)
91 | mmc.metrics.evaluate(metric='fid',
92 | log_dir='./log/sngan_example/',
93 | netG=netG,
94 | dataset=custom_dataset,
95 | num_real_samples=1000,
96 | num_fake_samples=1000,
97 | evaluate_step=100000,
98 | stats_file='./examples/example_log/fid_stats.npz',
99 | device=device)
--------------------------------------------------------------------------------
/examples/sngan_example.py:
--------------------------------------------------------------------------------
1 | """
2 | Typical usage example.
3 | """
4 |
5 | import torch
6 | import torch.optim as optim
7 | import torch_mimicry as mmc
8 | from torch_mimicry.nets import sngan
9 |
10 | if __name__ == "__main__":
11 | # Data handling objects
12 | device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
13 | dataset = mmc.datasets.load_dataset(root='./datasets', name='cifar10')
14 | dataloader = torch.utils.data.DataLoader(dataset,
15 | batch_size=64,
16 | shuffle=True,
17 | num_workers=4)
18 |
19 | # Define models and optimizers
20 | netG = sngan.SNGANGenerator32().to(device)
21 | netD = sngan.SNGANDiscriminator32().to(device)
22 | optD = optim.Adam(netD.parameters(), 2e-4, betas=(0.0, 0.9))
23 | optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9))
24 |
25 | # Start training
26 | trainer = mmc.training.Trainer(netD=netD,
27 | netG=netG,
28 | optD=optD,
29 | optG=optG,
30 | n_dis=5,
31 | num_steps=30,
32 | lr_decay='linear',
33 | dataloader=dataloader,
34 | log_dir='./log/example',
35 | device=device)
36 | trainer.train()
37 |
38 | # Evaluate fid
39 | mmc.metrics.evaluate(metric='fid',
40 | log_dir='./log/example',
41 | netG=netG,
42 | dataset='cifar10',
43 | num_real_samples=50000,
44 | num_fake_samples=50000,
45 | evaluate_step=30,
46 | device=device)
47 |
48 | # Evaluate kid
49 | mmc.metrics.evaluate(metric='kid',
50 | log_dir='./log/example',
51 | netG=netG,
52 | dataset='cifar10',
53 | num_samples=50000,
54 | evaluate_step=30,
55 | device=device)
56 |
57 | # Evaluate inception score
58 | mmc.metrics.evaluate(metric='inception_score',
59 | log_dir='./log/example',
60 | netG=netG,
61 | num_samples=50000,
62 | evaluate_step=30,
63 | device=device)
64 |
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | filterwarnings =
3 | ignore::DeprecationWarning
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.23.1
2 | scipy==1.9.0
3 | requests==2.28.1
4 | torch==1.12.1
5 | tensorflow==2.9.1
6 | torchvision==0.13.1
7 | six==1.16.0
8 | matplotlib==3.5.2
9 | Pillow==9.2.0
10 | scikit-image==0.19.3
11 | pytest==7.1.2
12 | scikit-learn==1.1.2
13 | future==0.18.2
14 | pytest-cov==3.0.0
15 | pandas==1.4.3
16 | psutil==5.9.1
17 | yapf==0.32.0
18 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file = README.md
3 |
4 | [aliases]
5 | test=pytest
6 |
7 | [tool:pytest]
8 | addopts = --capture=no --cov
9 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | __version__ = '0.1.16'
4 | url = 'https://github.com/kwotsin/mimicry'
5 |
6 | install_requires = [
7 | 'numpy',
8 | 'scipy',
9 | 'requests',
10 | 'torch',
11 | 'tensorflow',
12 | 'torchvision',
13 | 'six',
14 | 'matplotlib',
15 | 'Pillow',
16 | 'scikit-image',
17 | 'pytest',
18 | 'scikit-learn',
19 | 'future',
20 | 'pytest-cov',
21 | 'pandas',
22 | 'psutil',
23 | 'yapf',
24 | 'lmdb',
25 | ]
26 |
27 | setup_requires = ['pytest-runner']
28 | tests_require = ['pytest', 'pytest-cov', 'mock']
29 |
30 | long_description = """
31 | Mimicry is a lightweight PyTorch library aimed towards the reproducibility of GAN research.
32 |
33 | Comparing GANs is often difficult - mild differences in implementations and evaluation methodologies can result in huge performance differences.
34 | Mimicry aims to resolve this by providing:
35 | (a) Standardized implementations of popular GANs that closely reproduce reported scores;
36 | (b) Baseline scores of GANs trained and evaluated under the same conditions;
37 | (c) A framework for researchers to focus on implementation of GANs without rewriting most of GAN training boilerplate code, with support for multiple GAN evaluation metrics.
38 |
39 | We provide a model zoo and set of baselines to benchmark different GANs of the same model size trained under the same conditions, using multiple metrics. To ensure reproducibility, we verify scores of our implemented models against reported scores in literature.
40 | """
41 |
42 | setup(
43 | name='torch_mimicry',
44 | version=__version__,
45 | long_description=long_description,
46 | long_description_content_type='text/markdown',
47 | description='Mimicry: Towards the Reproducibility of GAN Research',
48 | author='Kwot Sin Lee',
49 | author_email='ksl36@cam.ac.uk',
50 | url=url,
51 | download_url='{}/archive/{}.tar.gz'.format(url, __version__),
52 | keywords=[
53 | 'pytorch',
54 | 'generative-adversarial-networks',
55 | 'gans',
56 | 'GAN',
57 | ],
58 | python_requires='>=3.6',
59 | install_requires=install_requires,
60 | setup_requires=setup_requires,
61 | tests_require=tests_require,
62 | packages=find_packages(),
63 | )
64 |
--------------------------------------------------------------------------------
/tests/datasets/imagenet/test.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kwotsin/mimicry/a7fda06c4aff1e6af8dc4c4a35ed6636e434c766/tests/datasets/imagenet/test.bin
--------------------------------------------------------------------------------
/tests/metrics/fid/test_fid.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 | import tensorflow as tf
4 |
5 | from torch_mimicry.metrics.fid import fid_utils
6 | from torch_mimicry.metrics.inception_model import inception_utils
7 |
8 |
9 | class TestFID:
10 | def setup(self):
11 | self.images = np.ones((4, 32, 32, 3))
12 | self.sess = tf.compat.v1.Session()
13 |
14 | def test_calculate_activation_statistics(self):
15 | inception_path = './metrics/inception_model'
16 | inception_utils.create_inception_graph(inception_path)
17 |
18 | mu, sigma = fid_utils.calculate_activation_statistics(
19 | images=self.images, sess=self.sess)
20 |
21 | assert mu.shape == (2048, )
22 | assert sigma.shape == (2048, 2048)
23 |
24 | def test_calculate_frechet_distance(self):
25 | mu1, sigma1 = np.ones((16, )), np.ones((16, 16))
26 | mu2, sigma2 = mu1 * 2, sigma1 * 2
27 |
28 | score = fid_utils.calculate_frechet_distance(mu1=mu1,
29 | mu2=mu2,
30 | sigma1=sigma1,
31 | sigma2=sigma2)
32 |
33 | assert type(score) == np.float64
34 |
35 | # Inputs check
36 | bad_mu2, bad_sigma2 = np.ones((15, 15)), np.ones((15, 15))
37 | with pytest.raises(ValueError):
38 | fid_utils.calculate_frechet_distance(mu1=mu1,
39 | mu2=bad_mu2,
40 | sigma1=sigma1,
41 | sigma2=bad_sigma2)
42 |
43 | def teardown(self):
44 | del self.images
45 | self.sess.close()
46 |
47 |
48 | if __name__ == "__main__":
49 | test = TestFID()
50 | test.setup()
51 | test.test_calculate_activation_statistics()
52 | test.test_calculate_frechet_distance()
53 | test.teardown()
54 |
--------------------------------------------------------------------------------
/tests/metrics/inception_model/test_inception_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | from torch_mimicry.metrics.inception_model import inception_utils
5 |
6 |
7 | class TestInceptionUtils:
8 | def test_get_activations(self):
9 | for inception_path in ['./metrics/inception_model', None]:
10 | inception_utils.create_inception_graph(inception_path)
11 |
12 | images = np.ones((4, 32, 32, 3))
13 | with tf.compat.v1.Session() as sess:
14 | feat = inception_utils.get_activations(images=images,
15 | sess=sess)
16 |
17 | assert feat.shape == (4, 2048)
18 |
19 |
20 | if __name__ == "__main__":
21 | test = TestInceptionUtils()
22 | test.test_get_activations()
23 |
--------------------------------------------------------------------------------
/tests/metrics/inception_score/test_inception_score.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 | import torch
4 |
5 | from torch_mimicry.metrics.inception_model import inception_utils
6 | from torch_mimicry.metrics.inception_score import inception_score_utils
7 |
8 |
9 | class TestInceptionScore:
10 | def test_get_predictions(self):
11 | inception_utils.create_inception_graph('./metrics/inception_model')
12 |
13 | images = np.ones((4, 32, 32, 3))
14 | preds = inception_score_utils.get_predictions(images)
15 | assert preds.shape == (4, 1008)
16 |
17 | preds = inception_score_utils.get_predictions(
18 | images, device=torch.device('cpu'))
19 | assert preds.shape == (4, 1008)
20 |
21 | def test_get_inception_score(self):
22 | images = np.ones((4, 32, 32, 3))
23 | mean, std = inception_score_utils.get_inception_score(images)
24 |
25 | assert type(mean) == float
26 | assert type(std) == float
27 |
28 | with pytest.raises(ValueError):
29 | images *= -1
30 | inception_score_utils.get_inception_score(images)
31 |
32 |
33 | if __name__ == "__main__":
34 | test = TestInceptionScore()
35 | test.test_get_predictions()
36 | test.test_get_inception_score()
37 |
--------------------------------------------------------------------------------
/tests/metrics/kid/test_kid.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 | from sklearn.metrics.pairwise import polynomial_kernel
4 |
5 | from torch_mimicry.metrics.kid import kid_utils
6 |
7 |
8 | class TestKID:
9 | def setup(self):
10 | self.codes_g = np.ones((4, 16))
11 | self.codes_r = np.ones((4, 16))
12 |
13 | def test_polynomial_mmd(self):
14 | score = kid_utils.polynomial_mmd(codes_g=self.codes_g,
15 | codes_r=self.codes_r)
16 |
17 | assert type(score) == np.float64
18 | assert score < 1e-5
19 |
20 | def test_polynomial_mmd_averages(self):
21 |
22 | scores = kid_utils.polynomial_mmd_averages(codes_g=self.codes_g,
23 | codes_r=self.codes_r,
24 | n_subsets=4,
25 | subset_size=1)
26 |
27 | assert len(scores) == 4
28 | assert type(scores[0]) == np.float64
29 |
30 | def test_compute_mmd2(self):
31 | X = self.codes_g
32 | Y = self.codes_r
33 | K_XX = polynomial_kernel(X)
34 | K_YY = polynomial_kernel(Y)
35 | K_XY = polynomial_kernel(X, Y)
36 |
37 | mmd_est_args = ['u-statistic', 'unbiased']
38 |
39 | for mmd_est in mmd_est_args:
40 | for unit_diagonal in [True, False]:
41 | mmd2_score = kid_utils._compute_mmd2(
42 | K_XX=K_XX,
43 | K_YY=K_YY,
44 | K_XY=K_XY,
45 | mmd_est=mmd_est,
46 | unit_diagonal=unit_diagonal)
47 |
48 | assert type(mmd2_score) == np.float64
49 |
50 | # Input checks
51 | with pytest.raises(ValueError):
52 | kid_utils._compute_mmd2(K_XX=K_XX,
53 | K_YY=K_YY,
54 | K_XY=K_XY,
55 | mmd_est='invalid_option',
56 | unit_diagonal=unit_diagonal)
57 |
58 | m = K_XX.shape[0]
59 | with pytest.raises(ValueError):
60 | bad_K_XX = np.ones((m + 1, m + 1))
61 | kid_utils._compute_mmd2(K_XX=bad_K_XX,
62 | K_YY=K_YY,
63 | K_XY=K_XY,
64 | mmd_est='unbiased',
65 | unit_diagonal=unit_diagonal)
66 |
67 | with pytest.raises(ValueError):
68 | bad_K_YY = np.ones((m + 1, m + 1))
69 | kid_utils._compute_mmd2(K_XX=K_XX,
70 | K_YY=bad_K_YY,
71 | K_XY=K_XY,
72 | mmd_est='unbiased',
73 | unit_diagonal=unit_diagonal)
74 |
75 | with pytest.raises(ValueError):
76 | bad_K_XY = np.ones((m + 1, m + 1))
77 | kid_utils._compute_mmd2(K_XX=K_XX,
78 | K_YY=K_YY,
79 | K_XY=bad_K_XY,
80 | mmd_est='unbiased',
81 | unit_diagonal=unit_diagonal)
82 |
83 | def teardown(self):
84 | del self.codes_g
85 | del self.codes_r
86 |
87 |
88 | if __name__ == "__main__":
89 | test = TestKID()
90 | test.setup()
91 | test.test_polynomial_mmd()
92 | test.test_polynomial_mmd_averages()
93 | test.test_compute_mmd2()
94 | test.teardown()
95 |
--------------------------------------------------------------------------------
/tests/metrics/test_compute_is.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_mimicry.metrics import compute_is
4 | from torch_mimicry.nets.gan import gan
5 |
6 |
7 | class ExampleGen(gan.BaseGenerator):
8 | def __init__(self,
9 | bottom_width=4,
10 | nz=4,
11 | ngf=256,
12 | loss_type='gan',
13 | *args,
14 | **kwargs):
15 | super().__init__(nz=nz,
16 | ngf=ngf,
17 | bottom_width=bottom_width,
18 | loss_type=loss_type,
19 | *args,
20 | **kwargs)
21 |
22 | def forward(self, x):
23 | output = torch.ones(x.shape[0], 3, 32, 32)
24 |
25 | return output
26 |
27 |
28 | class TestComputeIS:
29 | def setup(self):
30 | self.device = torch.device('cpu')
31 | self.netG = ExampleGen()
32 |
33 | def test_compute_inception_score(self):
34 | mean, std = compute_is.inception_score(netG=self.netG,
35 | device=self.device,
36 | num_samples=10,
37 | batch_size=10)
38 |
39 | assert type(mean) == float
40 | assert type(std) == float
41 |
42 | def teardown(self):
43 | del self.netG
44 |
45 |
46 | if __name__ == "__main__":
47 | test = TestComputeIS()
48 | test.setup()
49 | test.test_compute_inception_score()
50 |
--------------------------------------------------------------------------------
/tests/modules/test_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from torch_mimicry.modules import layers
5 |
6 |
7 | class TestLayers:
8 | def setup(self):
9 | self.N, self.C, self.H, self.W = (4, 3, 32, 32)
10 | self.n_out = 8
11 |
12 | def test_ConditionalBatchNorm2d(self):
13 | num_classes = 10
14 | X = torch.ones(self.N, self.C, self.H, self.W)
15 | y = torch.randint(low=0, high=num_classes, size=(self.N, ))
16 |
17 | # Setup cond. BN --> Note: because we usually do
18 | # BN-ReLU-Conv, we take num feat as input channels
19 | conv = nn.Conv2d(self.C, self.n_out, 1, 1, 0)
20 | bn = layers.ConditionalBatchNorm2d(num_features=self.C,
21 | num_classes=num_classes)
22 |
23 | output = bn(X, y)
24 | output = conv(output)
25 |
26 | assert output.shape == (self.N, self.n_out, self.H, self.W)
27 |
28 | def test_SelfAttention(self):
29 | ngf = 16
30 | X = torch.randn(self.N, ngf, 16, 16)
31 | for spectral_norm in [True, False]:
32 | sa_layer = layers.SelfAttention(ngf, spectral_norm=spectral_norm)
33 |
34 | assert sa_layer(X).shape == (self.N, ngf, 16, 16)
35 |
36 | def test_SNConv2d(self):
37 | X = torch.ones(self.N, self.C, self.H, self.W)
38 | for default in [True, False]:
39 | layer = layers.SNConv2d(self.C,
40 | self.n_out,
41 | 1,
42 | 1,
43 | 0,
44 | default=default)
45 |
46 | assert layer(X).shape == (self.N, self.n_out, self.H, self.W)
47 |
48 | def test_SNLinear(self):
49 | X = torch.ones(self.N, self.C)
50 | for default in [True, False]:
51 | layer = layers.SNLinear(self.C, self.n_out, default=default)
52 |
53 | assert layer(X).shape == (self.N, self.n_out)
54 |
55 | def test_SNEmbedding(self):
56 | num_classes = 10
57 | X = torch.ones(self.N, dtype=torch.int64)
58 | for default in [True, False]:
59 | layer = layers.SNEmbedding(num_classes,
60 | self.n_out,
61 | default=default)
62 |
63 | assert layer(X).shape == (self.N, self.n_out)
64 |
65 |
66 | if __name__ == "__main__":
67 | test = TestLayers()
68 | test.setup()
69 | test.test_ConditionalBatchNorm2d()
70 | test.test_SelfAttention()
71 | test.test_SNConv2d()
72 | test.test_SNLinear()
73 | test.test_SNEmbedding()
74 |
--------------------------------------------------------------------------------
/tests/modules/test_losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_mimicry.modules import losses
4 |
5 |
6 | class TestLosses:
7 | def setup(self):
8 | self.output_real = torch.ones(4, 1)
9 | self.output_fake = torch.ones(4, 1)
10 | self.device = torch.device('cpu')
11 |
12 | def test_minimax_loss(self):
13 | loss_gen = losses.minimax_loss_gen(output_fake=self.output_fake)
14 |
15 | loss_dis = losses.minimax_loss_dis(output_fake=self.output_fake,
16 | output_real=self.output_real)
17 |
18 | assert loss_gen.dtype == torch.float32
19 | assert loss_dis.dtype == torch.float32
20 | assert loss_gen.item() - 0.3133 < 1e-2
21 | assert loss_dis.item() - 1.6265 < 1e-2
22 |
23 | def test_ns_loss(self):
24 | loss_gen = losses.ns_loss_gen(self.output_fake)
25 |
26 | assert loss_gen.dtype == torch.float32
27 | assert loss_gen.item() - 0.3133 < 1e-2
28 |
29 | def test_wasserstein_loss(self):
30 | loss_gen = losses.wasserstein_loss_gen(self.output_fake)
31 | loss_dis = losses.wasserstein_loss_dis(output_real=self.output_real,
32 | output_fake=self.output_fake)
33 |
34 | assert loss_gen.dtype == torch.float32
35 | assert loss_dis.dtype == torch.float32
36 | assert loss_gen.item() + 1.0 < 1e-2
37 | assert loss_dis.item() < 1e-2
38 |
39 | def test_hinge_loss(self):
40 | loss_gen = losses.hinge_loss_gen(output_fake=self.output_fake)
41 | loss_dis = losses.hinge_loss_dis(output_fake=self.output_fake,
42 | output_real=self.output_real)
43 |
44 | assert loss_gen.dtype == torch.float32
45 | assert loss_dis.dtype == torch.float32
46 | assert loss_gen.item() + 1.0 < 1e-2
47 | assert loss_dis.item() - 2.0 < 1e-2
48 |
49 | def teardown(self):
50 | del self.output_real
51 | del self.output_fake
52 | del self.device
53 |
54 |
55 | if __name__ == "__main__":
56 | test = TestLosses()
57 | test.setup()
58 | test.test_minimax_loss()
59 | test.test_ns_loss()
60 | test.test_wasserstein_loss()
61 | test.test_hinge_loss()
62 | test.teardown()
63 |
--------------------------------------------------------------------------------
/tests/modules/test_resblocks.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 |
3 | import torch
4 |
5 | from torch_mimicry.modules import resblocks
6 |
7 |
8 | class TestResBlocks:
9 | def setup(self):
10 | self.images = torch.ones(4, 3, 16, 16)
11 |
12 | def test_GBlock(self):
13 | # Arguments
14 | num_classes_list = [0, 10]
15 | spectral_norm_list = [True, False]
16 | in_channels = 3
17 | out_channels = 8
18 | args_comb = product(num_classes_list, spectral_norm_list)
19 |
20 | for args in args_comb:
21 | num_classes = args[0]
22 | spectral_norm = args[1]
23 |
24 | if num_classes > 0:
25 | y = torch.ones((4, ), dtype=torch.int64)
26 | else:
27 | y = None
28 |
29 | gen_block_up = resblocks.GBlock(in_channels=in_channels,
30 | out_channels=out_channels,
31 | upsample=True,
32 | num_classes=num_classes,
33 | spectral_norm=spectral_norm)
34 |
35 | gen_block = resblocks.GBlock(in_channels=in_channels,
36 | out_channels=out_channels,
37 | upsample=False,
38 | num_classes=num_classes,
39 | spectral_norm=spectral_norm)
40 |
41 | gen_block_no_sc = resblocks.GBlock(in_channels=in_channels,
42 | out_channels=in_channels,
43 | upsample=False,
44 | num_classes=num_classes,
45 | spectral_norm=spectral_norm)
46 |
47 | assert gen_block_up(self.images, y).shape == (4, 8, 32, 32)
48 | assert gen_block(self.images, y).shape == (4, 8, 16, 16)
49 | assert gen_block_no_sc(self.images, y).shape == (4, 3, 16, 16)
50 |
51 | def test_DBlocks(self):
52 | in_channels = 3
53 | out_channels = 8
54 |
55 | for spectral_norm in [True, False]:
56 | dis_block_down = resblocks.DBlock(in_channels=in_channels,
57 | out_channels=out_channels,
58 | downsample=True,
59 | spectral_norm=spectral_norm)
60 |
61 | dis_block = resblocks.DBlock(in_channels=in_channels,
62 | out_channels=out_channels,
63 | downsample=False,
64 | spectral_norm=spectral_norm)
65 |
66 | dis_block_opt = resblocks.DBlockOptimized(
67 | in_channels=in_channels,
68 | out_channels=out_channels,
69 | spectral_norm=spectral_norm)
70 |
71 | assert dis_block(self.images).shape == (4, out_channels, 16, 16)
72 | assert dis_block_down(self.images).shape == (4, out_channels, 8, 8)
73 | assert dis_block_opt(self.images).shape == (4, out_channels, 8, 8)
74 |
75 | def teardown(self):
76 | del self.images
77 |
78 |
79 | if __name__ == "__main__":
80 | test = TestResBlocks()
81 | test.setup()
82 | test.test_GBlock()
83 | test.test_DBlocks()
84 | test.teardown()
85 |
--------------------------------------------------------------------------------
/tests/modules/test_spectral_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from torch_mimicry.modules import spectral_norm
5 |
6 |
7 | class TestSpectralNorm:
8 | def setup(self):
9 | torch.manual_seed(0)
10 | self.N, self.C, self.H, self.W = (32, 16, 32, 32)
11 | self.n_in = self.C
12 | self.n_out = 32
13 |
14 | def test_SNConv2d(self):
15 | conv = spectral_norm.SNConv2d(self.n_in, self.n_out, 1, 1, 0)
16 | conv_def = nn.utils.spectral_norm(
17 | nn.Conv2d(self.n_in, self.n_out, 1, 1, 0))
18 |
19 | # Init with ones to test implementation without randomness.
20 | nn.init.ones_(conv.weight.data)
21 | nn.init.ones_(conv_def.weight.data)
22 |
23 | # Get outputs
24 | X = torch.ones(self.N, self.C, self.H, self.W)
25 | output = conv(X)
26 | output_def = conv_def(X)
27 |
28 | # Test valid shape
29 | assert output.shape == output_def.shape == (32, 32, 32, 32)
30 |
31 | # Test per element it is very close to default implementation
32 | # to preserve correctness even when user toggles b/w implementations
33 | assert abs(torch.mean(output_def) - torch.mean(output)) < 1
34 |
35 | def test_SNLinear(self):
36 | linear = spectral_norm.SNLinear(self.n_in, self.n_out)
37 | linear_def = nn.utils.spectral_norm(nn.Linear(self.n_in, self.n_out))
38 |
39 | nn.init.ones_(linear.weight.data)
40 | nn.init.ones_(linear_def.weight.data)
41 |
42 | X = torch.ones(self.N, self.n_in)
43 | output = linear(X)
44 | output_def = linear_def(X)
45 |
46 | assert output.shape == output_def.shape == (32, 32)
47 | assert abs(torch.mean(output_def) - torch.mean(output)) < 1
48 |
49 | def test_SNEmbedding(self):
50 | embedding = spectral_norm.SNEmbedding(self.N, self.n_out)
51 | embedding_def = nn.utils.spectral_norm(nn.Embedding(
52 | self.N, self.n_out))
53 |
54 | nn.init.ones_(embedding.weight.data)
55 | nn.init.ones_(embedding_def.weight.data)
56 |
57 | X = torch.ones(self.N, dtype=torch.int64)
58 | output = embedding(X)
59 | output_def = embedding_def(X)
60 |
61 | assert output.shape == output_def.shape == (32, 32)
62 | assert abs(torch.mean(output_def) - torch.mean(output)) < 1
63 |
64 |
65 | if __name__ == "__main__":
66 | test = TestSpectralNorm()
67 | test.setup()
68 | test.test_SNConv2d()
69 | test.test_SNLinear()
70 | test.test_SNEmbedding()
71 |
--------------------------------------------------------------------------------
/tests/nets/basemodel/test_basemodel.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import pytest
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 |
9 | from torch_mimicry.nets.basemodel.basemodel import BaseModel
10 |
11 |
12 | class ExampleModel(BaseModel):
13 | def __init__(self, *args, **kwargs):
14 | super().__init__(*args, **kwargs)
15 | self.linear = nn.Linear(1, 4)
16 | nn.init.xavier_uniform_(self.linear.weight.data)
17 |
18 | def forward(self, x):
19 | return
20 |
21 |
22 | class TestBaseModel:
23 | def setup(self):
24 | self.model = ExampleModel()
25 | self.opt = optim.Adam(self.model.parameters(), 2e-4, betas=(0.0, 0.9))
26 | self.global_step = 0
27 |
28 | self.log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
29 | "test_log")
30 |
31 | def test_save_and_restore_checkpoint(self):
32 | ckpt_dir = os.path.join(self.log_dir, 'checkpoints/model')
33 | ckpt_file = os.path.join(ckpt_dir,
34 | "model_{}_steps.pth".format(self.global_step))
35 |
36 | self.model.save_checkpoint(directory=ckpt_dir,
37 | optimizer=self.opt,
38 | global_step=self.global_step)
39 |
40 | restored_model = ExampleModel()
41 | restored_opt = optim.Adam(self.model.parameters(),
42 | 2e-4,
43 | betas=(0.0, 0.9))
44 |
45 | restored_model.restore_checkpoint(ckpt_file=ckpt_file,
46 | optimizer=self.opt)
47 |
48 | # Check weights are preserved
49 | assert all(
50 | (restored_model.linear.weight == self.model.linear.weight) == 1)
51 |
52 | with pytest.raises(ValueError):
53 | restored_model.restore_checkpoint(ckpt_file=None,
54 | optimizer=self.opt)
55 |
56 | # Check optimizers have same state dict
57 | assert self.opt.state_dict() == restored_opt.state_dict()
58 |
59 | def test_count_params(self):
60 | num_total_params, num_trainable_params = self.model.count_params()
61 |
62 | assert num_trainable_params == num_total_params == 8
63 |
64 | def test_get_device(self):
65 | assert type(self.model.device) == torch.device
66 |
67 | def teardown(self):
68 | if os.path.exists(self.log_dir):
69 | shutil.rmtree(self.log_dir)
70 |
71 | del self.model
72 | del self.opt
73 |
74 |
75 | if __name__ == "__main__":
76 | test = TestBaseModel()
77 | test.setup()
78 | test.test_save_and_restore_checkpoint()
79 | test.test_count_params()
80 | test.test_get_device()
81 | test.teardown()
82 |
--------------------------------------------------------------------------------
/tests/nets/cgan_pd/test_cgan_pd_128.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for cGAN-PD for image size 128.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.cgan_pd.cgan_pd_128 import CGANPDGenerator128, CGANPDDiscriminator128
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestCGANPD128:
13 | def setup(self):
14 | self.num_classes = 10
15 | self.nz = 128
16 | self.N, self.C, self.H, self.W = (4, 3, 128, 128)
17 |
18 | self.noise = torch.ones(self.N, self.nz)
19 | self.images = torch.ones(self.N, self.C, self.H, self.W)
20 | self.Y = torch.randint(low=0, high=self.num_classes, size=(self.N, ))
21 |
22 | self.ngf = 16
23 | self.ndf = 16
24 |
25 | self.netG = CGANPDGenerator128(num_classes=self.num_classes,
26 | ngf=self.ngf)
27 | self.netD = CGANPDDiscriminator128(num_classes=self.num_classes,
28 | ndf=self.ndf)
29 |
30 | def test_CGANPDGenerator128(self):
31 | images = self.netG(self.noise, self.Y)
32 | assert images.shape == (self.N, self.C, self.H, self.W)
33 |
34 | images = self.netG(self.noise, None)
35 | assert images.shape == (self.N, self.C, self.H, self.W)
36 |
37 | def test_CGANPDDiscriminator128(self):
38 | output = self.netD(self.images, self.Y)
39 |
40 | assert output.shape == (self.N, 1)
41 |
42 | def test_train_steps(self):
43 | real_batch = common.load_images(self.N)
44 |
45 | # Setup optimizers
46 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
47 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
48 |
49 | # Log statistics to check
50 | log_data = metric_log.MetricLog()
51 |
52 | # Test D train step
53 | log_data = self.netD.train_step(real_batch=real_batch,
54 | netG=self.netG,
55 | optD=optD,
56 | device='cpu',
57 | log_data=log_data)
58 |
59 | log_data = self.netG.train_step(real_batch=real_batch,
60 | netD=self.netD,
61 | optG=optG,
62 | log_data=log_data,
63 | device='cpu')
64 |
65 | for name, metric_dict in log_data.items():
66 | assert type(name) == str
67 | assert type(metric_dict['value']) == float
68 |
69 | def teardown(self):
70 | del self.noise
71 | del self.images
72 | del self.Y
73 | del self.netG
74 | del self.netD
75 |
76 |
77 | if __name__ == "__main__":
78 | test = TestCGANPD128()
79 | test.setup()
80 | test.test_CGANPDGenerator128()
81 | test.test_CGANPDDiscriminator128()
82 | test.test_train_steps()
83 | test.teardown()
84 |
--------------------------------------------------------------------------------
/tests/nets/cgan_pd/test_cgan_pd_32.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for cGAN-PD for image size 32.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.cgan_pd.cgan_pd_32 import CGANPDGenerator32, CGANPDDiscriminator32
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestCGANPD32:
13 | def setup(self):
14 | self.num_classes = 10
15 | self.nz = 128
16 | self.N, self.C, self.H, self.W = (4, 3, 32, 32)
17 |
18 | self.noise = torch.ones(self.N, self.nz)
19 | self.images = torch.ones(self.N, self.C, self.H, self.W)
20 | self.Y = torch.randint(low=0, high=self.num_classes, size=(self.N, ))
21 |
22 | self.ngf = 16
23 | self.ndf = 16
24 |
25 | self.netG = CGANPDGenerator32(num_classes=self.num_classes,
26 | ngf=self.ngf)
27 | self.netD = CGANPDDiscriminator32(num_classes=self.num_classes,
28 | ndf=self.ndf)
29 |
30 | def test_CGANPDGenerator32(self):
31 | images = self.netG(self.noise, self.Y)
32 | assert images.shape == (self.N, self.C, self.H, self.W)
33 |
34 | images = self.netG(self.noise, None)
35 | assert images.shape == (self.N, self.C, self.H, self.W)
36 |
37 | def test_CGANPDDiscriminator32(self):
38 | output = self.netD(self.images, self.Y)
39 | assert output.shape == (self.N, 1)
40 |
41 | def test_train_steps(self):
42 | real_batch = common.load_images(self.N)
43 |
44 | # Setup optimizers
45 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
46 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
47 |
48 | # Log statistics to check
49 | log_data = metric_log.MetricLog()
50 |
51 | # Test D train step
52 | log_data = self.netD.train_step(real_batch=real_batch,
53 | netG=self.netG,
54 | optD=optD,
55 | device='cpu',
56 | log_data=log_data)
57 |
58 | log_data = self.netG.train_step(real_batch=real_batch,
59 | netD=self.netD,
60 | optG=optG,
61 | log_data=log_data,
62 | device='cpu')
63 |
64 | for name, metric_dict in log_data.items():
65 | assert type(name) == str
66 | assert type(metric_dict['value']) == float
67 |
68 | def teardown(self):
69 | del self.noise
70 | del self.images
71 | del self.Y
72 | del self.netG
73 | del self.netD
74 |
75 |
76 | if __name__ == "__main__":
77 | test = TestCGANPD32()
78 | test.setup()
79 | test.test_CGANPDGenerator32()
80 | test.test_CGANPDDiscriminator32()
81 | test.test_train_steps()
82 | test.teardown()
83 |
--------------------------------------------------------------------------------
/tests/nets/dcgan/test_dcgan_128.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for DCGAN for image size 128.
3 |
4 | """
5 | import torch
6 | import torch.optim as optim
7 |
8 | from torch_mimicry.nets.dcgan.dcgan_128 import DCGANGenerator128, DCGANDiscriminator128
9 | from torch_mimicry.training import metric_log
10 | from torch_mimicry.utils import common
11 |
12 |
13 | class TestDCGAN128:
14 | def setup(self):
15 | self.nz = 128
16 | self.N, self.C, self.H, self.W = (8, 3, 128, 128)
17 | self.ngf = 16
18 | self.ndf = 16
19 |
20 | self.netG = DCGANGenerator128(ngf=self.ngf)
21 | self.netD = DCGANDiscriminator128(ndf=self.ndf)
22 |
23 | def test_DCGANGenerator128(self):
24 | noise = torch.ones(self.N, self.nz)
25 | output = self.netG(noise)
26 |
27 | assert output.shape == (self.N, self.C, self.H, self.W)
28 |
29 | def test_DCGANDiscriminator128(self):
30 | images = torch.ones(self.N, self.C, self.H, self.W)
31 | output = self.netD(images)
32 |
33 | assert output.shape == (self.N, 1)
34 |
35 | def test_train_steps(self):
36 | real_batch = common.load_images(self.N, size=self.H)
37 |
38 | # Setup optimizers
39 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
40 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
41 |
42 | # Log statistics to check
43 | log_data = metric_log.MetricLog()
44 |
45 | # Test D train step
46 | log_data = self.netD.train_step(real_batch=real_batch,
47 | netG=self.netG,
48 | optD=optD,
49 | device='cpu',
50 | log_data=log_data)
51 |
52 | log_data = self.netG.train_step(real_batch=real_batch,
53 | netD=self.netD,
54 | optG=optG,
55 | log_data=log_data,
56 | device='cpu')
57 |
58 | for name, metric_dict in log_data.items():
59 | assert type(name) == str
60 | assert type(metric_dict['value']) == float
61 |
62 | def teardown(self):
63 | del self.netG
64 | del self.netD
65 |
66 |
67 | if __name__ == "__main__":
68 | test = TestDCGAN128()
69 | test.setup()
70 | test.test_DCGANGenerator128()
71 | test.test_DCGANDiscriminator128()
72 | test.test_train_steps()
73 | test.teardown()
74 |
--------------------------------------------------------------------------------
/tests/nets/dcgan/test_dcgan_32.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for DCGAN for image size 32.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.dcgan.dcgan_32 import DCGANGenerator32, DCGANDiscriminator32
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestDCGAN32:
13 | def setup(self):
14 | self.nz = 128
15 | self.N, self.C, self.H, self.W = (8, 3, 32, 32)
16 | self.ngf = 16
17 | self.ndf = 16
18 |
19 | self.netG = DCGANGenerator32(ngf=self.ngf)
20 | self.netD = DCGANDiscriminator32(ndf=self.ndf)
21 |
22 | def test_DCGANGenerator32(self):
23 | noise = torch.ones(self.N, self.nz)
24 | output = self.netG(noise)
25 |
26 | assert output.shape == (self.N, self.C, self.H, self.W)
27 |
28 | def test_DCGANDiscriminator32(self):
29 | images = torch.ones(self.N, self.C, self.H, self.W)
30 | output = self.netD(images)
31 |
32 | assert output.shape == (self.N, 1)
33 |
34 | def test_train_steps(self):
35 | real_batch = common.load_images(self.N, size=self.H)
36 |
37 | # Setup optimizers
38 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
39 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
40 |
41 | # Log statistics to check
42 | log_data = metric_log.MetricLog()
43 |
44 | # Test D train step
45 | log_data = self.netD.train_step(real_batch=real_batch,
46 | netG=self.netG,
47 | optD=optD,
48 | device='cpu',
49 | log_data=log_data)
50 |
51 | log_data = self.netG.train_step(real_batch=real_batch,
52 | netD=self.netD,
53 | optG=optG,
54 | log_data=log_data,
55 | device='cpu')
56 |
57 | for name, metric_dict in log_data.items():
58 | assert type(name) == str
59 | assert type(metric_dict['value']) == float
60 |
61 | def teardown(self):
62 | del self.netG
63 | del self.netD
64 |
65 |
66 | if __name__ == "__main__":
67 | test = TestDCGAN32()
68 | test.setup()
69 | test.test_DCGANGenerator32()
70 | test.test_DCGANDiscriminator32()
71 | test.test_train_steps()
72 | test.teardown()
73 |
--------------------------------------------------------------------------------
/tests/nets/dcgan/test_dcgan_48.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for DCGAN for image size 48.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.dcgan.dcgan_48 import DCGANGenerator48, DCGANDiscriminator48
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestDCGAN48:
13 | def setup(self):
14 | self.nz = 128
15 | self.N, self.C, self.H, self.W = (8, 3, 48, 48)
16 | self.ngf = 16
17 | self.ndf = 16
18 |
19 | self.netG = DCGANGenerator48(ngf=self.ngf)
20 | self.netD = DCGANDiscriminator48(ndf=self.ndf)
21 |
22 | def test_DCGANGenerator48(self):
23 | noise = torch.ones(self.N, self.nz)
24 | output = self.netG(noise)
25 |
26 | assert output.shape == (self.N, self.C, self.H, self.W)
27 |
28 | def test_DCGANDiscriminator48(self):
29 | images = torch.ones(self.N, self.C, self.H, self.W)
30 | output = self.netD(images)
31 |
32 | assert output.shape == (self.N, 1)
33 |
34 | def test_train_steps(self):
35 | real_batch = common.load_images(self.N, size=self.H)
36 |
37 | # Setup optimizers
38 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
39 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
40 |
41 | # Log statistics to check
42 | log_data = metric_log.MetricLog()
43 |
44 | # Test D train step
45 | log_data = self.netD.train_step(real_batch=real_batch,
46 | netG=self.netG,
47 | optD=optD,
48 | device='cpu',
49 | log_data=log_data)
50 |
51 | log_data = self.netG.train_step(real_batch=real_batch,
52 | netD=self.netD,
53 | optG=optG,
54 | log_data=log_data,
55 | device='cpu')
56 |
57 | for name, metric_dict in log_data.items():
58 | assert type(name) == str
59 | assert type(metric_dict['value']) == float
60 |
61 | def teardown(self):
62 | del self.netG
63 | del self.netD
64 |
65 |
66 | if __name__ == "__main__":
67 | test = TestDCGAN48()
68 | test.setup()
69 | test.test_DCGANGenerator48()
70 | test.test_DCGANDiscriminator48()
71 | test.test_train_steps()
72 | test.teardown()
73 |
--------------------------------------------------------------------------------
/tests/nets/dcgan/test_dcgan_64.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for DCGAN for image size 64.
3 |
4 | """
5 | import torch
6 | import torch.optim as optim
7 |
8 | from torch_mimicry.nets.dcgan.dcgan_64 import DCGANGenerator64, DCGANDiscriminator64
9 | from torch_mimicry.training import metric_log
10 | from torch_mimicry.utils import common
11 |
12 |
13 | class TestDCGAN64:
14 | def setup(self):
15 | self.nz = 128
16 | self.N, self.C, self.H, self.W = (8, 3, 64, 64)
17 | self.ngf = 16
18 | self.ndf = 16
19 |
20 | self.netG = DCGANGenerator64(ngf=self.ngf)
21 | self.netD = DCGANDiscriminator64(ndf=self.ndf)
22 |
23 | def test_DCGANGenerator64(self):
24 | noise = torch.ones(self.N, self.nz)
25 | output = self.netG(noise)
26 |
27 | assert output.shape == (self.N, self.C, self.H, self.W)
28 |
29 | def test_DCGANDiscriminator64(self):
30 | images = torch.ones(self.N, self.C, self.H, self.W)
31 | output = self.netD(images)
32 |
33 | assert output.shape == (self.N, 1)
34 |
35 | def test_train_steps(self):
36 | real_batch = common.load_images(self.N, size=self.H)
37 |
38 | # Setup optimizers
39 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
40 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
41 |
42 | # Log statistics to check
43 | log_data = metric_log.MetricLog()
44 |
45 | # Test D train step
46 | log_data = self.netD.train_step(real_batch=real_batch,
47 | netG=self.netG,
48 | optD=optD,
49 | device='cpu',
50 | log_data=log_data)
51 |
52 | log_data = self.netG.train_step(real_batch=real_batch,
53 | netD=self.netD,
54 | optG=optG,
55 | log_data=log_data,
56 | device='cpu')
57 |
58 | for name, metric_dict in log_data.items():
59 | assert type(name) == str
60 | assert type(metric_dict['value']) == float
61 |
62 | def teardown(self):
63 | del self.netG
64 | del self.netD
65 |
66 |
67 | if __name__ == "__main__":
68 | test = TestDCGAN64()
69 | test.setup()
70 | test.test_DCGANGenerator64()
71 | test.test_DCGANDiscriminator64()
72 | test.test_train_steps()
73 | test.teardown()
74 |
--------------------------------------------------------------------------------
/tests/nets/dcgan/test_dcgan_cifar.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for DCGAN for image size 32.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.dcgan.dcgan_cifar import DCGANGeneratorCIFAR, DCGANDiscriminatorCIFAR
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestDCGAN32:
13 | def setup(self):
14 | self.nz = 128
15 | self.N, self.C, self.H, self.W = (8, 3, 32, 32)
16 | self.ngf = 16
17 | self.ndf = 16
18 |
19 | self.netG = DCGANGeneratorCIFAR(ngf=self.ngf)
20 | self.netD = DCGANDiscriminatorCIFAR(ndf=self.ndf)
21 |
22 | def test_DCGANGeneratorCIFAR(self):
23 | noise = torch.ones(self.N, self.nz)
24 | output = self.netG(noise)
25 |
26 | assert output.shape == (self.N, self.C, self.H, self.W)
27 |
28 | def test_DCGANDiscriminatorCIFAR(self):
29 | images = torch.ones(self.N, self.C, self.H, self.W)
30 | output = self.netD(images)
31 |
32 | assert output.shape == (self.N, 1)
33 |
34 | def test_train_steps(self):
35 | real_batch = common.load_images(self.N, size=self.H)
36 |
37 | # Setup optimizers
38 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
39 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
40 |
41 | # Log statistics to check
42 | log_data = metric_log.MetricLog()
43 |
44 | # Test D train step
45 | log_data = self.netD.train_step(real_batch=real_batch,
46 | netG=self.netG,
47 | optD=optD,
48 | device='cpu',
49 | log_data=log_data)
50 |
51 | log_data = self.netG.train_step(real_batch=real_batch,
52 | netD=self.netD,
53 | optG=optG,
54 | log_data=log_data,
55 | device='cpu')
56 |
57 | for name, metric_dict in log_data.items():
58 | assert type(name) == str
59 | assert type(metric_dict['value']) == float
60 |
61 | def teardown(self):
62 | del self.netG
63 | del self.netD
64 |
65 |
66 | if __name__ == "__main__":
67 | test = TestDCGAN32()
68 | test.setup()
69 | test.test_DCGANGeneratorCIFAR()
70 | test.test_DCGANDiscriminatorCIFAR()
71 | test.test_train_steps()
72 | test.teardown()
73 |
--------------------------------------------------------------------------------
/tests/nets/gan/test_cgan.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import torch.nn as nn
4 |
5 | from torch_mimicry.nets.gan.cgan import BaseConditionalGenerator, BaseConditionalDiscriminator
6 |
7 |
8 | class ExampleGenerator(BaseConditionalGenerator):
9 | def __init__(self, *args, **kwargs):
10 | super().__init__(*args, **kwargs)
11 | self.l1 = nn.Linear(1, 1)
12 |
13 | def forward(self, x, y):
14 | return torch.ones(x.shape[0], 3, 32, 32)
15 |
16 |
17 | class ExampleDiscriminator(BaseConditionalDiscriminator):
18 | def __init__(self, *args, **kwargs):
19 | super().__init__(*args, **kwargs)
20 | self.l1 = nn.Linear(1, 1)
21 |
22 | def forward(self, x, y):
23 | return torch.ones(x.shape[0])
24 |
25 |
26 | class TestBaseGAN:
27 | def setup(self):
28 | self.N = 1
29 | self.device = "cpu"
30 |
31 | self.nz = 16
32 | self.ngf = 16
33 | self.ndf = 16
34 | self.bottom_width = 4
35 | self.num_classes = 10
36 | self.loss_type = 'gan'
37 |
38 | self.netG = ExampleGenerator(num_classes=self.num_classes,
39 | ngf=self.ngf,
40 | bottom_width=self.bottom_width,
41 | nz=self.nz,
42 | loss_type=self.loss_type)
43 | self.netD = ExampleDiscriminator(num_classes=self.num_classes,
44 | ndf=self.ndf,
45 | loss_type=self.loss_type)
46 | self.output_fake = torch.ones(self.N, 1)
47 | self.output_real = torch.ones(self.N, 1)
48 |
49 | def test_generate_images(self):
50 | with pytest.raises(ValueError):
51 | images = self.netG.generate_images(10, c=self.num_classes + 1)
52 |
53 | images = self.netG.generate_images(10)
54 | assert images.shape == (10, 3, 32, 32)
55 | assert images.device == self.netG.device
56 |
57 | images = self.netG.generate_images(10, c=0)
58 | assert images.shape == (10, 3, 32, 32)
59 | assert images.device == self.netG.device
60 |
61 | def test_generate_images_with_labels(self):
62 | with pytest.raises(ValueError):
63 | images, labels = self.netG.generate_images_with_labels(
64 | 10, c=self.num_classes + 1)
65 |
66 | images, labels = self.netG.generate_images_with_labels(10)
67 | assert images.shape == (10, 3, 32, 32)
68 | assert images.device == self.netG.device
69 | assert labels.shape == (10, )
70 | assert labels.device == self.netG.device
71 |
72 | images, labels = self.netG.generate_images_with_labels(10, c=0)
73 | assert images.shape == (10, 3, 32, 32)
74 | assert images.device == self.netG.device
75 | assert labels.shape == (10, )
76 | assert labels.device == self.netG.device
77 |
78 | def test_compute_GAN_loss(self):
79 | losses = ['gan', 'ns', 'hinge', 'wasserstein']
80 |
81 | for loss_type in losses:
82 | self.netG.loss_type = loss_type
83 | self.netD.loss_type = loss_type
84 |
85 | errG = self.netG.compute_gan_loss(output=self.output_fake)
86 | errD = self.netD.compute_gan_loss(output_real=self.output_real,
87 | output_fake=self.output_fake)
88 |
89 | assert type(errG.item()) == float
90 | assert type(errD.item()) == float
91 |
92 | def teardown(self):
93 | del self.netG
94 | del self.netD
95 | del self.output_real
96 | del self.output_fake
97 |
98 |
99 | if __name__ == "__main__":
100 | test = TestBaseGAN()
101 | test.setup()
102 | test.test_generate_images()
103 | test.test_generate_images_with_labels()
104 | test.test_compute_GAN_loss()
105 | test.teardown()
106 |
--------------------------------------------------------------------------------
/tests/nets/gan/test_gan.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import torch.nn as nn
4 |
5 | from torch_mimicry.nets.gan.gan import BaseGenerator, BaseDiscriminator
6 |
7 |
8 | class ExampleGenerator(BaseGenerator):
9 | def __init__(self, *args, **kwargs):
10 | super().__init__(*args, **kwargs)
11 | self.l1 = nn.Linear(1, 1)
12 |
13 | def forward(self, x):
14 | return torch.ones(x.shape[0], 3, 32, 32)
15 |
16 |
17 | class ExampleDiscriminator(BaseDiscriminator):
18 | def __init__(self, *args, **kwargs):
19 | super().__init__(*args, **kwargs)
20 | self.l1 = nn.Linear(1, 1)
21 |
22 | def forward(self, x):
23 | return torch.ones(x.shape[0])
24 |
25 |
26 | class TestBaseGAN:
27 | def setup(self):
28 | self.N = 1
29 | self.device = "cpu"
30 | self.real_label_val = 1.0
31 | self.fake_label_val = 0.0
32 |
33 | self.nz = 16
34 | self.ngf = 16
35 | self.ndf = 16
36 | self.bottom_width = 4
37 | self.loss_type = 'gan'
38 |
39 | self.netG = ExampleGenerator(ngf=self.ngf,
40 | bottom_width=self.bottom_width,
41 | nz=self.nz,
42 | loss_type=self.loss_type)
43 | self.netD = ExampleDiscriminator(ndf=self.ndf,
44 | loss_type=self.loss_type)
45 | self.output_fake = torch.ones(self.N, 1)
46 | self.output_real = torch.ones(self.N, 1)
47 |
48 | def test_generate_images(self):
49 | images = self.netG.generate_images(10)
50 |
51 | assert images.shape == (10, 3, 32, 32)
52 | assert images.device == self.netG.device
53 |
54 | def test_compute_GAN_loss(self):
55 | losses = ['gan', 'ns', 'hinge', 'wasserstein']
56 |
57 | for loss_type in losses:
58 | self.netG.loss_type = loss_type
59 | self.netD.loss_type = loss_type
60 |
61 | errG = self.netG.compute_gan_loss(output=self.output_fake)
62 | errD = self.netD.compute_gan_loss(output_real=self.output_real,
63 | output_fake=self.output_fake)
64 |
65 | assert type(errG.item()) == float
66 | assert type(errD.item()) == float
67 |
68 | with pytest.raises(ValueError):
69 | self.netG.loss_type = 'invalid'
70 | self.netG.compute_gan_loss(output=self.output_fake)
71 |
72 | with pytest.raises(ValueError):
73 | self.netD.loss_type = 'invalid'
74 | self.netD.compute_gan_loss(output_real=self.output_real,
75 | output_fake=self.output_fake)
76 |
77 | def teardown(self):
78 | del self.netG
79 | del self.netD
80 | del self.output_real
81 | del self.output_fake
82 |
83 |
84 | if __name__ == "__main__":
85 | test = TestBaseGAN()
86 | test.setup()
87 | test.test_generate_images()
88 | test.test_compute_GAN_loss()
89 | test.teardown()
90 |
--------------------------------------------------------------------------------
/tests/nets/infomax_gan/test_infomax_gan_128.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for InfoMaxGAN for image size 128.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.infomax_gan.infomax_gan_128 import InfoMaxGANGenerator128, InfoMaxGANDiscriminator128
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestInfoMaxGAN128:
13 | def setup(self):
14 | self.nz = 128
15 | self.N, self.C, self.H, self.W = (8, 3, 128, 128)
16 | self.ngf = 16
17 | self.ndf = 16
18 |
19 | self.netG = InfoMaxGANGenerator128(ngf=self.ngf)
20 | self.netD = InfoMaxGANDiscriminator128(ndf=self.ndf)
21 |
22 | def test_InfoMaxGANGenerator128(self):
23 | noise = torch.ones(self.N, self.nz)
24 | output = self.netG(noise)
25 |
26 | assert output.shape == (self.N, self.C, self.H, self.W)
27 |
28 | def test_InfoMaxGANDiscriminator128(self):
29 | images = torch.ones(self.N, self.C, self.H, self.W)
30 | output, local_feat, global_feat = self.netD(images)
31 |
32 | assert output.shape == (self.N, 1)
33 | assert local_feat.shape == (self.N, self.netD.ndf, self.H >> 5,
34 | self.W >> 5)
35 | assert global_feat.shape == (self.N, self.netD.ndf)
36 |
37 | def test_train_steps(self):
38 | real_batch = common.load_images(self.N, size=self.H)
39 |
40 | # Setup optimizers
41 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
42 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
43 |
44 | # Log statistics to check
45 | log_data = metric_log.MetricLog()
46 |
47 | # Test D train step
48 | log_data = self.netD.train_step(real_batch=real_batch,
49 | netG=self.netG,
50 | optD=optD,
51 | device='cpu',
52 | log_data=log_data)
53 |
54 | log_data = self.netG.train_step(real_batch=real_batch,
55 | netD=self.netD,
56 | optG=optG,
57 | log_data=log_data,
58 | device='cpu')
59 |
60 | for name, metric_dict in log_data.items():
61 | assert type(name) == str
62 | assert type(metric_dict['value']) == float
63 |
64 | def teardown(self):
65 | del self.netG
66 | del self.netD
67 |
68 |
69 | if __name__ == "__main__":
70 | test = TestInfoMaxGAN128()
71 | test.setup()
72 | test.test_InfoMaxGANGenerator128()
73 | test.test_InfoMaxGANDiscriminator128()
74 | test.test_train_steps()
75 | test.teardown()
76 |
--------------------------------------------------------------------------------
/tests/nets/infomax_gan/test_infomax_gan_32.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for InfoMaxGAN for image size 32.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.infomax_gan.infomax_gan_32 import InfoMaxGANGenerator32, InfoMaxGANDiscriminator32
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestInfoMaxGAN32:
13 | def setup(self):
14 | self.nz = 128
15 | self.N, self.C, self.H, self.W = (8, 3, 32, 32)
16 | self.ngf = 16
17 | self.ndf = 16
18 |
19 | self.netG = InfoMaxGANGenerator32(ngf=self.ngf)
20 | self.netD = InfoMaxGANDiscriminator32(ndf=self.ndf)
21 |
22 | def test_InfoMaxGANGenerator32(self):
23 | noise = torch.ones(self.N, self.nz)
24 | output = self.netG(noise)
25 |
26 | assert output.shape == (self.N, self.C, self.H, self.W)
27 |
28 | def test_InfoMaxGANDiscriminator32(self):
29 | images = torch.ones(self.N, self.C, self.H, self.W)
30 | output, local_feat, global_feat = self.netD(images)
31 |
32 | assert output.shape == (self.N, 1)
33 | assert local_feat.shape == (self.N, self.netD.ndf, self.H >> 2,
34 | self.W >> 2)
35 | assert global_feat.shape == (self.N, self.netD.ndf)
36 |
37 | def test_train_steps(self):
38 | real_batch = common.load_images(self.N, size=self.H)
39 |
40 | # Setup optimizers
41 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
42 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
43 |
44 | # Log statistics to check
45 | log_data = metric_log.MetricLog()
46 |
47 | # Test D train step
48 | log_data = self.netD.train_step(real_batch=real_batch,
49 | netG=self.netG,
50 | optD=optD,
51 | device='cpu',
52 | log_data=log_data)
53 |
54 | log_data = self.netG.train_step(real_batch=real_batch,
55 | netD=self.netD,
56 | optG=optG,
57 | log_data=log_data,
58 | device='cpu')
59 |
60 | for name, metric_dict in log_data.items():
61 | assert type(name) == str
62 | assert type(metric_dict['value']) == float
63 |
64 | def teardown(self):
65 | del self.netG
66 | del self.netD
67 |
68 |
69 | if __name__ == "__main__":
70 | test = TestInfoMaxGAN32()
71 | test.setup()
72 | test.test_InfoMaxGANGenerator32()
73 | test.test_InfoMaxGANDiscriminator32()
74 | test.test_train_steps()
75 | test.teardown()
76 |
--------------------------------------------------------------------------------
/tests/nets/infomax_gan/test_infomax_gan_48.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for InfoMaxGAN for image size 48.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.infomax_gan.infomax_gan_48 import InfoMaxGANGenerator48, InfoMaxGANDiscriminator48
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestInfoMaxGAN48:
13 | def setup(self):
14 | self.nz = 128
15 | self.N, self.C, self.H, self.W = (8, 3, 48, 48)
16 | self.ngf = 16
17 | self.ndf = 16
18 |
19 | self.netG = InfoMaxGANGenerator48(ngf=self.ngf)
20 | self.netD = InfoMaxGANDiscriminator48(ndf=self.ndf)
21 |
22 | def test_InfoMaxGANGenerator48(self):
23 | noise = torch.ones(self.N, self.nz)
24 | output = self.netG(noise)
25 |
26 | assert output.shape == (self.N, self.C, self.H, self.W)
27 |
28 | def test_InfoMaxGANDiscriminator48(self):
29 | images = torch.ones(self.N, self.C, self.H, self.W)
30 | output, local_feat, global_feat = self.netD(images)
31 |
32 | assert output.shape == (self.N, 1)
33 | assert local_feat.shape == (self.N, self.netD.ndf >> 1, self.H >> 4,
34 | self.W >> 4)
35 | assert global_feat.shape == (self.N, self.netD.ndf)
36 |
37 | def test_train_steps(self):
38 | real_batch = common.load_images(self.N, size=self.H)
39 |
40 | # Setup optimizers
41 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
42 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
43 |
44 | # Log statistics to check
45 | log_data = metric_log.MetricLog()
46 |
47 | # Test D train step
48 | log_data = self.netD.train_step(real_batch=real_batch,
49 | netG=self.netG,
50 | optD=optD,
51 | device='cpu',
52 | log_data=log_data)
53 |
54 | log_data = self.netG.train_step(real_batch=real_batch,
55 | netD=self.netD,
56 | optG=optG,
57 | log_data=log_data,
58 | device='cpu')
59 |
60 | for name, metric_dict in log_data.items():
61 | assert type(name) == str
62 | assert type(metric_dict['value']) == float
63 |
64 | def teardown(self):
65 | del self.netG
66 | del self.netD
67 |
68 |
69 | if __name__ == "__main__":
70 | test = TestInfoMaxGAN48()
71 | test.setup()
72 | test.test_InfoMaxGANGenerator48()
73 | test.test_InfoMaxGANDiscriminator48()
74 | test.test_train_steps()
75 | test.teardown()
76 |
--------------------------------------------------------------------------------
/tests/nets/infomax_gan/test_infomax_gan_64.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for InfoMaxGAN for image size 64.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.infomax_gan.infomax_gan_64 import InfoMaxGANGenerator64, InfoMaxGANDiscriminator64
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestInfoMaxGAN64:
13 | def setup(self):
14 | self.nz = 128
15 | self.N, self.C, self.H, self.W = (8, 3, 64, 64)
16 | self.ngf = 16
17 | self.ndf = 16
18 |
19 | self.netG = InfoMaxGANGenerator64(ngf=self.ngf)
20 | self.netD = InfoMaxGANDiscriminator64(ndf=self.ndf)
21 |
22 | def test_InfoMaxGANGenerator64(self):
23 | noise = torch.ones(self.N, self.nz)
24 | output = self.netG(noise)
25 |
26 | assert output.shape == (self.N, self.C, self.H, self.W)
27 |
28 | def test_InfoMaxGANDiscriminator64(self):
29 | images = torch.ones(self.N, self.C, self.H, self.W)
30 | output, local_feat, global_feat = self.netD(images)
31 |
32 | assert output.shape == (self.N, 1)
33 | assert local_feat.shape == (self.N, self.netD.ndf >> 1, self.H >> 4,
34 | self.W >> 4)
35 | assert global_feat.shape == (self.N, self.netD.ndf)
36 |
37 | def test_train_steps(self):
38 | real_batch = common.load_images(self.N, size=self.H)
39 |
40 | # Setup optimizers
41 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
42 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
43 |
44 | # Log statistics to check
45 | log_data = metric_log.MetricLog()
46 |
47 | # Test D train step
48 | log_data = self.netD.train_step(real_batch=real_batch,
49 | netG=self.netG,
50 | optD=optD,
51 | device='cpu',
52 | log_data=log_data)
53 |
54 | log_data = self.netG.train_step(real_batch=real_batch,
55 | netD=self.netD,
56 | optG=optG,
57 | log_data=log_data,
58 | device='cpu')
59 |
60 | for name, metric_dict in log_data.items():
61 | assert type(name) == str
62 | assert type(metric_dict['value']) == float
63 |
64 | def teardown(self):
65 | del self.netG
66 | del self.netD
67 |
68 |
69 | if __name__ == "__main__":
70 | test = TestInfoMaxGAN64()
71 | test.setup()
72 | test.test_InfoMaxGANGenerator64()
73 | test.test_InfoMaxGANDiscriminator64()
74 | test.test_train_steps()
75 | test.teardown()
76 |
--------------------------------------------------------------------------------
/tests/nets/infomax_gan/test_infomax_gan_base.py:
--------------------------------------------------------------------------------
1 | """
2 | Test for SSGAN specific functions at the discriminator.
3 | """
4 | import pytest
5 | import math
6 | import torch
7 |
8 | from torch_mimicry.nets.infomax_gan.infomax_gan_base import BaseDiscriminator
9 |
10 |
11 | class ExampleDiscriminator(BaseDiscriminator):
12 | def __init__(self, *args, **kwargs):
13 | super().__init__(*args, **kwargs)
14 |
15 | def forward(self, x):
16 | return
17 |
18 |
19 | class TestInfoMaxGANBase:
20 | def setup(self):
21 | self.ndf = 16
22 | self.nrkhs = 32
23 | self.N = 4
24 | self.netD = ExampleDiscriminator(ndf=self.ndf, nrkhs=self.nrkhs)
25 |
26 | def test_infonce_loss(self):
27 | l = torch.ones(self.N, self.nrkhs, 1)
28 | m = torch.ones(self.N, self.nrkhs, 1)
29 |
30 | loss = self.netD.infonce_loss(l=l, m=m)
31 | prob = math.exp(-1 * loss.item())
32 |
33 | assert type(loss.item()) == float
34 |
35 | # 1/4 probability
36 | assert abs(prob - 0.25) < 1e-2
37 |
38 | def test_compute_infomax_loss(self):
39 | with pytest.raises(ValueError):
40 | local_feat = torch.ones(self.N, self.ndf, 4, 4)
41 | global_feat = torch.ones(self.N, self.nrkhs)
42 | scale = 0.2
43 | self.netD.compute_infomax_loss(local_feat, global_feat, scale)
44 |
45 | def teardown(self):
46 | del self.netD
47 |
48 |
49 | if __name__ == "__main__":
50 | test = TestInfoMaxGANBase()
51 | test.setup()
52 | test.test_infonce_loss()
53 | test.test_compute_infomax_loss()
54 | test.teardown()
55 |
--------------------------------------------------------------------------------
/tests/nets/sagan/test_sagan_128.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for SAGAN for image size 128.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.sagan.sagan_128 import SAGANGenerator128, SAGANDiscriminator128
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestSAGAN128:
13 | def setup(self):
14 | self.num_classes = 10
15 | self.nz = 128
16 | self.N, self.C, self.H, self.W = (4, 3, 128, 128)
17 |
18 | self.noise = torch.ones(self.N, self.nz)
19 | self.images = torch.ones(self.N, self.C, self.H, self.W)
20 | self.Y = torch.randint(low=0, high=self.num_classes, size=(self.N, ))
21 |
22 | self.ngf = 32
23 | self.ndf = 64
24 |
25 | self.netG = SAGANGenerator128(num_classes=self.num_classes,
26 | ngf=self.ngf)
27 | self.netD = SAGANDiscriminator128(num_classes=self.num_classes,
28 | ndf=self.ndf)
29 |
30 | def test_SAGANGenerator128(self):
31 | images = self.netG(self.noise, self.Y)
32 | assert images.shape == (self.N, self.C, self.H, self.W)
33 |
34 | images = self.netG(self.noise, None)
35 | assert images.shape == (self.N, self.C, self.H, self.W)
36 |
37 | def test_SAGANDiscriminator128(self):
38 | output = self.netD(self.images, self.Y)
39 |
40 | assert output.shape == (self.N, 1)
41 |
42 | def test_train_steps(self):
43 | real_batch = common.load_images(self.N)
44 |
45 | # Setup optimizers
46 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
47 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
48 |
49 | # Log statistics to check
50 | log_data = metric_log.MetricLog()
51 |
52 | # Test D train step
53 | log_data = self.netD.train_step(real_batch=real_batch,
54 | netG=self.netG,
55 | optD=optD,
56 | device='cpu',
57 | log_data=log_data)
58 |
59 | log_data = self.netG.train_step(real_batch=real_batch,
60 | netD=self.netD,
61 | optG=optG,
62 | log_data=log_data,
63 | device='cpu')
64 |
65 | for name, metric_dict in log_data.items():
66 | assert type(name) == str
67 | assert type(metric_dict['value']) == float
68 |
69 | def teardown(self):
70 | del self.noise
71 | del self.images
72 | del self.Y
73 | del self.netG
74 | del self.netD
75 |
76 |
77 | if __name__ == "__main__":
78 | test = TestSAGAN128()
79 | test.setup()
80 | test.test_SAGANGenerator128()
81 | test.test_SAGANDiscriminator128()
82 | test.test_train_steps()
83 | test.teardown()
84 |
--------------------------------------------------------------------------------
/tests/nets/sagan/test_sagan_32.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for SAGAN for image size 32.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.sagan.sagan_32 import SAGANGenerator32, SAGANDiscriminator32
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestSAGAN32:
13 | def setup(self):
14 | self.num_classes = 10
15 | self.nz = 128
16 | self.N, self.C, self.H, self.W = (4, 3, 32, 32)
17 |
18 | self.noise = torch.ones(self.N, self.nz)
19 | self.images = torch.ones(self.N, self.C, self.H, self.W)
20 | self.Y = torch.randint(low=0, high=self.num_classes, size=(self.N, ))
21 |
22 | self.ngf = 32
23 | self.ndf = 64
24 |
25 | self.netG = SAGANGenerator32(num_classes=self.num_classes,
26 | ngf=self.ngf)
27 | self.netD = SAGANDiscriminator32(num_classes=self.num_classes,
28 | ndf=self.ndf)
29 |
30 | def test_SAGANGenerator32(self):
31 | images = self.netG(self.noise, self.Y)
32 | assert images.shape == (self.N, self.C, self.H, self.W)
33 |
34 | images = self.netG(self.noise, None)
35 | assert images.shape == (self.N, self.C, self.H, self.W)
36 |
37 | def test_SAGANDiscriminator32(self):
38 | output = self.netD(self.images, self.Y)
39 | assert output.shape == (self.N, 1)
40 |
41 | def test_train_steps(self):
42 | real_batch = common.load_images(self.N)
43 |
44 | # Setup optimizers
45 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
46 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
47 |
48 | # Log statistics to check
49 | log_data = metric_log.MetricLog()
50 |
51 | # Test D train step
52 | log_data = self.netD.train_step(real_batch=real_batch,
53 | netG=self.netG,
54 | optD=optD,
55 | device='cpu',
56 | log_data=log_data)
57 |
58 | log_data = self.netG.train_step(real_batch=real_batch,
59 | netD=self.netD,
60 | optG=optG,
61 | log_data=log_data,
62 | device='cpu')
63 |
64 | for name, metric_dict in log_data.items():
65 | assert type(name) == str
66 | assert type(metric_dict['value']) == float
67 |
68 | def teardown(self):
69 | del self.noise
70 | del self.images
71 | del self.Y
72 | del self.netG
73 | del self.netD
74 |
75 |
76 | if __name__ == "__main__":
77 | test = TestSAGAN32()
78 | test.setup()
79 | test.test_SAGANGenerator32()
80 | test.test_SAGANDiscriminator32()
81 | test.test_train_steps()
82 | test.teardown()
83 |
--------------------------------------------------------------------------------
/tests/nets/sngan/test_sngan_128.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for SNGAN for image size 128.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.sngan.sngan_128 import SNGANGenerator128, SNGANDiscriminator128
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestSNGAN128:
13 | def setup(self):
14 | self.nz = 128
15 | self.N, self.C, self.H, self.W = (8, 3, 128, 128)
16 | self.ngf = 16
17 | self.ndf = 16
18 |
19 | self.netG = SNGANGenerator128(ngf=self.ngf)
20 | self.netD = SNGANDiscriminator128(ndf=self.ndf)
21 |
22 | def test_SNGANGenerator128(self):
23 | noise = torch.ones(self.N, self.nz)
24 | output = self.netG(noise)
25 |
26 | assert output.shape == (self.N, self.C, self.H, self.W)
27 |
28 | def test_SNGANDiscriminator128(self):
29 | images = torch.ones(self.N, self.C, self.H, self.W)
30 | output = self.netD(images)
31 |
32 | assert output.shape == (self.N, 1)
33 |
34 | def test_train_steps(self):
35 | real_batch = common.load_images(self.N, size=self.H)
36 |
37 | # Setup optimizers
38 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
39 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
40 |
41 | # Log statistics to check
42 | log_data = metric_log.MetricLog()
43 |
44 | # Test D train step
45 | log_data = self.netD.train_step(real_batch=real_batch,
46 | netG=self.netG,
47 | optD=optD,
48 | device='cpu',
49 | log_data=log_data)
50 |
51 | log_data = self.netG.train_step(real_batch=real_batch,
52 | netD=self.netD,
53 | optG=optG,
54 | log_data=log_data,
55 | device='cpu')
56 |
57 | for name, metric_dict in log_data.items():
58 | assert type(name) == str
59 | assert type(metric_dict['value']) == float
60 |
61 | def teardown(self):
62 | del self.netG
63 | del self.netD
64 |
65 |
66 | if __name__ == "__main__":
67 | test = TestSNGAN128()
68 | test.setup()
69 | test.test_SNGANGenerator128()
70 | test.test_SNGANDiscriminator128()
71 | test.test_train_steps()
72 | test.teardown()
73 |
--------------------------------------------------------------------------------
/tests/nets/sngan/test_sngan_32.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for SNGAN for image size 32.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.sngan.sngan_32 import SNGANGenerator32, SNGANDiscriminator32
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestSNGAN32:
13 | def setup(self):
14 | self.nz = 128
15 | self.N, self.C, self.H, self.W = (8, 3, 32, 32)
16 | self.ngf = 16
17 | self.ndf = 16
18 |
19 | self.netG = SNGANGenerator32(ngf=self.ngf)
20 | self.netD = SNGANDiscriminator32(ndf=self.ndf)
21 |
22 | def test_SNGANGenerator32(self):
23 | noise = torch.ones(self.N, self.nz)
24 | output = self.netG(noise)
25 |
26 | assert output.shape == (self.N, self.C, self.H, self.W)
27 |
28 | def test_SNGANDiscriminator32(self):
29 | images = torch.ones(self.N, self.C, self.H, self.W)
30 | output = self.netD(images)
31 |
32 | assert output.shape == (self.N, 1)
33 |
34 | def test_train_steps(self):
35 | # Get real and fake images
36 | real_batch = common.load_images(self.N, size=self.H)
37 |
38 | # Setup optimizers
39 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
40 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
41 |
42 | # Log statistics to check
43 | log_data = metric_log.MetricLog()
44 |
45 | # Test D train step
46 | log_data = self.netD.train_step(real_batch=real_batch,
47 | netG=self.netG,
48 | optD=optD,
49 | device='cpu',
50 | log_data=log_data)
51 |
52 | log_data = self.netG.train_step(real_batch=real_batch,
53 | netD=self.netD,
54 | optG=optG,
55 | log_data=log_data,
56 | device='cpu')
57 |
58 | for name, metric_dict in log_data.items():
59 | assert type(name) == str
60 | assert type(metric_dict['value']) == float
61 |
62 | def teardown(self):
63 | del self.netG
64 | del self.netD
65 |
66 |
67 | if __name__ == "__main__":
68 | test = TestSNGAN32()
69 | test.setup()
70 | test.test_SNGANGenerator32()
71 | test.test_SNGANDiscriminator32()
72 | test.test_train_steps()
73 | test.teardown()
74 |
--------------------------------------------------------------------------------
/tests/nets/sngan/test_sngan_48.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for SNGAN for image size 48.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.sngan.sngan_48 import SNGANGenerator48, SNGANDiscriminator48
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestSNGAN48:
13 | def setup(self):
14 | self.nz = 128
15 | self.N, self.C, self.H, self.W = (8, 3, 48, 48)
16 | self.ngf = 16
17 | self.ndf = 16
18 |
19 | self.netG = SNGANGenerator48(ngf=self.ngf)
20 | self.netD = SNGANDiscriminator48(ndf=self.ndf)
21 |
22 | def test_SNGANGenerator48(self):
23 | noise = torch.ones(self.N, self.nz)
24 | output = self.netG(noise)
25 |
26 | assert output.shape == (self.N, self.C, self.H, self.W)
27 |
28 | def test_SNGANDiscriminator48(self):
29 | images = torch.ones(self.N, self.C, self.H, self.W)
30 | output = self.netD(images)
31 |
32 | assert output.shape == (self.N, 1)
33 |
34 | def test_train_steps(self):
35 | real_batch = common.load_images(self.N, size=self.H)
36 |
37 | # Setup optimizers
38 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
39 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
40 |
41 | # Log statistics to check
42 | log_data = metric_log.MetricLog()
43 |
44 | # Test D train step
45 | log_data = self.netD.train_step(real_batch=real_batch,
46 | netG=self.netG,
47 | optD=optD,
48 | device='cpu',
49 | log_data=log_data)
50 |
51 | log_data = self.netG.train_step(real_batch=real_batch,
52 | netD=self.netD,
53 | optG=optG,
54 | log_data=log_data,
55 | device='cpu')
56 |
57 | for name, metric_dict in log_data.items():
58 | assert type(name) == str
59 | assert type(metric_dict['value']) == float
60 |
61 | def teardown(self):
62 | del self.netG
63 | del self.netD
64 |
65 |
66 | if __name__ == "__main__":
67 | test = TestSNGAN48()
68 | test.setup()
69 | test.test_SNGANGenerator48()
70 | test.test_SNGANDiscriminator48()
71 | test.test_train_steps()
72 | test.teardown()
73 |
--------------------------------------------------------------------------------
/tests/nets/sngan/test_sngan_64.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for SNGAN for image size 64.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.sngan.sngan_64 import SNGANGenerator64, SNGANDiscriminator64
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestSNGAN64:
13 | def setup(self):
14 | self.nz = 128
15 | self.N, self.C, self.H, self.W = (8, 3, 64, 64)
16 | self.ngf = 16
17 | self.ndf = 16
18 |
19 | self.netG = SNGANGenerator64(ngf=self.ngf)
20 | self.netD = SNGANDiscriminator64(ndf=self.ndf)
21 |
22 | def test_SNGANGenerator64(self):
23 | noise = torch.ones(self.N, self.nz)
24 | output = self.netG(noise)
25 |
26 | assert output.shape == (self.N, self.C, self.H, self.W)
27 |
28 | def test_SNGANDiscriminator64(self):
29 | images = torch.ones(self.N, self.C, self.H, self.W)
30 | output = self.netD(images)
31 |
32 | assert output.shape == (self.N, 1)
33 |
34 | def test_train_steps(self):
35 | real_batch = common.load_images(self.N, size=self.H)
36 |
37 | # Setup optimizers
38 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
39 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
40 |
41 | # Log statistics to check
42 | log_data = metric_log.MetricLog()
43 |
44 | # Test D train step
45 | log_data = self.netD.train_step(real_batch=real_batch,
46 | netG=self.netG,
47 | optD=optD,
48 | device='cpu',
49 | log_data=log_data)
50 |
51 | log_data = self.netG.train_step(real_batch=real_batch,
52 | netD=self.netD,
53 | optG=optG,
54 | log_data=log_data,
55 | device='cpu')
56 |
57 | for name, metric_dict in log_data.items():
58 | assert type(name) == str
59 | assert type(metric_dict['value']) == float
60 |
61 | def teardown(self):
62 | del self.netG
63 | del self.netD
64 |
65 |
66 | if __name__ == "__main__":
67 | test = TestSNGAN64()
68 | test.setup()
69 | test.test_SNGANGenerator64()
70 | test.test_SNGANDiscriminator64()
71 | test.test_train_steps()
72 | test.teardown()
73 |
--------------------------------------------------------------------------------
/tests/nets/ssgan/test_ssgan_128.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for SSGAN for image size 128.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.ssgan.ssgan_128 import SSGANGenerator128, SSGANDiscriminator128
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestSSGAN128:
13 | def setup(self):
14 | self.nz = 128
15 | self.N, self.C, self.H, self.W = (8, 3, 128, 128)
16 | self.ngf = 16
17 | self.ndf = 16
18 | self.device = 'cpu'
19 |
20 | self.netG = SSGANGenerator128(ngf=self.ngf)
21 | self.netD = SSGANDiscriminator128(ndf=self.ndf)
22 |
23 | def test_SSGANGenerator128(self):
24 | noise = torch.ones(self.N, self.nz)
25 | output = self.netG(noise)
26 |
27 | assert output.shape == (self.N, self.C, self.H, self.W)
28 |
29 | def test_SSGANDiscriminator128(self):
30 | images = torch.ones(self.N, self.C, self.H, self.W)
31 | output, labels = self.netD(images)
32 |
33 | assert output.shape == (self.N, 1)
34 | assert labels.shape == (self.N, 4)
35 |
36 | def test_train_steps(self):
37 | real_batch = common.load_images(self.N, size=self.H)
38 |
39 | # Setup optimizers
40 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
41 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
42 |
43 | # Log statistics to check
44 | log_data = metric_log.MetricLog()
45 |
46 | # Test D train step
47 | log_data = self.netD.train_step(real_batch=real_batch,
48 | netG=self.netG,
49 | optD=optD,
50 | device=self.device,
51 | log_data=log_data)
52 |
53 | log_data = self.netG.train_step(real_batch=real_batch,
54 | netD=self.netD,
55 | optG=optG,
56 | log_data=log_data,
57 | device=self.device)
58 |
59 | for name, metric_dict in log_data.items():
60 | assert type(name) == str
61 | assert type(metric_dict['value']) == float
62 |
63 | def teardown(self):
64 | del self.netG
65 | del self.netD
66 |
67 |
68 | if __name__ == "__main__":
69 | test = TestSSGAN128()
70 | test.setup()
71 | test.test_SSGANGenerator128()
72 | test.test_SSGANDiscriminator128()
73 | test.test_train_steps()
74 | test.teardown()
75 |
--------------------------------------------------------------------------------
/tests/nets/ssgan/test_ssgan_32.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for SSGAN for image size 32.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.ssgan.ssgan_32 import SSGANGenerator32, SSGANDiscriminator32
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestSSGAN32:
13 | def setup(self):
14 | self.nz = 128
15 | self.N, self.C, self.H, self.W = (8, 3, 32, 32)
16 | self.ngf = 16
17 | self.ndf = 16
18 |
19 | self.netG = SSGANGenerator32(ngf=self.ngf)
20 | self.netD = SSGANDiscriminator32(ndf=self.ndf)
21 |
22 | def test_SSGANGenerator32(self):
23 | noise = torch.ones(self.N, self.nz)
24 | output = self.netG(noise)
25 |
26 | assert output.shape == (self.N, self.C, self.H, self.W)
27 |
28 | def test_SSGANDiscriminator32(self):
29 | images = torch.ones(self.N, self.C, self.H, self.W)
30 | output, labels = self.netD(images)
31 |
32 | assert output.shape == (self.N, 1)
33 | assert labels.shape == (self.N, 4)
34 |
35 | def test_train_steps(self):
36 | real_batch = common.load_images(self.N, size=self.H)
37 |
38 | # Setup optimizers
39 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
40 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
41 |
42 | # Log statistics to check
43 | log_data = metric_log.MetricLog()
44 |
45 | # Test D train step
46 | log_data = self.netD.train_step(real_batch=real_batch,
47 | netG=self.netG,
48 | optD=optD,
49 | device='cpu',
50 | log_data=log_data)
51 |
52 | log_data = self.netG.train_step(real_batch=real_batch,
53 | netD=self.netD,
54 | optG=optG,
55 | log_data=log_data,
56 | device='cpu')
57 |
58 | for name, metric_dict in log_data.items():
59 | assert type(name) == str
60 | assert type(metric_dict['value']) == float
61 |
62 | def teardown(self):
63 | del self.netG
64 | del self.netD
65 |
66 |
67 | if __name__ == "__main__":
68 | test = TestSSGAN32()
69 | test.setup()
70 | test.test_SSGANGenerator32()
71 | test.test_SSGANDiscriminator32()
72 | test.test_train_steps()
73 | test.teardown()
74 |
--------------------------------------------------------------------------------
/tests/nets/ssgan/test_ssgan_48.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for SSGAN for image size 48.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.ssgan.ssgan_48 import SSGANGenerator48, SSGANDiscriminator48
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestSSGAN48:
13 | def setup(self):
14 | self.nz = 128
15 | self.N, self.C, self.H, self.W = (8, 3, 48, 48)
16 | self.ngf = 16
17 | self.ndf = 16
18 |
19 | self.netG = SSGANGenerator48(ngf=self.ngf)
20 | self.netD = SSGANDiscriminator48(ndf=self.ndf)
21 |
22 | def test_SSGANGenerator48(self):
23 | noise = torch.ones(self.N, self.nz)
24 | output = self.netG(noise)
25 |
26 | assert output.shape == (self.N, self.C, self.H, self.W)
27 |
28 | def test_SSGANDiscriminator48(self):
29 | images = torch.ones(self.N, self.C, self.H, self.W)
30 | output, labels = self.netD(images)
31 |
32 | assert output.shape == (self.N, 1)
33 | assert labels.shape == (self.N, 4)
34 |
35 | def test_train_steps(self):
36 | real_batch = common.load_images(self.N, size=self.H)
37 |
38 | # Setup optimizers
39 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
40 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
41 |
42 | # Log statistics to check
43 | log_data = metric_log.MetricLog()
44 |
45 | # Test D train step
46 | log_data = self.netD.train_step(real_batch=real_batch,
47 | netG=self.netG,
48 | optD=optD,
49 | device='cpu',
50 | log_data=log_data)
51 |
52 | log_data = self.netG.train_step(real_batch=real_batch,
53 | netD=self.netD,
54 | optG=optG,
55 | log_data=log_data,
56 | device='cpu')
57 |
58 | for name, metric_dict in log_data.items():
59 | assert type(name) == str
60 | assert type(metric_dict['value']) == float
61 |
62 | def teardown(self):
63 | del self.netG
64 | del self.netD
65 |
66 |
67 | if __name__ == "__main__":
68 | test = TestSSGAN48()
69 | test.setup()
70 | test.test_SSGANGenerator48()
71 | test.test_SSGANDiscriminator48()
72 | test.test_train_steps()
73 | test.teardown()
74 |
--------------------------------------------------------------------------------
/tests/nets/ssgan/test_ssgan_64.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for SSGAN for image size 64.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.ssgan.ssgan_64 import SSGANGenerator64, SSGANDiscriminator64
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestSSGAN64:
13 | def setup(self):
14 | self.nz = 128
15 | self.N, self.C, self.H, self.W = (8, 3, 64, 64)
16 | self.ngf = 16
17 | self.ndf = 16
18 |
19 | self.netG = SSGANGenerator64(ngf=self.ngf)
20 | self.netD = SSGANDiscriminator64(ndf=self.ndf)
21 |
22 | def test_SSGANGenerator64(self):
23 | noise = torch.ones(self.N, self.nz)
24 | output = self.netG(noise)
25 |
26 | assert output.shape == (self.N, self.C, self.H, self.W)
27 |
28 | def test_SSGANDiscriminator64(self):
29 | images = torch.ones(self.N, self.C, self.H, self.W)
30 | output, labels = self.netD(images)
31 |
32 | assert output.shape == (self.N, 1)
33 | assert labels.shape == (self.N, 4)
34 |
35 | def test_train_steps(self):
36 | real_batch = common.load_images(self.N, size=self.H)
37 |
38 | # Setup optimizers
39 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
40 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
41 |
42 | # Log statistics to check
43 | log_data = metric_log.MetricLog()
44 |
45 | # Test D train step
46 | log_data = self.netD.train_step(real_batch=real_batch,
47 | netG=self.netG,
48 | optD=optD,
49 | device='cpu',
50 | log_data=log_data)
51 |
52 | log_data = self.netG.train_step(real_batch=real_batch,
53 | netD=self.netD,
54 | optG=optG,
55 | log_data=log_data,
56 | device='cpu')
57 |
58 | for name, metric_dict in log_data.items():
59 | assert type(name) == str
60 | assert type(metric_dict['value']) == float
61 |
62 | def teardown(self):
63 | del self.netG
64 | del self.netD
65 |
66 |
67 | if __name__ == "__main__":
68 | test = TestSSGAN64()
69 | test.setup()
70 | test.test_SSGANGenerator64()
71 | test.test_SSGANDiscriminator64()
72 | test.test_train_steps()
73 | test.teardown()
74 |
--------------------------------------------------------------------------------
/tests/nets/ssgan/test_ssgan_base.py:
--------------------------------------------------------------------------------
1 | """
2 | Test for SSGAN specific functions at the discriminator.
3 | """
4 | import pytest
5 | import torch
6 |
7 | from torch_mimicry.nets.ssgan.ssgan_base import SSGANBaseDiscriminator
8 | from torch_mimicry.utils import common
9 |
10 |
11 | class ExampleDiscriminator(SSGANBaseDiscriminator):
12 | def __init__(self, *args, **kwargs):
13 | super().__init__(*args, **kwargs)
14 |
15 | def forward(self, x):
16 | return torch.ones(x.shape[0])
17 |
18 |
19 | class TestSSGANBase:
20 | def setup(self):
21 | self.netD = ExampleDiscriminator(ndf=16)
22 |
23 | def test_rot_tensor(self):
24 | # Load image and model
25 | image, _ = common.load_images(1, size=32)
26 |
27 | # For any rotation, after performing the same action 4 times,
28 | # you should return to the same pixel value
29 | for deg in [0, 90, 180, 270]:
30 | x = image.clone()
31 | for _ in range(4):
32 | x = self.netD._rot_tensor(x, deg)
33 |
34 | assert torch.sum((x - image)**2) < 1e-5
35 |
36 | def test_rotate_batch(self):
37 | # Load image and model
38 | images, _ = common.load_images(8, size=32)
39 |
40 | check = images.clone()
41 | check, labels = self.netD._rotate_batch(check)
42 | degrees = [0, 90, 180, 270]
43 |
44 | # Rotate 3 more times to get back to original.
45 | for i in range(check.shape[0]):
46 | for _ in range(3):
47 | check[i] = self.netD._rot_tensor(check[i], degrees[labels[i]])
48 |
49 | assert torch.sum((images - check)**2) < 1e-5
50 |
51 | with pytest.raises(ValueError):
52 | self.netD._rot_tensor(check[i], 9999)
53 |
54 | def teardown(self):
55 | del self.netD
56 |
57 |
58 | if __name__ == "__main__":
59 | test = TestSSGANBase()
60 | test.setup()
61 | test.test_rot_tensor()
62 | test.test_rotate_batch()
63 | test.teardown()
64 |
--------------------------------------------------------------------------------
/tests/nets/wgan_gp/test_wgan_gp_128.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for WGAN-GP for image size 128.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.wgan_gp.wgan_gp_128 import WGANGPGenerator128, WGANGPDiscriminator128
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestWGANGP128:
13 | def setup(self):
14 | self.nz = 128
15 | self.N, self.C, self.H, self.W = (8, 3, 128, 128)
16 | self.ngf = 16
17 | self.ndf = 16
18 |
19 | self.netG = WGANGPGenerator128(ngf=self.ngf)
20 | self.netD = WGANGPDiscriminator128(ndf=self.ndf)
21 |
22 | def test_WGANGPGenerator128(self):
23 | noise = torch.ones(self.N, self.nz)
24 | output = self.netG(noise)
25 |
26 | assert output.shape == (self.N, self.C, self.H, self.W)
27 |
28 | def test_WGANGPDiscriminator128(self):
29 | images = torch.ones(self.N, self.C, self.H, self.W)
30 | output = self.netD(images)
31 |
32 | assert output.shape == (self.N, 1)
33 |
34 | def test_train_steps(self):
35 | # Get real and fake images
36 | real_batch = common.load_images(self.N, size=self.H)
37 |
38 | # Setup optimizers
39 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
40 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
41 |
42 | # Log statistics to check
43 | log_data = metric_log.MetricLog()
44 |
45 | # Test D train step
46 | log_data = self.netD.train_step(real_batch=real_batch,
47 | netG=self.netG,
48 | optD=optD,
49 | device='cpu',
50 | log_data=log_data)
51 |
52 | log_data = self.netG.train_step(real_batch=real_batch,
53 | netD=self.netD,
54 | optG=optG,
55 | log_data=log_data,
56 | device='cpu')
57 |
58 | for name, metric_dict in log_data.items():
59 | assert type(name) == str
60 | assert type(metric_dict['value']) == float
61 |
62 | def teardown(self):
63 | del self.netG
64 | del self.netD
65 |
66 |
67 | if __name__ == "__main__":
68 | test = TestWGANGP128()
69 | test.setup()
70 | test.test_WGANGPGenerator128()
71 | test.test_WGANGPDiscriminator128()
72 | test.test_train_steps()
73 | test.teardown()
74 |
--------------------------------------------------------------------------------
/tests/nets/wgan_gp/test_wgan_gp_32.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for WGAN-GP for image size 32.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.wgan_gp.wgan_gp_32 import WGANGPGenerator32, WGANGPDiscriminator32
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestWGANGP32:
13 | def setup(self):
14 | self.nz = 128
15 | self.N, self.C, self.H, self.W = (8, 3, 32, 32)
16 | self.ngf = 16
17 | self.ndf = 16
18 |
19 | self.netG = WGANGPGenerator32(ngf=self.ngf)
20 | self.netD = WGANGPDiscriminator32(ndf=self.ndf)
21 |
22 | def test_WGANGPGenerator32(self):
23 | noise = torch.ones(self.N, self.nz)
24 | output = self.netG(noise)
25 |
26 | assert output.shape == (self.N, self.C, self.H, self.W)
27 |
28 | def test_WGANGPDiscriminator32(self):
29 | images = torch.ones(self.N, self.C, self.H, self.W)
30 | output = self.netD(images)
31 |
32 | assert output.shape == (self.N, 1)
33 |
34 | def test_train_steps(self):
35 | # Get real and fake images
36 | real_batch = common.load_images(self.N, size=self.H)
37 |
38 | # Setup optimizers
39 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
40 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
41 |
42 | # Log statistics to check
43 | log_data = metric_log.MetricLog()
44 |
45 | # Test D train step
46 | log_data = self.netD.train_step(real_batch=real_batch,
47 | netG=self.netG,
48 | optD=optD,
49 | device='cpu',
50 | log_data=log_data)
51 |
52 | log_data = self.netG.train_step(real_batch=real_batch,
53 | netD=self.netD,
54 | optG=optG,
55 | log_data=log_data,
56 | device='cpu')
57 |
58 | for name, metric_dict in log_data.items():
59 | assert type(name) == str
60 | assert type(metric_dict['value']) == float
61 |
62 | def teardown(self):
63 | del self.netG
64 | del self.netD
65 |
66 |
67 | if __name__ == "__main__":
68 | test = TestWGANGP32()
69 | test.setup()
70 | test.test_WGANGPGenerator32()
71 | test.test_WGANGPDiscriminator32()
72 | test.test_train_steps()
73 | test.teardown()
74 |
--------------------------------------------------------------------------------
/tests/nets/wgan_gp/test_wgan_gp_48.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for WGAN-GP for image size 48.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.wgan_gp.wgan_gp_48 import WGANGPGenerator48, WGANGPDiscriminator48
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestWGANGP48:
13 | def setup(self):
14 | self.nz = 128
15 | self.N, self.C, self.H, self.W = (8, 3, 48, 48)
16 | self.ngf = 16
17 | self.ndf = 16
18 |
19 | self.netG = WGANGPGenerator48(ngf=self.ngf)
20 | self.netD = WGANGPDiscriminator48(ndf=self.ndf)
21 |
22 | def test_WGANGPGenerator48(self):
23 | noise = torch.ones(self.N, self.nz)
24 | output = self.netG(noise)
25 |
26 | assert output.shape == (self.N, self.C, self.H, self.W)
27 |
28 | def test_WGANGPDiscriminator48(self):
29 | images = torch.ones(self.N, self.C, self.H, self.W)
30 | output = self.netD(images)
31 |
32 | assert output.shape == (self.N, 1)
33 |
34 | def test_train_steps(self):
35 | # Get real and fake images
36 | real_batch = common.load_images(self.N, size=self.H)
37 |
38 | # Setup optimizers
39 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
40 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
41 |
42 | # Log statistics to check
43 | log_data = metric_log.MetricLog()
44 |
45 | # Test D train step
46 | log_data = self.netD.train_step(real_batch=real_batch,
47 | netG=self.netG,
48 | optD=optD,
49 | device='cpu',
50 | log_data=log_data)
51 |
52 | log_data = self.netG.train_step(real_batch=real_batch,
53 | netD=self.netD,
54 | optG=optG,
55 | log_data=log_data,
56 | device='cpu')
57 |
58 | for name, metric_dict in log_data.items():
59 | assert type(name) == str
60 | assert type(metric_dict['value']) == float
61 |
62 | def teardown(self):
63 | del self.netG
64 | del self.netD
65 |
66 |
67 | if __name__ == "__main__":
68 | test = TestWGANGP48()
69 | test.setup()
70 | test.test_WGANGPGenerator48()
71 | test.test_WGANGPDiscriminator48()
72 | test.test_train_steps()
73 | test.teardown()
74 |
--------------------------------------------------------------------------------
/tests/nets/wgan_gp/test_wgan_gp_64.py:
--------------------------------------------------------------------------------
1 | """
2 | Test functions for WGAN-GP for image size 64.
3 | """
4 | import torch
5 | import torch.optim as optim
6 |
7 | from torch_mimicry.nets.wgan_gp.wgan_gp_64 import WGANGPGenerator64, WGANGPDiscriminator64
8 | from torch_mimicry.training import metric_log
9 | from torch_mimicry.utils import common
10 |
11 |
12 | class TestWGANGP64:
13 | def setup(self):
14 | self.nz = 128
15 | self.N, self.C, self.H, self.W = (8, 3, 64, 64)
16 | self.ngf = 16
17 | self.ndf = 16
18 |
19 | self.netG = WGANGPGenerator64(ngf=self.ngf)
20 | self.netD = WGANGPDiscriminator64(ndf=self.ndf)
21 |
22 | def test_WGANGPGenerator64(self):
23 | noise = torch.ones(self.N, self.nz)
24 | output = self.netG(noise)
25 |
26 | assert output.shape == (self.N, self.C, self.H, self.W)
27 |
28 | def test_WGANGPDiscriminator64(self):
29 | images = torch.ones(self.N, self.C, self.H, self.W)
30 | output = self.netD(images)
31 |
32 | assert output.shape == (self.N, 1)
33 |
34 | def test_train_steps(self):
35 | # Get real and fake images
36 | real_batch = common.load_images(self.N, size=self.H)
37 |
38 | # Setup optimizers
39 | optD = optim.Adam(self.netD.parameters(), 2e-4, betas=(0.0, 0.9))
40 | optG = optim.Adam(self.netG.parameters(), 2e-4, betas=(0.0, 0.9))
41 |
42 | # Log statistics to check
43 | log_data = metric_log.MetricLog()
44 |
45 | # Test D train step
46 | log_data = self.netD.train_step(real_batch=real_batch,
47 | netG=self.netG,
48 | optD=optD,
49 | device='cpu',
50 | log_data=log_data)
51 |
52 | log_data = self.netG.train_step(real_batch=real_batch,
53 | netD=self.netD,
54 | optG=optG,
55 | log_data=log_data,
56 | device='cpu')
57 |
58 | for name, metric_dict in log_data.items():
59 | assert type(name) == str
60 | assert type(metric_dict['value']) == float
61 |
62 | def teardown(self):
63 | del self.netG
64 | del self.netD
65 |
66 |
67 | if __name__ == "__main__":
68 | test = TestWGANGP64()
69 | test.setup()
70 | test.test_WGANGPGenerator64()
71 | test.test_WGANGPDiscriminator64()
72 | test.test_train_steps()
73 | test.teardown()
74 |
--------------------------------------------------------------------------------
/tests/nets/wgan_gp/test_wgan_gp_resblocks.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 |
3 | import torch
4 |
5 | from torch_mimicry.nets.wgan_gp import wgan_gp_resblocks
6 |
7 |
8 | class TestResBlocks:
9 | def setup(self):
10 | self.images = torch.ones(4, 3, 16, 16)
11 |
12 | def test_GBlock(self):
13 | # Arguments
14 | num_classes_list = [0, 10]
15 | spectral_norm_list = [True, False]
16 | in_channels = 3
17 | out_channels = 8
18 | args_comb = product(num_classes_list, spectral_norm_list)
19 |
20 | for args in args_comb:
21 | num_classes = args[0]
22 | spectral_norm = args[1]
23 |
24 | if num_classes > 0:
25 | y = torch.ones((4, ), dtype=torch.int64)
26 | else:
27 | y = None
28 |
29 | gen_block_up = wgan_gp_resblocks.GBlock(
30 | in_channels=in_channels,
31 | out_channels=out_channels,
32 | upsample=True,
33 | num_classes=num_classes,
34 | spectral_norm=spectral_norm)
35 |
36 | gen_block = wgan_gp_resblocks.GBlock(in_channels=in_channels,
37 | out_channels=out_channels,
38 | upsample=False,
39 | num_classes=num_classes,
40 | spectral_norm=spectral_norm)
41 |
42 | assert gen_block_up(self.images, y).shape == (4, 8, 32, 32)
43 | assert gen_block(self.images, y).shape == (4, 8, 16, 16)
44 |
45 | def test_DBlocks(self):
46 | in_channels = 3
47 | out_channels = 8
48 |
49 | for spectral_norm in [True, False]:
50 | dis_block_down = wgan_gp_resblocks.DBlock(
51 | in_channels=in_channels,
52 | out_channels=out_channels,
53 | downsample=True,
54 | spectral_norm=spectral_norm)
55 |
56 | dis_block = wgan_gp_resblocks.DBlock(in_channels=in_channels,
57 | out_channels=out_channels,
58 | downsample=False,
59 | spectral_norm=spectral_norm)
60 |
61 | dis_block_opt = wgan_gp_resblocks.DBlockOptimized(
62 | in_channels=in_channels,
63 | out_channels=out_channels,
64 | spectral_norm=spectral_norm)
65 |
66 | assert dis_block(self.images).shape == (4, out_channels, 16, 16)
67 | assert dis_block_down(self.images).shape == (4, out_channels, 8, 8)
68 | assert dis_block_opt(self.images).shape == (4, out_channels, 8, 8)
69 |
70 | def teardown(self):
71 | del self.images
72 |
73 |
74 | if __name__ == "__main__":
75 | test = TestResBlocks()
76 | test.setup()
77 | test.test_GBlock()
78 | test.test_DBlocks()
79 | test.teardown()
80 |
--------------------------------------------------------------------------------
/tests/training/test_logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from torch_mimicry.nets.gan import gan, cgan
8 | from torch_mimicry.training import logger, metric_log
9 |
10 |
11 | class ExampleGen(gan.BaseGenerator):
12 | def __init__(self,
13 | bottom_width=4,
14 | nz=4,
15 | ngf=16,
16 | loss_type='gan',
17 | *args,
18 | **kwargs):
19 | super().__init__(nz=nz,
20 | ngf=ngf,
21 | bottom_width=bottom_width,
22 | loss_type=loss_type,
23 | *args,
24 | **kwargs)
25 | self.linear = nn.Linear(self.nz, 3072)
26 |
27 | def forward(self, x):
28 | output = self.linear(x)
29 | output = output.view(x.shape[0], 3, 32, 32)
30 |
31 | return output
32 |
33 |
34 | class ExampleConditionalGen(cgan.BaseConditionalGenerator):
35 | def __init__(self,
36 | bottom_width=4,
37 | nz=4,
38 | ngf=16,
39 | loss_type='gan',
40 | num_classes=10,
41 | **kwargs):
42 | super().__init__(nz=nz,
43 | ngf=ngf,
44 | bottom_width=bottom_width,
45 | loss_type=loss_type,
46 | num_classes=num_classes,
47 | **kwargs)
48 | self.linear = nn.Linear(self.nz, 3072)
49 |
50 | def forward(self, x, y=None):
51 | output = self.linear(x)
52 | output = output.view(x.shape[0], 3, 32, 32)
53 |
54 | return output
55 |
56 |
57 | class TestLogger:
58 | def setup(self):
59 | self.log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
60 | "test_log")
61 |
62 | self.logger = logger.Logger(log_dir=self.log_dir,
63 | num_steps=100,
64 | dataset_size=50000,
65 | flush_secs=30,
66 | device=torch.device('cpu'))
67 |
68 | self.scalars = [
69 | 'errG',
70 | 'errD',
71 | 'D(x)',
72 | 'D(G(z))',
73 | 'img',
74 | 'lr_D',
75 | 'lr_G',
76 | ]
77 |
78 | def test_print_log(self):
79 | log_data = metric_log.MetricLog()
80 | global_step = 10
81 |
82 | # Populate log data with some value
83 | for scalar in self.scalars:
84 | if scalar == 'img':
85 | continue
86 |
87 | log_data.add_metric(scalar, 1.0)
88 |
89 | printed = self.logger.print_log(global_step=global_step,
90 | log_data=log_data,
91 | time_taken=10)
92 |
93 | assert printed == (
94 | 'INFO: [Epoch 1/1][Global Step: 10/100] ' +
95 | '\n| D(G(z)): 1.0\n| D(x): 1.0\n| errD: 1.0\n| errG: 1.0' +
96 | '\n| lr_D: 1.0\n| lr_G: 1.0\n| (10.0000 sec/idx)')
97 |
98 | def test_vis_images(self):
99 | netG = ExampleGen()
100 | netG_conditional = ExampleConditionalGen()
101 |
102 | global_step = 10
103 | num_images = 64
104 |
105 | # Test unconditional
106 | self.logger.vis_images(netG, global_step, num_images)
107 | img_dir = os.path.join(self.log_dir, 'images')
108 | filenames = os.listdir(img_dir)
109 | assert 'fake_samples_step_10.png' in filenames
110 | assert 'fixed_fake_samples_step_10.png' in filenames
111 |
112 | # Remove images
113 | for file in filenames:
114 | os.remove(os.path.join(img_dir, file))
115 |
116 | # Test conditional
117 | self.logger.vis_images(netG_conditional, global_step, num_images)
118 | assert 'fake_samples_step_10.png' in filenames
119 | assert 'fixed_fake_samples_step_10.png' in filenames
120 |
121 | def teardown(self):
122 | shutil.rmtree(self.log_dir)
123 |
124 |
125 | if __name__ == "__main__":
126 | test = TestLogger()
127 | test.setup()
128 | test.test_print_log()
129 | test.test_vis_images()
130 | test.teardown()
131 |
--------------------------------------------------------------------------------
/tests/training/test_metric_log.py:
--------------------------------------------------------------------------------
1 | from torch_mimicry.training import metric_log
2 |
3 |
4 | class TestMetricLog:
5 | def setup(self):
6 | self.log_data = metric_log.MetricLog()
7 |
8 | def test_add_metric(self):
9 | # Singular metric
10 | self.log_data.add_metric('singular', 1.0124214)
11 | assert self.log_data['singular'] == 1.0124
12 |
13 | # Multiple metrics under same group
14 | self.log_data.add_metric('errD', 1.00001, group='loss')
15 | self.log_data.add_metric('errG', 1.0011, group='loss')
16 |
17 | assert self.log_data.get_group_name(
18 | 'errD') == self.log_data.get_group_name('errG')
19 |
20 | def teardown(self):
21 | del self.log_data
22 |
23 |
24 | if __name__ == "__main__":
25 | test = TestMetricLog()
26 | test.setup()
27 | test.test_add_metric()
28 | test.teardown()
29 |
--------------------------------------------------------------------------------
/tests/training/test_scheduler.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.optim as optim
3 | import pytest
4 |
5 | from torch_mimicry.training import scheduler, metric_log
6 |
7 |
8 | class TestLRScheduler:
9 | def setup(self):
10 | self.netD = nn.Linear(10, 10)
11 | self.netG = nn.Linear(10, 10)
12 |
13 | self.num_steps = 10
14 | self.lr_D = 2e-4
15 | self.lr_G = 2e-4
16 |
17 | def get_lr(self, optimizer):
18 | return optimizer.param_groups[0]['lr']
19 |
20 | def test_linear_decay(self):
21 | optD = optim.Adam(self.netD.parameters(), self.lr_D, betas=(0.0, 0.9))
22 | optG = optim.Adam(self.netG.parameters(), self.lr_G, betas=(0.0, 0.9))
23 |
24 | lr_scheduler = scheduler.LRScheduler(lr_decay='linear',
25 | optD=optD,
26 | optG=optG,
27 | num_steps=self.num_steps,
28 | start_step=5)
29 |
30 | log_data = metric_log.MetricLog()
31 | for step in range(self.num_steps):
32 | lr_scheduler.step(log_data, step)
33 |
34 | if step < lr_scheduler.start_step:
35 | assert abs(2e-4 - self.get_lr(optD)) < 1e-5
36 | assert abs(2e-4 - self.get_lr(optG)) < 1e-5
37 |
38 | else:
39 | curr_lr = ((1 - (max(0, step - lr_scheduler.start_step) /
40 | (self.num_steps - lr_scheduler.start_step))) *
41 | self.lr_D)
42 |
43 | assert abs(curr_lr - self.get_lr(optD)) < 1e-5
44 | assert abs(curr_lr - self.get_lr(optG)) < 1e-5
45 |
46 | def test_no_decay(self):
47 | optD = optim.Adam(self.netD.parameters(), self.lr_D, betas=(0.0, 0.9))
48 | optG = optim.Adam(self.netG.parameters(), self.lr_G, betas=(0.0, 0.9))
49 |
50 | lr_scheduler = scheduler.LRScheduler(lr_decay='None',
51 | optD=optD,
52 | optG=optG,
53 | num_steps=self.num_steps)
54 |
55 | log_data = metric_log.MetricLog()
56 | for step in range(1, self.num_steps + 1):
57 | lr_scheduler.step(log_data, step)
58 |
59 | assert (self.lr_D == self.get_lr(optD))
60 | assert (self.lr_G == self.get_lr(optG))
61 |
62 | def test_arguments(self):
63 | with pytest.raises(NotImplementedError):
64 | optD = optim.Adam(self.netD.parameters(),
65 | self.lr_D,
66 | betas=(0.0, 0.9))
67 | optG = optim.Adam(self.netG.parameters(),
68 | self.lr_G,
69 | betas=(0.0, 0.9))
70 | scheduler.LRScheduler(lr_decay='does_not_exist',
71 | optD=optD,
72 | optG=optG,
73 | num_steps=self.num_steps)
74 |
75 | # with pytest.
76 |
77 |
78 | if __name__ == "__main__":
79 | test = TestLRScheduler()
80 | test.setup()
81 | test.test_arguments()
82 | test.test_linear_decay()
83 | test.test_no_decay()
84 |
--------------------------------------------------------------------------------
/tests/utils/test_common.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import numpy as np
5 | import torch
6 | from PIL import Image
7 |
8 | from torch_mimicry.utils import common
9 |
10 |
11 | class TestCommon:
12 | def setup(self):
13 | # Build directories
14 | self.log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
15 | "test_log")
16 | if not os.path.exists(self.log_dir):
17 | os.makedirs(self.log_dir)
18 |
19 | def test_json_write_and_load(self):
20 | dict_to_write = dict(a=1, b=2, c=3)
21 | output_file = os.path.join(self.log_dir, 'output.json')
22 | common.write_to_json(dict_to_write, output_file)
23 | check = common.load_from_json(output_file)
24 |
25 | assert dict_to_write == check
26 |
27 | def test_load_and_save_image(self):
28 | image, label = common.load_images(n=1)
29 |
30 | image = torch.squeeze(image, dim=0)
31 | output_file = os.path.join(self.log_dir, 'images', 'test_img.png')
32 | common.save_tensor_image(image, output_file=output_file)
33 |
34 | check = np.array(Image.open(output_file))
35 |
36 | assert check.shape == (32, 32, 3)
37 | assert label.shape == (1, )
38 |
39 | def teardown(self):
40 | shutil.rmtree(self.log_dir)
41 |
42 |
43 | if __name__ == "__main__":
44 | test = TestCommon()
45 | test.setup()
46 | test.test_json_write_and_load()
47 | test.test_load_and_save_image()
48 | test.teardown()
49 |
--------------------------------------------------------------------------------
/torch_mimicry/__init__.py:
--------------------------------------------------------------------------------
1 | from torch_mimicry import nets, training, metrics, datasets, modules
2 |
3 | __version__ = "0.1.16"
4 |
--------------------------------------------------------------------------------
/torch_mimicry/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from . import imagenet
2 | from .data_utils import *
3 | from .image_loader import *
4 |
--------------------------------------------------------------------------------
/torch_mimicry/datasets/imagenet/__init__.py:
--------------------------------------------------------------------------------
1 | from .imagenet import *
2 |
--------------------------------------------------------------------------------
/torch_mimicry/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from . import fid, kid, inception_score, inception_model
2 | from .compute_fid import *
3 | from .compute_is import *
4 | from .compute_kid import *
5 | from .compute_metrics import *
6 |
--------------------------------------------------------------------------------
/torch_mimicry/metrics/fid/__init__.py:
--------------------------------------------------------------------------------
1 | from .fid_utils import *
2 |
--------------------------------------------------------------------------------
/torch_mimicry/metrics/fid/fid_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Helper functions for calculating FID as adopted from the official FID code:
3 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py
4 | """
5 | import numpy as np
6 | from scipy import linalg
7 |
8 | from torch_mimicry.metrics.inception_model import inception_utils
9 |
10 |
11 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
12 | """
13 | Numpy implementation of the Frechet Distance.
14 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
15 | and X_2 ~ N(mu_2, C_2) is
16 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
17 |
18 | Stable version by Dougal J. Sutherland.
19 |
20 | Args:
21 | mu1 : Numpy array containing the activations of the pool_3 layer of the
22 | inception net ( like returned by the function 'get_predictions')
23 | for generated samples.
24 | mu2: The sample mean over activations of the pool_3 layer, precalcualted
25 | on an representive data set.
26 | sigma1 (ndarray): The covariance matrix over activations of the pool_3 layer for
27 | generated samples.
28 | sigma2: The covariance matrix over activations of the pool_3 layer,
29 | precalcualted on an representive data set.
30 |
31 | Returns:
32 | np.float64: The Frechet Distance.
33 | """
34 | if mu1.shape != mu2.shape or sigma1.shape != sigma2.shape:
35 | raise ValueError(
36 | "(mu1, sigma1) should have exactly the same shape as (mu2, sigma2)."
37 | )
38 |
39 | mu1 = np.atleast_1d(mu1)
40 | mu2 = np.atleast_1d(mu2)
41 |
42 | sigma1 = np.atleast_2d(sigma1)
43 | sigma2 = np.atleast_2d(sigma2)
44 |
45 | diff = mu1 - mu2
46 |
47 | # Product might be almost singular
48 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
49 | if not np.isfinite(covmean).all():
50 | print(
51 | "WARNING: fid calculation produces singular product; adding {} to diagonal of cov estimates"
52 | .format(eps))
53 |
54 | offset = np.eye(sigma1.shape[0]) * eps
55 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
56 |
57 | # Numerical error might give slight imaginary component
58 | if np.iscomplexobj(covmean):
59 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
60 | m = np.max(np.abs(covmean.imag))
61 | raise ValueError("Imaginary component {}".format(m))
62 | covmean = covmean.real
63 |
64 | tr_covmean = np.trace(covmean)
65 |
66 | return diff.dot(diff) + np.trace(sigma1) + np.trace(
67 | sigma2) - 2 * tr_covmean
68 |
69 |
70 | def calculate_activation_statistics(images, sess, batch_size=50, verbose=True):
71 | """
72 | Calculation of the statistics used by the FID.
73 |
74 | Args:
75 | images (ndarray): Numpy array of shape (N, H, W, 3) and values in
76 | the range [0, 255].
77 | sess (Session): TensorFlow session object.
78 | batch_size (int): Batch size for inference.
79 | verbose (bool): If True, prints out logging information.
80 |
81 | Returns:
82 | ndarray: Mean of inception features from samples.
83 | ndarray: Covariance of inception features from samples.
84 | """
85 | act = inception_utils.get_activations(images, sess, batch_size, verbose)
86 | mu = np.mean(act, axis=0)
87 | sigma = np.cov(act, rowvar=False)
88 |
89 | return mu, sigma
90 |
--------------------------------------------------------------------------------
/torch_mimicry/metrics/inception_model/__init__.py:
--------------------------------------------------------------------------------
1 | from .inception_utils import *
2 |
--------------------------------------------------------------------------------
/torch_mimicry/metrics/inception_score/__init__.py:
--------------------------------------------------------------------------------
1 | from .inception_score_utils import *
2 |
--------------------------------------------------------------------------------
/torch_mimicry/metrics/kid/__init__.py:
--------------------------------------------------------------------------------
1 | from .kid_utils import *
2 |
--------------------------------------------------------------------------------
/torch_mimicry/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .layers import *
2 | from .losses import *
3 | from .resblocks import *
4 | from .spectral_norm import *
--------------------------------------------------------------------------------
/torch_mimicry/nets/__init__.py:
--------------------------------------------------------------------------------
1 | from . import (basemodel, gan, dcgan, wgan_gp, sngan, cgan_pd, ssgan,
2 | infomax_gan, sagan)
3 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/basemodel/__init__.py:
--------------------------------------------------------------------------------
1 | from .basemodel import *
2 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/basemodel/basemodel.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of BaseModel.
3 | """
4 | import os
5 | from abc import ABC, abstractmethod
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 |
11 | class BaseModel(nn.Module, ABC):
12 | r"""
13 | BaseModel with basic functionalities for checkpointing and restoration.
14 | """
15 | def __init__(self):
16 | super().__init__()
17 |
18 | @abstractmethod
19 | def forward(self, x):
20 | pass
21 |
22 | @property
23 | def device(self):
24 | return next(self.parameters()).device
25 |
26 | def restore_checkpoint(self, ckpt_file, optimizer=None):
27 | r"""
28 | Restores checkpoint from a pth file and restores optimizer state.
29 |
30 | Args:
31 | ckpt_file (str): A PyTorch pth file containing model weights.
32 | optimizer (Optimizer): A vanilla optimizer to have its state restored from.
33 |
34 | Returns:
35 | int: Global step variable where the model was last checkpointed.
36 | """
37 | if not ckpt_file:
38 | raise ValueError("No checkpoint file to be restored.")
39 |
40 | try:
41 | ckpt_dict = torch.load(ckpt_file)
42 | except RuntimeError:
43 | ckpt_dict = torch.load(ckpt_file,
44 | map_location=lambda storage, loc: storage)
45 |
46 | # Restore model weights
47 | self.load_state_dict(ckpt_dict['model_state_dict'])
48 |
49 | # Restore optimizer status if existing. Evaluation doesn't need this
50 | if optimizer:
51 | optimizer.load_state_dict(ckpt_dict['optimizer_state_dict'])
52 |
53 | # Return global step
54 | return ckpt_dict['global_step']
55 |
56 | def save_checkpoint(self,
57 | directory,
58 | global_step,
59 | optimizer=None,
60 | name=None):
61 | r"""
62 | Saves checkpoint at a certain global step during training. Optimizer state
63 | is also saved together.
64 |
65 | Args:
66 | directory (str): Path to save checkpoint to.
67 | global_step (int): The global step variable during training.
68 | optimizer (Optimizer): Optimizer state to be saved concurrently.
69 | name (str): The name to save the checkpoint file as.
70 |
71 | Returns:
72 | None
73 | """
74 | # Create directory to save to
75 | if not os.path.exists(directory):
76 | os.makedirs(directory)
77 |
78 | # Build checkpoint dict to save.
79 | ckpt_dict = {
80 | 'model_state_dict':
81 | self.state_dict(),
82 | 'optimizer_state_dict':
83 | optimizer.state_dict() if optimizer is not None else None,
84 | 'global_step':
85 | global_step
86 | }
87 |
88 | # Save the file with specific name
89 | if name is None:
90 | name = "{}_{}_steps.pth".format(
91 | os.path.basename(directory), # netD or netG
92 | global_step)
93 |
94 | torch.save(ckpt_dict, os.path.join(directory, name))
95 |
96 | def count_params(self):
97 | r"""
98 | Computes the number of parameters in this model.
99 |
100 | Args: None
101 |
102 | Returns:
103 | int: Total number of weight parameters for this model.
104 | int: Total number of trainable parameters for this model.
105 |
106 | """
107 | num_total_params = sum(p.numel() for p in self.parameters())
108 | num_trainable_params = sum(p.numel() for p in self.parameters()
109 | if p.requires_grad)
110 |
111 | return num_total_params, num_trainable_params
112 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/cgan_pd/__init__.py:
--------------------------------------------------------------------------------
1 | from .cgan_pd_32 import *
2 | from .cgan_pd_base import *
--------------------------------------------------------------------------------
/torch_mimicry/nets/cgan_pd/cgan_pd_base.py:
--------------------------------------------------------------------------------
1 | """
2 | Base class definition of cGAN-PD.
3 | """
4 |
5 | from torch_mimicry.nets.gan import cgan
6 |
7 |
8 | class CGANPDBaseGenerator(cgan.BaseConditionalGenerator):
9 | r"""
10 | ResNet backbone generator for cGAN-PD,
11 |
12 | Attributes:
13 | num_classes (int): Number of classes, more than 0 for conditional GANs.
14 | nz (int): Noise dimension for upsampling.
15 | ngf (int): Variable controlling generator feature map sizes.
16 | bottom_width (int): Starting width for upsampling generator output to an image.
17 | loss_type (str): Name of loss to use for GAN loss.
18 | """
19 | def __init__(self,
20 | num_classes,
21 | bottom_width,
22 | nz,
23 | ngf,
24 | loss_type='hinge',
25 | **kwargs):
26 | super().__init__(nz=nz,
27 | ngf=ngf,
28 | bottom_width=bottom_width,
29 | loss_type=loss_type,
30 | num_classes=num_classes,
31 | **kwargs)
32 |
33 |
34 | class CGANPDBaseDiscriminator(cgan.BaseConditionalDiscriminator):
35 | r"""
36 | ResNet backbone discriminator for cGAN-PD.
37 |
38 | Attributes:
39 | num_classes (int): Number of classes, more than 0 for conditional GANs.
40 | ndf (int): Variable controlling discriminator feature map sizes.
41 | loss_type (str): Name of loss to use for GAN loss.
42 | """
43 | def __init__(self, num_classes, ndf, loss_type='hinge', **kwargs):
44 | super().__init__(ndf=ndf,
45 | loss_type=loss_type,
46 | num_classes=num_classes,
47 | **kwargs)
48 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/dcgan/__init__.py:
--------------------------------------------------------------------------------
1 | from .dcgan_128 import *
2 | from .dcgan_32 import *
3 | from .dcgan_48 import *
4 | from .dcgan_64 import *
5 | from .dcgan_base import *
6 | from .dcgan_cifar import *
7 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/dcgan/dcgan_32.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of DCGAN for image size 32.
3 | """
4 | import torch
5 | import torch.nn as nn
6 |
7 | from torch_mimicry.nets.dcgan import dcgan_base
8 | from torch_mimicry.modules.resblocks import DBlockOptimized, DBlock, GBlock
9 |
10 |
11 | class DCGANGenerator32(dcgan_base.DCGANBaseGenerator):
12 | r"""
13 | ResNet backbone generator for ResNet DCGAN.
14 |
15 | Attributes:
16 | nz (int): Noise dimension for upsampling.
17 | ngf (int): Variable controlling generator feature map sizes.
18 | bottom_width (int): Starting width for upsampling generator output to an image.
19 | loss_type (str): Name of loss to use for GAN loss.
20 | """
21 | def __init__(self, nz=128, ngf=256, bottom_width=4, **kwargs):
22 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs)
23 |
24 | # Build the layers
25 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf)
26 | self.block2 = GBlock(self.ngf, self.ngf, upsample=True)
27 | self.block3 = GBlock(self.ngf, self.ngf, upsample=True)
28 | self.block4 = GBlock(self.ngf, self.ngf, upsample=True)
29 | self.b5 = nn.BatchNorm2d(self.ngf)
30 | self.c5 = nn.Conv2d(self.ngf, 3, 3, 1, padding=1)
31 | self.activation = nn.ReLU(True)
32 |
33 | # Initialise the weights
34 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0)
35 | nn.init.xavier_uniform_(self.c5.weight.data, 1.0)
36 |
37 | def forward(self, x):
38 | r"""
39 | Feedforwards a batch of noise vectors into a batch of fake images.
40 |
41 | Args:
42 | x (Tensor): A batch of noise vectors of shape (N, nz).
43 |
44 | Returns:
45 | Tensor: A batch of fake images of shape (N, C, H, W).
46 | """
47 | h = self.l1(x)
48 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width)
49 | h = self.block2(h)
50 | h = self.block3(h)
51 | h = self.block4(h)
52 | h = self.b5(h)
53 | h = self.activation(h)
54 | h = torch.tanh(self.c5(h)) # tanh
55 |
56 | return h
57 |
58 |
59 | class DCGANDiscriminator32(dcgan_base.DCGANBaseDiscriminator):
60 | r"""
61 | ResNet backbone discriminator for ResNet DCGAN.
62 |
63 | Attributes:
64 | ndf (int): Variable controlling discriminator feature map sizes.
65 | loss_type (str): Name of loss to use for GAN loss.
66 | """
67 | def __init__(self, ndf=128, **kwargs):
68 | super().__init__(ndf=ndf, **kwargs)
69 |
70 | # Build layers
71 | self.block1 = DBlockOptimized(3, self.ndf, spectral_norm=False)
72 | self.block2 = DBlock(self.ndf,
73 | self.ndf,
74 | downsample=True,
75 | spectral_norm=False)
76 | self.block3 = DBlock(self.ndf,
77 | self.ndf,
78 | downsample=False,
79 | spectral_norm=False)
80 | self.block4 = DBlock(self.ndf,
81 | self.ndf,
82 | downsample=False,
83 | spectral_norm=False)
84 | self.l5 = nn.Linear(self.ndf, 1)
85 | self.activation = nn.ReLU(True)
86 |
87 | # Initialise the weights
88 | nn.init.xavier_uniform_(self.l5.weight.data, 1.0)
89 |
90 | def forward(self, x):
91 | r"""
92 | Feedforwards a batch of real/fake images and produces a batch of GAN logits.
93 |
94 | Args:
95 | x (Tensor): A batch of images of shape (N, C, H, W).
96 |
97 | Returns:
98 | Tensor: A batch of GAN logits of shape (N, 1).
99 | """
100 | h = x
101 | h = self.block1(h)
102 | h = self.block2(h)
103 | h = self.block3(h)
104 | h = self.block4(h)
105 | h = self.activation(h)
106 |
107 | # Global average pooling
108 | h = torch.sum(h, dim=(2, 3))
109 | output = self.l5(h)
110 |
111 | return output
112 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/dcgan/dcgan_48.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of DCGAN for image size 48.
3 | """
4 | import torch
5 | import torch.nn as nn
6 |
7 | from torch_mimicry.nets.dcgan import dcgan_base
8 | from torch_mimicry.modules.resblocks import DBlockOptimized, DBlock, GBlock
9 |
10 |
11 | class DCGANGenerator48(dcgan_base.DCGANBaseGenerator):
12 | r"""
13 | ResNet backbone generator for ResNet DCGAN.
14 |
15 | Attributes:
16 | nz (int): Noise dimension for upsampling.
17 | ngf (int): Variable controlling generator feature map sizes.
18 | bottom_width (int): Starting width for upsampling generator output to an image.
19 | loss_type (str): Name of loss to use for GAN loss.
20 | """
21 | def __init__(self, nz=128, ngf=512, bottom_width=6, **kwargs):
22 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs)
23 |
24 | # Build the layers
25 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf)
26 | self.block2 = GBlock(self.ngf, self.ngf >> 1, upsample=True)
27 | self.block3 = GBlock(self.ngf >> 1, self.ngf >> 2, upsample=True)
28 | self.block4 = GBlock(self.ngf >> 2, self.ngf >> 3, upsample=True)
29 | self.b5 = nn.BatchNorm2d(self.ngf >> 3)
30 | self.c5 = nn.Conv2d(self.ngf >> 3, 3, 3, 1, padding=1)
31 | self.activation = nn.ReLU(True)
32 |
33 | # Initialise the weights
34 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0)
35 | nn.init.xavier_uniform_(self.c5.weight.data, 1.0)
36 |
37 | def forward(self, x):
38 | r"""
39 | Feedforwards a batch of noise vectors into a batch of fake images.
40 |
41 | Args:
42 | x (Tensor): A batch of noise vectors of shape (N, nz).
43 |
44 | Returns:
45 | Tensor: A batch of fake images of shape (N, C, H, W).
46 | """
47 | h = self.l1(x)
48 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width)
49 | h = self.block2(h)
50 | h = self.block3(h)
51 | h = self.block4(h)
52 | h = self.b5(h)
53 | h = self.activation(h)
54 | h = torch.tanh(self.c5(h))
55 |
56 | return h
57 |
58 |
59 | class DCGANDiscriminator48(dcgan_base.DCGANBaseDiscriminator):
60 | r"""
61 | ResNet backbone discriminator for ResNet DCGAN.
62 |
63 | Attributes:
64 | ndf (int): Variable controlling discriminator feature map sizes.
65 | loss_type (str): Name of loss to use for GAN loss.
66 | """
67 | def __init__(self, ndf=1024, **kwargs):
68 | super().__init__(ndf=ndf, **kwargs)
69 |
70 | # Build layers
71 | self.block1 = DBlockOptimized(3, self.ndf >> 4, spectral_norm=False)
72 | self.block2 = DBlock(self.ndf >> 4,
73 | self.ndf >> 3,
74 | downsample=True,
75 | spectral_norm=False)
76 | self.block3 = DBlock(self.ndf >> 3,
77 | self.ndf >> 2,
78 | downsample=True,
79 | spectral_norm=False)
80 | self.block4 = DBlock(self.ndf >> 2,
81 | self.ndf >> 1,
82 | downsample=True,
83 | spectral_norm=False)
84 | self.block5 = DBlock(self.ndf >> 1,
85 | self.ndf,
86 | downsample=False,
87 | spectral_norm=False)
88 | self.l5 = nn.Linear(self.ndf, 1)
89 |
90 | self.activation = nn.ReLU(True)
91 |
92 | # Initialise the weights
93 | nn.init.xavier_uniform_(self.l5.weight.data, 1.0)
94 |
95 | def forward(self, x):
96 | r"""
97 | Feedforwards a batch of real/fake images and produces a batch of GAN logits.
98 |
99 | Args:
100 | x (Tensor): A batch of images of shape (N, C, H, W).
101 |
102 | Returns:
103 | Tensor: A batch of GAN logits of shape (N, 1).
104 | """
105 | h = x
106 | h = self.block1(h)
107 | h = self.block2(h)
108 | h = self.block3(h)
109 | h = self.block4(h)
110 | h = self.block5(h)
111 | h = self.activation(h)
112 |
113 | # Global average pooling
114 | h = torch.sum(h, dim=(2, 3))
115 | output = self.l5(h)
116 |
117 | return output
118 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/dcgan/dcgan_base.py:
--------------------------------------------------------------------------------
1 | """
2 | Base class definition of DCGAN.
3 | """
4 | from torch_mimicry.nets.gan import gan
5 |
6 |
7 | class DCGANBaseGenerator(gan.BaseGenerator):
8 | r"""
9 | ResNet backbone generator for ResNet DCGAN.
10 |
11 | Attributes:
12 | nz (int): Noise dimension for upsampling.
13 | ngf (int): Variable controlling generator feature map sizes.
14 | bottom_width (int): Starting width for upsampling generator output to an image.
15 | loss_type (str): Name of loss to use for GAN loss.
16 | """
17 | def __init__(self, nz, ngf, bottom_width, loss_type='ns', **kwargs):
18 | super().__init__(nz=nz,
19 | ngf=ngf,
20 | bottom_width=bottom_width,
21 | loss_type=loss_type,
22 | **kwargs)
23 |
24 |
25 | class DCGANBaseDiscriminator(gan.BaseDiscriminator):
26 | r"""
27 | ResNet backbone discriminator for ResNet DCGAN.
28 |
29 | Attributes:
30 | ndf (int): Variable controlling discriminator feature map sizes.
31 | loss_type (str): Name of loss to use for GAN loss.
32 | """
33 | def __init__(self, ndf, loss_type='ns', **kwargs):
34 | super().__init__(ndf=ndf, loss_type=loss_type, **kwargs)
35 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/dcgan/dcgan_cifar.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of DCGAN based on Kurach et al. specifically for CIFAR-10.
3 | The main difference with dcgan_32 is in using sigmoid
4 | as the final activation for the generator instead of tanh.
5 |
6 | To reproduce scores, CIFAR-10 images should not be normalized from -1 to 1, and should
7 | instead have values from 0 to 1, which is the default when loading images as np arrays.
8 | """
9 | import torch
10 | import torch.nn as nn
11 |
12 | from torch_mimicry.nets.dcgan import dcgan_base
13 | from torch_mimicry.modules.resblocks import DBlockOptimized, DBlock, GBlock
14 |
15 |
16 | class DCGANGeneratorCIFAR(dcgan_base.DCGANBaseGenerator):
17 | r"""
18 | ResNet backbone generator for ResNet DCGAN.
19 |
20 | Attributes:
21 | nz (int): Noise dimension for upsampling.
22 | ngf (int): Variable controlling generator feature map sizes.
23 | bottom_width (int): Starting width for upsampling generator output to an image.
24 | loss_type (str): Name of loss to use for GAN loss.
25 | """
26 | def __init__(self, nz=128, ngf=256, bottom_width=4, **kwargs):
27 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs)
28 |
29 | # Build the layers
30 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf)
31 | self.block2 = GBlock(self.ngf, self.ngf, upsample=True)
32 | self.block3 = GBlock(self.ngf, self.ngf, upsample=True)
33 | self.block4 = GBlock(self.ngf, self.ngf, upsample=True)
34 | self.b5 = nn.BatchNorm2d(self.ngf)
35 | self.c5 = nn.Conv2d(self.ngf, 3, 3, 1, padding=1)
36 | self.activation = nn.ReLU(True)
37 |
38 | # Initialise the weights
39 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0)
40 | nn.init.xavier_uniform_(self.c5.weight.data, 1.0)
41 |
42 | def forward(self, x):
43 | r"""
44 | Feedforwards a batch of noise vectors into a batch of fake images.
45 |
46 | Args:
47 | x (Tensor): A batch of noise vectors of shape (N, nz).
48 |
49 | Returns:
50 | Tensor: A batch of fake images of shape (N, C, H, W).
51 | """
52 | h = self.l1(x)
53 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width)
54 | h = self.block2(h)
55 | h = self.block3(h)
56 | h = self.block4(h)
57 | h = self.b5(h)
58 | h = self.activation(h)
59 | h = torch.sigmoid(self.c5(h))
60 |
61 | return h
62 |
63 |
64 | class DCGANDiscriminatorCIFAR(dcgan_base.DCGANBaseDiscriminator):
65 | r"""
66 | ResNet backbone discriminator for ResNet DCGAN.
67 |
68 | Attributes:
69 | ndf (int): Variable controlling discriminator feature map sizes.
70 | loss_type (str): Name of loss to use for GAN loss.
71 | """
72 | def __init__(self, ndf=128, **kwargs):
73 | super().__init__(ndf=ndf, **kwargs)
74 |
75 | # Build layers
76 | self.block1 = DBlockOptimized(3, self.ndf, spectral_norm=False)
77 | self.block2 = DBlock(self.ndf,
78 | self.ndf,
79 | downsample=True,
80 | spectral_norm=False)
81 | self.block3 = DBlock(self.ndf,
82 | self.ndf,
83 | downsample=False,
84 | spectral_norm=False)
85 | self.block4 = DBlock(self.ndf,
86 | self.ndf,
87 | downsample=False,
88 | spectral_norm=False)
89 | self.l5 = nn.Linear(self.ndf, 1)
90 | self.activation = nn.ReLU(True)
91 |
92 | # Initialise the weights
93 | nn.init.xavier_uniform_(self.l5.weight.data, 1.0)
94 |
95 | def forward(self, x):
96 | r"""
97 | Feedforwards a batch of real/fake images and produces a batch of GAN logits.
98 |
99 | Args:
100 | x (Tensor): A batch of images of shape (N, C, H, W).
101 |
102 | Returns:
103 | Tensor: A batch of GAN logits of shape (N, 1).
104 | """
105 | h = x
106 | h = self.block1(h)
107 | h = self.block2(h)
108 | h = self.block3(h)
109 | h = self.block4(h)
110 | h = self.activation(h)
111 |
112 | # Global mean pooling
113 | h = torch.sum(h, dim=(2, 3))
114 | output = self.l5(h)
115 |
116 | return output
117 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/gan/__init__.py:
--------------------------------------------------------------------------------
1 | from .cgan import *
2 | from .gan import *
--------------------------------------------------------------------------------
/torch_mimicry/nets/infomax_gan/__init__.py:
--------------------------------------------------------------------------------
1 | from .infomax_gan_128 import *
2 | from .infomax_gan_32 import *
3 | from .infomax_gan_48 import *
4 | from .infomax_gan_64 import *
5 | from .infomax_gan_base import *
6 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/sagan/__init__.py:
--------------------------------------------------------------------------------
1 | from .sagan_128 import *
2 | from .sagan_32 import *
3 | from .sagan_base import *
4 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/sagan/sagan_base.py:
--------------------------------------------------------------------------------
1 | """
2 | Base implementation of SAGAN.
3 | """
4 | from torch_mimicry.nets.gan import cgan
5 |
6 |
7 | class SAGANBaseGenerator(cgan.BaseConditionalGenerator):
8 | r"""
9 | ResNet backbone generator for cGAN-PD,
10 |
11 | Attributes:
12 | num_classes (int): Number of classes, more than 0 for conditional GANs.
13 | nz (int): Noise dimension for upsampling.
14 | ngf (int): Variable controlling generator feature map sizes.
15 | bottom_width (int): Starting width for upsampling generator output to an image.
16 | loss_type (str): Name of loss to use for GAN loss.
17 | """
18 | def __init__(self,
19 | num_classes,
20 | bottom_width,
21 | nz,
22 | ngf,
23 | loss_type='hinge',
24 | **kwargs):
25 | super().__init__(nz=nz,
26 | ngf=ngf,
27 | bottom_width=bottom_width,
28 | loss_type=loss_type,
29 | num_classes=num_classes,
30 | **kwargs)
31 |
32 |
33 | class SAGANBaseDiscriminator(cgan.BaseConditionalDiscriminator):
34 | r"""
35 | ResNet backbone discriminator for cGAN-PD.
36 |
37 | Attributes:
38 | num_classes (int): Number of classes, more than 0 for conditional GANs.
39 | ndf (int): Variable controlling discriminator feature map sizes.
40 | loss_type (str): Name of loss to use for GAN loss.
41 | """
42 | def __init__(self, num_classes, ndf, loss_type='hinge', **kwargs):
43 | super().__init__(ndf=ndf,
44 | loss_type=loss_type,
45 | num_classes=num_classes,
46 | **kwargs)
47 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/sngan/__init__.py:
--------------------------------------------------------------------------------
1 | from .sngan_128 import *
2 | from .sngan_32 import *
3 | from .sngan_48 import *
4 | from .sngan_64 import *
5 | from .sngan_base import *
6 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/sngan/sngan_128.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of SNGAN for image size 128.
3 | """
4 | import torch
5 | import torch.nn as nn
6 |
7 | from torch_mimicry.modules.layers import SNLinear
8 | from torch_mimicry.modules.resblocks import DBlockOptimized, DBlock, GBlock
9 | from torch_mimicry.nets.sngan import sngan_base
10 |
11 |
12 | class SNGANGenerator128(sngan_base.SNGANBaseGenerator):
13 | r"""
14 | ResNet backbone generator for SNGAN.
15 |
16 | Attributes:
17 | nz (int): Noise dimension for upsampling.
18 | ngf (int): Variable controlling generator feature map sizes.
19 | bottom_width (int): Starting width for upsampling generator output to an image.
20 | loss_type (str): Name of loss to use for GAN loss.
21 | """
22 | def __init__(self, nz=128, ngf=1024, bottom_width=4, **kwargs):
23 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs)
24 |
25 | # Build the layers
26 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf)
27 | self.block2 = GBlock(self.ngf, self.ngf, upsample=True)
28 | self.block3 = GBlock(self.ngf, self.ngf >> 1, upsample=True)
29 | self.block4 = GBlock(self.ngf >> 1, self.ngf >> 2, upsample=True)
30 | self.block5 = GBlock(self.ngf >> 2, self.ngf >> 3, upsample=True)
31 | self.block6 = GBlock(self.ngf >> 3, self.ngf >> 4, upsample=True)
32 | self.b7 = nn.BatchNorm2d(self.ngf >> 4)
33 | self.c7 = nn.Conv2d(self.ngf >> 4, 3, 3, 1, padding=1)
34 | self.activation = nn.ReLU(True)
35 |
36 | # Initialise the weights
37 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0)
38 | nn.init.xavier_uniform_(self.c7.weight.data, 1.0)
39 |
40 | def forward(self, x):
41 | r"""
42 | Feedforwards a batch of noise vectors into a batch of fake images.
43 |
44 | Args:
45 | x (Tensor): A batch of noise vectors of shape (N, nz).
46 |
47 | Returns:
48 | Tensor: A batch of fake images of shape (N, C, H, W).
49 | """
50 | h = self.l1(x)
51 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width)
52 | h = self.block2(h)
53 | h = self.block3(h)
54 | h = self.block4(h)
55 | h = self.block5(h)
56 | h = self.block6(h)
57 | h = self.b7(h)
58 | h = self.activation(h)
59 | h = torch.tanh(self.c7(h))
60 |
61 | return h
62 |
63 |
64 | class SNGANDiscriminator128(sngan_base.SNGANBaseDiscriminator):
65 | r"""
66 | ResNet backbone discriminator for SNGAN.
67 |
68 | Attributes:
69 | ndf (int): Variable controlling discriminator feature map sizes.
70 | loss_type (str): Name of loss to use for GAN loss.
71 | """
72 | def __init__(self, ndf=1024, **kwargs):
73 | super().__init__(ndf=ndf, **kwargs)
74 |
75 | # Build layers
76 | self.block1 = DBlockOptimized(3, self.ndf >> 4)
77 | self.block2 = DBlock(self.ndf >> 4, self.ndf >> 3, downsample=True)
78 | self.block3 = DBlock(self.ndf >> 3, self.ndf >> 2, downsample=True)
79 | self.block4 = DBlock(self.ndf >> 2, self.ndf >> 1, downsample=True)
80 | self.block5 = DBlock(self.ndf >> 1, self.ndf, downsample=True)
81 | self.block6 = DBlock(self.ndf, self.ndf, downsample=False)
82 | self.l7 = SNLinear(self.ndf, 1)
83 | self.activation = nn.ReLU(True)
84 |
85 | # Initialise the weights
86 | nn.init.xavier_uniform_(self.l7.weight.data, 1.0)
87 |
88 | def forward(self, x):
89 | r"""
90 | Feedforwards a batch of real/fake images and produces a batch of GAN logits.
91 |
92 | Args:
93 | x (Tensor): A batch of images of shape (N, C, H, W).
94 |
95 | Returns:
96 | Tensor: A batch of GAN logits of shape (N, 1).
97 | """
98 | h = x
99 | h = self.block1(h)
100 | h = self.block2(h)
101 | h = self.block3(h)
102 | h = self.block4(h)
103 | h = self.block5(h)
104 | h = self.block6(h)
105 | h = self.activation(h)
106 |
107 | # Global sum pooling
108 | h = torch.sum(h, dim=(2, 3))
109 | output = self.l7(h)
110 |
111 | return output
112 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/sngan/sngan_32.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of SNGAN for image size 32.
3 | """
4 | import torch
5 | import torch.nn as nn
6 |
7 | from torch_mimicry.modules.layers import SNLinear
8 | from torch_mimicry.modules.resblocks import DBlockOptimized, DBlock, GBlock
9 | from torch_mimicry.nets.sngan import sngan_base
10 |
11 |
12 | class SNGANGenerator32(sngan_base.SNGANBaseGenerator):
13 | r"""
14 | ResNet backbone generator for SNGAN.
15 |
16 | Attributes:
17 | nz (int): Noise dimension for upsampling.
18 | ngf (int): Variable controlling generator feature map sizes.
19 | bottom_width (int): Starting width for upsampling generator output to an image.
20 | loss_type (str): Name of loss to use for GAN loss.
21 | """
22 | def __init__(self, nz=128, ngf=256, bottom_width=4, **kwargs):
23 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs)
24 |
25 | # Build the layers
26 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf)
27 | self.block2 = GBlock(self.ngf, self.ngf, upsample=True)
28 | self.block3 = GBlock(self.ngf, self.ngf, upsample=True)
29 | self.block4 = GBlock(self.ngf, self.ngf, upsample=True)
30 | self.b5 = nn.BatchNorm2d(self.ngf)
31 | self.c5 = nn.Conv2d(self.ngf, 3, 3, 1, padding=1)
32 | self.activation = nn.ReLU(True)
33 |
34 | # Initialise the weights
35 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0)
36 | nn.init.xavier_uniform_(self.c5.weight.data, 1.0)
37 |
38 | def forward(self, x):
39 | r"""
40 | Feedforwards a batch of noise vectors into a batch of fake images.
41 |
42 | Args:
43 | x (Tensor): A batch of noise vectors of shape (N, nz).
44 |
45 | Returns:
46 | Tensor: A batch of fake images of shape (N, C, H, W).
47 | """
48 | h = self.l1(x)
49 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width)
50 | h = self.block2(h)
51 | h = self.block3(h)
52 | h = self.block4(h)
53 | h = self.b5(h)
54 | h = self.activation(h)
55 | h = torch.tanh(self.c5(h))
56 |
57 | return h
58 |
59 |
60 | class SNGANDiscriminator32(sngan_base.SNGANBaseDiscriminator):
61 | r"""
62 | ResNet backbone discriminator for SNGAN.
63 |
64 | Attributes:
65 | ndf (int): Variable controlling discriminator feature map sizes.
66 | loss_type (str): Name of loss to use for GAN loss.
67 | """
68 | def __init__(self, ndf=128, **kwargs):
69 | super().__init__(ndf=ndf, **kwargs)
70 |
71 | # Build layers
72 | self.block1 = DBlockOptimized(3, self.ndf)
73 | self.block2 = DBlock(self.ndf, self.ndf, downsample=True)
74 | self.block3 = DBlock(self.ndf, self.ndf, downsample=False)
75 | self.block4 = DBlock(self.ndf, self.ndf, downsample=False)
76 | self.l5 = SNLinear(self.ndf, 1)
77 | self.activation = nn.ReLU(True)
78 |
79 | # Initialise the weights
80 | nn.init.xavier_uniform_(self.l5.weight.data, 1.0)
81 |
82 | def forward(self, x):
83 | r"""
84 | Feedforwards a batch of real/fake images and produces a batch of GAN logits.
85 |
86 | Args:
87 | x (Tensor): A batch of images of shape (N, C, H, W).
88 |
89 | Returns:
90 | Tensor: A batch of GAN logits of shape (N, 1).
91 | """
92 | h = x
93 | h = self.block1(h)
94 | h = self.block2(h)
95 | h = self.block3(h)
96 | h = self.block4(h)
97 | h = self.activation(h)
98 |
99 | # Global sum pooling
100 | h = torch.sum(h, dim=(2, 3))
101 | output = self.l5(h)
102 |
103 | return output
104 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/sngan/sngan_48.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of SNGAN for image size 48.
3 | """
4 | import torch
5 | import torch.nn as nn
6 |
7 | from torch_mimicry.modules.layers import SNLinear
8 | from torch_mimicry.modules.resblocks import DBlockOptimized, DBlock, GBlock
9 | from torch_mimicry.nets.sngan import sngan_base
10 |
11 |
12 | class SNGANGenerator48(sngan_base.SNGANBaseGenerator):
13 | r"""
14 | ResNet backbone generator for SNGAN.
15 |
16 | Attributes:
17 | nz (int): Noise dimension for upsampling.
18 | ngf (int): Variable controlling generator feature map sizes.
19 | bottom_width (int): Starting width for upsampling generator output to an image.
20 | loss_type (str): Name of loss to use for GAN loss.
21 | """
22 | def __init__(self, nz=128, ngf=512, bottom_width=6, **kwargs):
23 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs)
24 |
25 | # Build the layers
26 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf)
27 | self.block2 = GBlock(self.ngf, self.ngf >> 1, upsample=True)
28 | self.block3 = GBlock(self.ngf >> 1, self.ngf >> 2, upsample=True)
29 | self.block4 = GBlock(self.ngf >> 2, self.ngf >> 3, upsample=True)
30 | self.b5 = nn.BatchNorm2d(self.ngf >> 3)
31 | self.c5 = nn.Conv2d(self.ngf >> 3, 3, 3, 1, padding=1)
32 | self.activation = nn.ReLU(True)
33 |
34 | # Initialise the weights
35 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0)
36 | nn.init.xavier_uniform_(self.c5.weight.data, 1.0)
37 |
38 | def forward(self, x):
39 | r"""
40 | Feedforwards a batch of noise vectors into a batch of fake images.
41 |
42 | Args:
43 | x (Tensor): A batch of noise vectors of shape (N, nz).
44 |
45 | Returns:
46 | Tensor: A batch of fake images of shape (N, C, H, W).
47 | """
48 | h = self.l1(x)
49 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width)
50 | h = self.block2(h)
51 | h = self.block3(h)
52 | h = self.block4(h)
53 | h = self.b5(h)
54 | h = self.activation(h)
55 | h = torch.tanh(self.c5(h))
56 |
57 | return h
58 |
59 |
60 | class SNGANDiscriminator48(sngan_base.SNGANBaseDiscriminator):
61 | r"""
62 | ResNet backbone discriminator for SNGAN.
63 |
64 | Attribates:
65 | ndf (int): Variable controlling discriminator feature map sizes.
66 | loss_type (str): Name of loss to use for GAN loss.
67 | """
68 | def __init__(self, ndf=1024, **kwargs):
69 | super().__init__(ndf=ndf, **kwargs)
70 |
71 | # Build layers
72 | self.block1 = DBlockOptimized(3, self.ndf >> 4)
73 | self.block2 = DBlock(self.ndf >> 4, self.ndf >> 3, downsample=True)
74 | self.block3 = DBlock(self.ndf >> 3, self.ndf >> 2, downsample=True)
75 | self.block4 = DBlock(self.ndf >> 2, self.ndf >> 1, downsample=True)
76 | self.block5 = DBlock(self.ndf >> 1, self.ndf, downsample=False)
77 | self.l5 = SNLinear(self.ndf, 1)
78 |
79 | self.activation = nn.ReLU(True)
80 |
81 | # Initialise the weights
82 | nn.init.xavier_uniform_(self.l5.weight.data, 1.0)
83 |
84 | def forward(self, x):
85 | r"""
86 | Feedforwards a batch of real/fake images and produces a batch of GAN logits.
87 |
88 | Args:
89 | x (Tensor): A batch of images of shape (N, C, H, W).
90 |
91 | Returns:
92 | Tensor: A batch of GAN logits of shape (N, 1).
93 | """
94 | h = x
95 | h = self.block1(h)
96 | h = self.block2(h)
97 | h = self.block3(h)
98 | h = self.block4(h)
99 | h = self.block5(h)
100 | h = self.activation(h)
101 |
102 | # Global sum pooling
103 | h = torch.sum(h, dim=(2, 3))
104 | output = self.l5(h)
105 |
106 | return output
107 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/sngan/sngan_64.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of SNGAN for image size 64.
3 | """
4 | import torch
5 | import torch.nn as nn
6 |
7 | from torch_mimicry.modules.layers import SNLinear
8 | from torch_mimicry.modules.resblocks import DBlockOptimized, DBlock, GBlock
9 | from torch_mimicry.nets.sngan import sngan_base
10 |
11 |
12 | class SNGANGenerator64(sngan_base.SNGANBaseGenerator):
13 | r"""
14 | ResNet backbone generator for SNGAN.
15 |
16 | Attributes:
17 | nz (int): Noise dimension for upsampling.
18 | ngf (int): Variable controlling generator feature map sizes.
19 | bottom_width (int): Starting width for upsampling generator output to an image.
20 | loss_type (str): Name of loss to use for GAN loss.
21 | """
22 | def __init__(self, nz=128, ngf=1024, bottom_width=4, **kwargs):
23 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs)
24 |
25 | # Build the layers
26 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf)
27 | self.block2 = GBlock(self.ngf, self.ngf >> 1, upsample=True)
28 | self.block3 = GBlock(self.ngf >> 1, self.ngf >> 2, upsample=True)
29 | self.block4 = GBlock(self.ngf >> 2, self.ngf >> 3, upsample=True)
30 | self.block5 = GBlock(self.ngf >> 3, self.ngf >> 4, upsample=True)
31 | self.b6 = nn.BatchNorm2d(self.ngf >> 4)
32 | self.c6 = nn.Conv2d(self.ngf >> 4, 3, 3, 1, padding=1)
33 | self.activation = nn.ReLU(True)
34 |
35 | # Initialise the weights
36 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0)
37 | nn.init.xavier_uniform_(self.c6.weight.data, 1.0)
38 |
39 | def forward(self, x):
40 | r"""
41 | Feedforwards a batch of noise vectors into a batch of fake images.
42 |
43 | Args:
44 | x (Tensor): A batch of noise vectors of shape (N, nz).
45 |
46 | Returns:
47 | Tensor: A batch of fake images of shape (N, C, H, W).
48 | """
49 | h = self.l1(x)
50 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width)
51 | h = self.block2(h)
52 | h = self.block3(h)
53 | h = self.block4(h)
54 | h = self.block5(h)
55 | h = self.b6(h)
56 | h = self.activation(h)
57 | h = torch.tanh(self.c6(h))
58 |
59 | return h
60 |
61 |
62 | class SNGANDiscriminator64(sngan_base.SNGANBaseDiscriminator):
63 | r"""
64 | ResNet backbone discriminator for SNGAN.
65 |
66 | Attributes:
67 | ndf (int): Variable controlling discriminator feature map sizes.
68 | loss_type (str): Name of loss to use for GAN loss.
69 | """
70 | def __init__(self, ndf=1024, **kwargs):
71 | super().__init__(ndf=ndf, **kwargs)
72 |
73 | # Build layers
74 | self.block1 = DBlockOptimized(3, self.ndf >> 4)
75 | self.block2 = DBlock(self.ndf >> 4, self.ndf >> 3, downsample=True)
76 | self.block3 = DBlock(self.ndf >> 3, self.ndf >> 2, downsample=True)
77 | self.block4 = DBlock(self.ndf >> 2, self.ndf >> 1, downsample=True)
78 | self.block5 = DBlock(self.ndf >> 1, self.ndf, downsample=True)
79 | self.l6 = SNLinear(self.ndf, 1)
80 | self.activation = nn.ReLU(True)
81 |
82 | # Initialise the weights
83 | nn.init.xavier_uniform_(self.l6.weight.data, 1.0)
84 |
85 | def forward(self, x):
86 | r"""
87 | Feedforwards a batch of real/fake images and produces a batch of GAN logits.
88 |
89 | Args:
90 | x (Tensor): A batch of images of shape (N, C, H, W).
91 |
92 | Returns:
93 | Tensor: A batch of GAN logits of shape (N, 1).
94 | """
95 | h = x
96 | h = self.block1(h)
97 | h = self.block2(h)
98 | h = self.block3(h)
99 | h = self.block4(h)
100 | h = self.block5(h)
101 | h = self.activation(h)
102 |
103 | # Global sum pooling
104 | h = torch.sum(h, dim=(2, 3))
105 | output = self.l6(h)
106 |
107 | return output
108 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/sngan/sngan_base.py:
--------------------------------------------------------------------------------
1 | """
2 | Base implementation of SNGAN with default variables.
3 | """
4 | from torch_mimicry.nets.gan import gan
5 |
6 |
7 | class SNGANBaseGenerator(gan.BaseGenerator):
8 | r"""
9 | ResNet backbone generator for SNGAN.
10 |
11 | Attributes:
12 | nz (int): Noise dimension for upsampling.
13 | ngf (int): Variable controlling generator feature map sizes.
14 | bottom_width (int): Starting width for upsampling generator output to an image.
15 | loss_type (str): Name of loss to use for GAN loss.
16 | """
17 | def __init__(self, nz, ngf, bottom_width, loss_type='hinge', **kwargs):
18 | super().__init__(nz=nz,
19 | ngf=ngf,
20 | bottom_width=bottom_width,
21 | loss_type=loss_type,
22 | **kwargs)
23 |
24 |
25 | class SNGANBaseDiscriminator(gan.BaseDiscriminator):
26 | r"""
27 | ResNet backbone discriminator for SNGAN.
28 |
29 | Attributes:
30 | ndf (int): Variable controlling discriminator feature map sizes.
31 | """
32 | def __init__(self, ndf, loss_type='hinge', **kwargs):
33 | super().__init__(ndf=ndf, loss_type=loss_type, **kwargs)
34 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/ssgan/__init__.py:
--------------------------------------------------------------------------------
1 | from .ssgan_128 import *
2 | from .ssgan_32 import *
3 | from .ssgan_48 import *
4 | from .ssgan_64 import *
5 | from .ssgan_base import *
6 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/ssgan/ssgan_32.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of SSGAN for image size 32.
3 | """
4 | import torch
5 | import torch.nn as nn
6 |
7 | from torch_mimicry.modules import SNLinear
8 | from torch_mimicry.modules.resblocks import DBlockOptimized, DBlock, GBlock
9 | from torch_mimicry.nets.ssgan import ssgan_base
10 |
11 |
12 | class SSGANGenerator32(ssgan_base.SSGANBaseGenerator):
13 | r"""
14 | ResNet backbone generator for SSGAN.
15 |
16 | Attributes:
17 | nz (int): Noise dimension for upsampling.
18 | ngf (int): Variable controlling generator feature map sizes.
19 | bottom_width (int): Starting width for upsampling generator output to an image.
20 | loss_type (str): Name of loss to use for GAN loss.
21 | ss_loss_scale (float): Self-supervised loss scale for generator.
22 | """
23 | def __init__(self, nz=128, ngf=256, bottom_width=4, **kwargs):
24 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs)
25 |
26 | # Build the layers
27 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf)
28 | self.block2 = GBlock(self.ngf, self.ngf, upsample=True)
29 | self.block3 = GBlock(self.ngf, self.ngf, upsample=True)
30 | self.block4 = GBlock(self.ngf, self.ngf, upsample=True)
31 | self.b5 = nn.BatchNorm2d(self.ngf)
32 | self.c5 = nn.Conv2d(ngf, 3, 3, 1, padding=1)
33 | self.activation = nn.ReLU(True)
34 |
35 | # Initialise the weights
36 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0)
37 | nn.init.xavier_uniform_(self.c5.weight.data, 1.0)
38 |
39 | def forward(self, x):
40 | r"""
41 | Feedforwards a batch of noise vectors into a batch of fake images.
42 |
43 | Args:
44 | x (Tensor): A batch of noise vectors of shape (N, nz).
45 |
46 | Returns:
47 | Tensor: A batch of fake images of shape (N, C, H, W).
48 | """
49 | h = self.l1(x)
50 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width)
51 | h = self.block2(h)
52 | h = self.block3(h)
53 | h = self.block4(h)
54 | h = self.b5(h)
55 | h = self.activation(h)
56 | h = torch.tanh(self.c5(h))
57 |
58 | return h
59 |
60 |
61 | class SSGANDiscriminator32(ssgan_base.SSGANBaseDiscriminator):
62 | r"""
63 | ResNet backbone discriminator for SSGAN.
64 |
65 | Attributes:
66 | ndf (int): Variable controlling discriminator feature map sizes.
67 | loss_type (str): Name of loss to use for GAN loss.
68 | ss_loss_scale (float): Self-supervised loss scale for discriminator.
69 | """
70 | def __init__(self, ndf=128, **kwargs):
71 | super().__init__(ndf=ndf, **kwargs)
72 |
73 | # Build layers
74 | self.block1 = DBlockOptimized(3, self.ndf)
75 | self.block2 = DBlock(self.ndf, self.ndf, downsample=True)
76 | self.block3 = DBlock(self.ndf, self.ndf, downsample=False)
77 | self.block4 = DBlock(self.ndf, self.ndf, downsample=False)
78 | self.l5 = SNLinear(self.ndf, 1)
79 |
80 | # Rotation class prediction layer
81 | self.l_y = SNLinear(self.ndf, self.num_classes)
82 |
83 | # Initialise the weights
84 | nn.init.xavier_uniform_(self.l5.weight.data, 1.0)
85 | nn.init.xavier_uniform_(self.l_y.weight.data, 1.0)
86 |
87 | self.activation = nn.ReLU(True)
88 |
89 | def forward(self, x):
90 | r"""
91 | Feedforwards a batch of real/fake images and produces a batch of GAN logits,
92 | and rotation classes.
93 |
94 | Args:
95 | x (Tensor): A batch of images of shape (N, C, H, W).
96 |
97 | Returns:
98 | Tensor: A batch of GAN logits of shape (N, 1).
99 | Tensor: A batch of predicted classes of shape (N, num_classes).
100 | """
101 | h = x
102 | h = self.block1(h)
103 | h = self.block2(h)
104 | h = self.block3(h)
105 | h = self.block4(h)
106 | h = self.activation(h)
107 |
108 | # Global sum pooling
109 | h = torch.sum(h, dim=(2, 3))
110 | output = self.l5(h)
111 |
112 | # Produce the class output logits
113 | output_classes = self.l_y(h)
114 |
115 | return output, output_classes
116 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/wgan_gp/__init__.py:
--------------------------------------------------------------------------------
1 | from .wgan_gp_128 import *
2 | from .wgan_gp_32 import *
3 | from .wgan_gp_48 import *
4 | from .wgan_gp_64 import *
5 | from .wgan_gp_base import *
6 | from .wgan_gp_resblocks import *
7 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/wgan_gp/wgan_gp_128.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of WGAN-GP for image size 128.
3 | """
4 | import torch
5 | import torch.nn as nn
6 |
7 | from torch_mimicry.nets.wgan_gp import wgan_gp_base
8 | from torch_mimicry.nets.wgan_gp.wgan_gp_resblocks import DBlockOptimized, DBlock, GBlock
9 |
10 |
11 | class WGANGPGenerator128(wgan_gp_base.WGANGPBaseGenerator):
12 | r"""
13 | ResNet backbone generator for WGAN-GP.
14 |
15 | Attributes:
16 | nz (int): Noise dimension for upsampling.
17 | ngf (int): Variable controlling generator feature map sizes.
18 | bottom_width (int): Starting width for upsampling generator output to an image.
19 | loss_type (str): Name of loss to use for GAN loss.
20 | """
21 | def __init__(self, nz=128, ngf=1024, bottom_width=4, **kwargs):
22 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs)
23 |
24 | # Build the layers
25 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf)
26 | self.block2 = GBlock(self.ngf, self.ngf, upsample=True)
27 | self.block3 = GBlock(self.ngf, self.ngf >> 1, upsample=True)
28 | self.block4 = GBlock(self.ngf >> 1, self.ngf >> 2, upsample=True)
29 | self.block5 = GBlock(self.ngf >> 2, self.ngf >> 3, upsample=True)
30 | self.block6 = GBlock(self.ngf >> 3, self.ngf >> 4, upsample=True)
31 | self.b7 = nn.BatchNorm2d(self.ngf >> 4)
32 | self.c7 = nn.Conv2d(self.ngf >> 4, 3, 3, 1, padding=1)
33 | self.activation = nn.ReLU(True)
34 |
35 | # Initialise the weights
36 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0)
37 |
38 | def forward(self, x):
39 | r"""
40 | Feedforwards a batch of noise vectors into a batch of fake images.
41 |
42 | Args:
43 | x (Tensor): A batch of noise vectors of shape (N, nz).
44 |
45 | Returns:
46 | Tensor: A batch of fake images of shape (N, C, H, W).
47 | """
48 | h = self.l1(x)
49 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width)
50 | h = self.block2(h)
51 | h = self.block3(h)
52 | h = self.block4(h)
53 | h = self.block5(h)
54 | h = self.block6(h)
55 | h = self.b7(h)
56 | h = self.activation(h)
57 | h = torch.tanh(self.c7(h))
58 |
59 | return h
60 |
61 |
62 | class WGANGPDiscriminator128(wgan_gp_base.WGANGPBaseDiscriminator):
63 | r"""
64 | ResNet backbone discriminator for WGAN-GP.
65 |
66 | Attributes:
67 | ndf (int): Variable controlling discriminator feature map sizes.
68 | loss_type (str): Name of loss to use for GAN loss.
69 | gp_scale (float): Lamda parameter for gradient penalty.
70 | """
71 | def __init__(self, ndf=1024, **kwargs):
72 | super().__init__(ndf=ndf, **kwargs)
73 |
74 | # Build layers
75 | self.block1 = DBlockOptimized(3, self.ndf >> 4)
76 | self.block2 = DBlock(self.ndf >> 4, self.ndf >> 3, downsample=True)
77 | self.block3 = DBlock(self.ndf >> 3, self.ndf >> 2, downsample=True)
78 | self.block4 = DBlock(self.ndf >> 2, self.ndf >> 1, downsample=True)
79 | self.block5 = DBlock(self.ndf >> 1, self.ndf, downsample=True)
80 | self.block6 = DBlock(self.ndf, self.ndf, downsample=False)
81 | self.l7 = nn.Linear(self.ndf, 1)
82 | self.activation = nn.ReLU(True)
83 |
84 | # Initialise the weights
85 | nn.init.xavier_uniform_(self.l7.weight.data, 1.0)
86 |
87 | def forward(self, x):
88 | r"""
89 | Feedforwards a batch of real/fake images and produces a batch of GAN logits.
90 |
91 | Args:
92 | x (Tensor): A batch of images of shape (N, C, H, W).
93 |
94 | Returns:
95 | Tensor: A batch of GAN logits of shape (N, 1).
96 | """
97 | h = x
98 | h = self.block1(h)
99 | h = self.block2(h)
100 | h = self.block3(h)
101 | h = self.block4(h)
102 | h = self.block5(h)
103 | h = self.block6(h)
104 | h = self.activation(h)
105 |
106 | # Global average pooling
107 | h = torch.mean(h, dim=(2, 3)) # WGAN uses mean pooling
108 | output = self.l7(h)
109 |
110 | return output
111 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/wgan_gp/wgan_gp_32.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of WGAN-GP for image size 32.
3 | """
4 | import torch
5 | import torch.nn as nn
6 |
7 | from torch_mimicry.nets.wgan_gp import wgan_gp_base
8 | from torch_mimicry.nets.wgan_gp.wgan_gp_resblocks import DBlockOptimized, DBlock, GBlock
9 |
10 |
11 | class WGANGPGenerator32(wgan_gp_base.WGANGPBaseGenerator):
12 | r"""
13 | ResNet backbone generator for WGAN-GP.
14 |
15 | Attributes:
16 | nz (int): Noise dimension for upsampling.
17 | ngf (int): Variable controlling generator feature map sizes.
18 | bottom_width (int): Starting width for upsampling generator output to an image.
19 | loss_type (str): Name of loss to use for GAN loss.
20 | """
21 | def __init__(self, nz=128, ngf=256, bottom_width=4, **kwargs):
22 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs)
23 |
24 | # Build the layers
25 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf)
26 | self.block2 = GBlock(self.ngf, self.ngf, upsample=True)
27 | self.block3 = GBlock(self.ngf, self.ngf, upsample=True)
28 | self.block4 = GBlock(self.ngf, self.ngf, upsample=True)
29 | self.b5 = nn.BatchNorm2d(self.ngf)
30 | self.c5 = nn.Conv2d(self.ngf, 3, 3, 1, padding=1)
31 | self.activation = nn.ReLU(True)
32 |
33 | # Initialise the weights
34 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0)
35 |
36 | def forward(self, x):
37 | r"""
38 | Feedforwards a batch of noise vectors into a batch of fake images.
39 |
40 | Args:
41 | x (Tensor): A batch of noise vectors of shape (N, nz).
42 |
43 | Returns:
44 | Tensor: A batch of fake images of shape (N, C, H, W).
45 | """
46 | h = self.l1(x)
47 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width)
48 | h = self.block2(h)
49 | h = self.block3(h)
50 | h = self.block4(h)
51 | h = self.b5(h)
52 | h = self.activation(h)
53 | h = torch.tanh(self.c5(h))
54 |
55 | return h
56 |
57 |
58 | class WGANGPDiscriminator32(wgan_gp_base.WGANGPBaseDiscriminator):
59 | r"""
60 | ResNet backbone discriminator for WGAN-GP.
61 |
62 | Attributes:
63 | ndf (int): Variable controlling discriminator feature map sizes.
64 | loss_type (str): Name of loss to use for GAN loss.
65 | gp_scale (float): Lamda parameter for gradient penalty.
66 | """
67 | def __init__(self, ndf=128, **kwargs):
68 | super().__init__(ndf=ndf, **kwargs)
69 |
70 | # Build layers
71 | self.block1 = DBlockOptimized(3, self.ndf)
72 | self.block2 = DBlock(self.ndf, self.ndf, downsample=True)
73 | self.block3 = DBlock(self.ndf, self.ndf, downsample=False)
74 | self.block4 = DBlock(self.ndf, self.ndf, downsample=False)
75 | self.l5 = nn.Linear(self.ndf, 1)
76 |
77 | self.activation = nn.ReLU(True)
78 |
79 | # Initialise the weights
80 | nn.init.xavier_uniform_(self.l5.weight.data, 1.0)
81 |
82 | def forward(self, x):
83 | r"""
84 | Feedforwards a batch of real/fake images and produces a batch of GAN logits.
85 |
86 | Args:
87 | x (Tensor): A batch of images of shape (N, C, H, W).
88 |
89 | Returns:
90 | Tensor: A batch of GAN logits of shape (N, 1).
91 | """
92 | h = x
93 | h = self.block1(h)
94 | h = self.block2(h)
95 | h = self.block3(h)
96 | h = self.block4(h)
97 | h = self.activation(h)
98 |
99 | # Global average pooling
100 | h = torch.mean(h, dim=(2, 3)) # WGAN uses mean pooling
101 | output = self.l5(h)
102 |
103 | return output
104 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/wgan_gp/wgan_gp_48.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of WGAN-GP for image size 48.
3 | """
4 | import torch
5 | import torch.nn as nn
6 |
7 | from torch_mimicry.nets.wgan_gp import wgan_gp_base
8 | from torch_mimicry.nets.wgan_gp.wgan_gp_resblocks import DBlockOptimized, DBlock, GBlock
9 |
10 |
11 | class WGANGPGenerator48(wgan_gp_base.WGANGPBaseGenerator):
12 | r"""
13 | ResNet backbone generator for WGAN-GP.
14 |
15 | Attributes:
16 | nz (int): Noise dimension for upsampling.
17 | ngf (int): Variable controlling generator feature map sizes.
18 | bottom_width (int): Starting width for upsampling generator output to an image.
19 | loss_type (str): Name of loss to use for GAN loss.
20 | """
21 | def __init__(self, nz=128, ngf=512, bottom_width=6, **kwargs):
22 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs)
23 |
24 | # Build the layers
25 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf)
26 | self.block2 = GBlock(self.ngf, self.ngf >> 1, upsample=True)
27 | self.block3 = GBlock(self.ngf >> 1, self.ngf >> 2, upsample=True)
28 | self.block4 = GBlock(self.ngf >> 2, self.ngf >> 3, upsample=True)
29 | self.b5 = nn.BatchNorm2d(self.ngf >> 3)
30 | self.c5 = nn.Conv2d(self.ngf >> 3, 3, 3, 1, padding=1)
31 | self.activation = nn.ReLU(True)
32 |
33 | # Initialise the weights
34 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0)
35 |
36 | def forward(self, x):
37 | r"""
38 | Feedforwards a batch of noise vectors into a batch of fake images.
39 |
40 | Args:
41 | x (Tensor): A batch of noise vectors of shape (N, nz).
42 |
43 | Returns:
44 | Tensor: A batch of fake images of shape (N, C, H, W).
45 | """
46 | h = self.l1(x)
47 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width)
48 | h = self.block2(h)
49 | h = self.block3(h)
50 | h = self.block4(h)
51 | h = self.b5(h)
52 | h = self.activation(h)
53 | h = torch.tanh(self.c5(h))
54 |
55 | return h
56 |
57 |
58 | class WGANGPDiscriminator48(wgan_gp_base.WGANGPBaseDiscriminator):
59 | r"""
60 | ResNet backbone discriminator for WGAN-GP.
61 |
62 | Attributes:
63 | ndf (int): Variable controlling discriminator feature map sizes.
64 | loss_type (str): Name of loss to use for GAN loss.
65 | gp_scale (float): Lamda parameter for gradient penalty.
66 | """
67 | def __init__(self, ndf=1024, **kwargs):
68 | super().__init__(ndf=ndf, **kwargs)
69 |
70 | # Build layers
71 | self.block1 = DBlockOptimized(3, self.ndf >> 4)
72 | self.block2 = DBlock(self.ndf >> 4, self.ndf >> 3, downsample=True)
73 | self.block3 = DBlock(self.ndf >> 3, self.ndf >> 2, downsample=True)
74 | self.block4 = DBlock(self.ndf >> 2, self.ndf >> 1, downsample=True)
75 | self.block5 = DBlock(self.ndf >> 1, self.ndf, downsample=False)
76 | self.l6 = nn.Linear(self.ndf, 1)
77 |
78 | self.activation = nn.ReLU(True)
79 |
80 | # Initialise the weights
81 | nn.init.xavier_uniform_(self.l6.weight.data, 1.0)
82 |
83 | def forward(self, x):
84 | r"""
85 | Feedforwards a batch of real/fake images and produces a batch of GAN logits.
86 |
87 | Args:
88 | x (Tensor): A batch of images of shape (N, C, H, W).
89 |
90 | Returns:
91 | Tensor: A batch of GAN logits of shape (N, 1).
92 | """
93 | h = x
94 | h = self.block1(h)
95 | h = self.block2(h)
96 | h = self.block3(h)
97 | h = self.block4(h)
98 | h = self.block5(h)
99 | h = self.activation(h)
100 |
101 | # Global average pooling
102 | h = torch.mean(h, dim=(2, 3)) # WGAN uses mean pooling
103 | output = self.l6(h)
104 |
105 | return output
106 |
--------------------------------------------------------------------------------
/torch_mimicry/nets/wgan_gp/wgan_gp_64.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of WGAN-GP for image size 64.
3 | """
4 | import torch
5 | import torch.nn as nn
6 |
7 | from torch_mimicry.nets.wgan_gp import wgan_gp_base
8 | from torch_mimicry.nets.wgan_gp.wgan_gp_resblocks import DBlockOptimized, DBlock, GBlock
9 |
10 |
11 | class WGANGPGenerator64(wgan_gp_base.WGANGPBaseGenerator):
12 | r"""
13 | ResNet backbone generator for WGAN-GP.
14 |
15 | Attributes:
16 | nz (int): Noise dimension for upsampling.
17 | ngf (int): Variable controlling generator feature map sizes.
18 | bottom_width (int): Starting width for upsampling generator output to an image.
19 | loss_type (str): Name of loss to use for GAN loss.
20 | """
21 | def __init__(self, nz=128, ngf=1024, bottom_width=4, **kwargs):
22 | super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, **kwargs)
23 |
24 | # Build the layers
25 | self.l1 = nn.Linear(self.nz, (self.bottom_width**2) * self.ngf)
26 | self.block2 = GBlock(self.ngf, self.ngf >> 1, upsample=True)
27 | self.block3 = GBlock(self.ngf >> 1, self.ngf >> 2, upsample=True)
28 | self.block4 = GBlock(self.ngf >> 2, self.ngf >> 3, upsample=True)
29 | self.block5 = GBlock(self.ngf >> 3, self.ngf >> 4, upsample=True)
30 | self.b6 = nn.BatchNorm2d(self.ngf >> 4)
31 | self.c6 = nn.Conv2d(self.ngf >> 4, 3, 3, 1, padding=1)
32 | self.activation = nn.ReLU(True)
33 |
34 | # Initialise the weights
35 | nn.init.xavier_uniform_(self.l1.weight.data, 1.0)
36 |
37 | def forward(self, x):
38 | r"""
39 | Feedforwards a batch of noise vectors into a batch of fake images.
40 |
41 | Args:
42 | x (Tensor): A batch of noise vectors of shape (N, nz).
43 |
44 | Returns:
45 | Tensor: A batch of fake images of shape (N, C, H, W).
46 | """
47 | h = self.l1(x)
48 | h = h.view(x.shape[0], -1, self.bottom_width, self.bottom_width)
49 | h = self.block2(h)
50 | h = self.block3(h)
51 | h = self.block4(h)
52 | h = self.block5(h)
53 | h = self.b6(h)
54 | h = self.activation(h)
55 | h = torch.tanh(self.c6(h))
56 |
57 | return h
58 |
59 |
60 | class WGANGPDiscriminator64(wgan_gp_base.WGANGPBaseDiscriminator):
61 | r"""
62 | ResNet backbone discriminator for WGAN-GP.
63 |
64 | Attributes:
65 | ndf (int): Variable controlling discriminator feature map sizes.
66 | loss_type (str): Name of loss to use for GAN loss.
67 | gp_scale (float): Lamda parameter for gradient penalty.
68 | """
69 | def __init__(self, ndf=1024, **kwargs):
70 | super().__init__(ndf=ndf, **kwargs)
71 |
72 | # Build layers
73 | self.block1 = DBlockOptimized(3, self.ndf >> 4)
74 | self.block2 = DBlock(self.ndf >> 4, self.ndf >> 3, downsample=True)
75 | self.block3 = DBlock(self.ndf >> 3, self.ndf >> 2, downsample=True)
76 | self.block4 = DBlock(self.ndf >> 2, self.ndf >> 1, downsample=True)
77 | self.block5 = DBlock(self.ndf >> 1, self.ndf, downsample=True)
78 | self.l6 = nn.Linear(self.ndf, 1)
79 | self.activation = nn.ReLU(True)
80 |
81 | # Initialise the weights
82 | nn.init.xavier_uniform_(self.l6.weight.data, 1.0)
83 |
84 | def forward(self, x):
85 | r"""
86 | Feedforwards a batch of real/fake images and produces a batch of GAN logits.
87 |
88 | Args:
89 | x (Tensor): A batch of images of shape (N, C, H, W).
90 |
91 | Returns:
92 | Tensor: A batch of GAN logits of shape (N, 1).
93 | """
94 | h = x
95 | h = self.block1(h)
96 | h = self.block2(h)
97 | h = self.block3(h)
98 | h = self.block4(h)
99 | h = self.block5(h)
100 | h = self.activation(h)
101 |
102 | # Global average pooling
103 | h = torch.mean(h, dim=(2, 3)) # WGAN uses mean pooling
104 | output = self.l6(h)
105 |
106 | return output
107 |
--------------------------------------------------------------------------------
/torch_mimicry/training/__init__.py:
--------------------------------------------------------------------------------
1 | from .logger import *
2 | from .metric_log import *
3 | from .scheduler import *
4 | from .trainer import *
--------------------------------------------------------------------------------
/torch_mimicry/training/metric_log.py:
--------------------------------------------------------------------------------
1 | """
2 | MetricLog object for intelligently logging data to display them more intuitively.
3 | """
4 |
5 |
6 | class MetricLog:
7 | """
8 | A dictionary-like object that logs data, and includes an extra dict to map the metrics
9 | to its group name, if any, and the corresponding precision to print out.
10 |
11 | Attributes:
12 | metrics_dict (dict): A dictionary mapping to another dict containing
13 | the corresponding value, precision, and the group this metric belongs to.
14 | """
15 | def __init__(self, **kwargs):
16 | self.metrics_dict = {}
17 |
18 | def add_metric(self, name, value, group=None, precision=4):
19 | """
20 | Logs metric to internal dict, but with an additional option
21 | of grouping certain metrics together.
22 |
23 | Args:
24 | name (str): Name of metric to log.
25 | value (Tensor/Float): Value of the metric to log.
26 | group (str): Name of the group to classify different metrics together.
27 | precision (int): The number of floating point precision to represent the value.
28 |
29 | Returns:
30 | None
31 | """
32 | # Grab tensor values only
33 | try:
34 | value = value.item()
35 | except AttributeError:
36 | value = value
37 |
38 | self.metrics_dict[name] = dict(value=value,
39 | group=group,
40 | precision=precision)
41 |
42 | def __getitem__(self, key):
43 | return round(self.metrics_dict[key]['value'],
44 | self.metrics_dict[key]['precision'])
45 |
46 | def get_group_name(self, name):
47 | """
48 | Obtains the group name of a particular metric. For example, errD and errG
49 | which represents the discriminator/generator losses could fall under a
50 | group name called "loss".
51 |
52 | Args:
53 | name (str): The name of the metric to retrieve group name.
54 |
55 | Returns:
56 | str: A string representing the group name of the metric.
57 | """
58 | return self.metrics_dict[name]['group']
59 |
60 | def keys(self):
61 | """
62 | Dict like functionality for retrieving keys.
63 | """
64 | return self.metrics_dict.keys()
65 |
66 | def items(self):
67 | """
68 | Dict like functionality for retrieving items.
69 | """
70 | return self.metrics_dict.items()
71 |
--------------------------------------------------------------------------------
/torch_mimicry/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .common import *
2 |
--------------------------------------------------------------------------------
/torch_mimicry/utils/common.py:
--------------------------------------------------------------------------------
1 | """
2 | Script for common utility functions.
3 | """
4 | import json
5 | import os
6 |
7 | import numpy as np
8 | import torch
9 | from skimage import io
10 |
11 |
12 | def write_to_json(dict_to_write, output_file):
13 | """
14 | Outputs a given dictionary as a JSON file with indents.
15 |
16 | Args:
17 | dict_to_write (dict): Input dictionary to output.
18 | output_file (str): File path to write the dictionary.
19 |
20 | Returns:
21 | None
22 | """
23 | with open(output_file, 'w') as file:
24 | json.dump(dict_to_write, file, indent=4)
25 |
26 |
27 | def load_from_json(json_file):
28 | """
29 | Loads a JSON file as a dictionary and return it.
30 |
31 | Args:
32 | json_file (str): Input JSON file to read.
33 |
34 | Returns:
35 | dict: Dictionary loaded from the JSON file.
36 | """
37 | with open(json_file, 'r') as file:
38 | return json.load(file)
39 |
40 |
41 | def save_tensor_image(x, output_file):
42 | """
43 | Saves an input image tensor as some numpy array, useful for tests.
44 |
45 | Args:
46 | x (Tensor): A 3D tensor image of shape (3, H, W).
47 | output_file (str): The output image file to save the tensor.
48 |
49 | Returns:
50 | None
51 | """
52 | folder = os.path.dirname(output_file)
53 | if not os.path.exists(folder):
54 | os.makedirs(folder)
55 |
56 | x = x.permute(1, 2, 0).numpy()
57 | io.imsave(output_file, x)
58 |
59 |
60 | def load_images(n=1, size=32):
61 | """
62 | Load n image tensors with some fake labels.
63 |
64 | Args:
65 | n (int): Number of random images to load.
66 | size (int): Spatial size of random image.
67 |
68 | Returns:
69 | Tensor: Random images of shape (n, 3, size, size) and 0-valued labels.
70 | """
71 | images = torch.randn(n, 3, size, size)
72 | labels = torch.from_numpy(np.array([0 * n]))
73 |
74 | return images, labels
--------------------------------------------------------------------------------