├── docs
├── zh_cn
│ ├── changelog.md
│ ├── tutorials
│ │ ├── config.md
│ │ ├── customize_dataset.md
│ │ ├── customize_models.md
│ │ ├── customize_runtime.md
│ │ ├── customize_losses.md
│ │ └── index.rst
│ ├── quick_run.md
│ ├── faq.md
│ ├── Makefile
│ ├── index.rst
│ ├── make.bat
│ ├── stat.py
│ └── api.rst
└── en
│ ├── switch_language.md
│ ├── _static
│ └── css
│ │ └── readthedocs.css
│ ├── tutorials
│ └── index.rst
│ ├── Makefile
│ ├── index.rst
│ ├── make.bat
│ ├── faq.md
│ ├── stat.py
│ └── api.rst
├── requirements
├── readthedocs.txt
├── mminstall.txt
├── tests.txt
├── runtime.txt
└── docs.txt
├── requirements.txt
├── tests
├── test_models
│ ├── test_base_gan.py
│ └── test_base_ddpm.py
├── data
│ ├── image
│ │ ├── baboon.png
│ │ └── img_root
│ │ │ ├── GT05.jpg
│ │ │ ├── baboon.png
│ │ │ ├── audi_a5.jpeg
│ │ │ ├── horse
│ │ │ └── horse.jpeg
│ │ │ └── grass
│ │ │ └── grass1.jpeg
│ ├── paired
│ │ ├── test
│ │ │ ├── 3.jpg
│ │ │ └── 33_AB.jpg
│ │ └── train
│ │ │ ├── 1.jpg
│ │ │ └── 2.jpg
│ └── unpaired
│ │ ├── testA
│ │ └── 5.jpg
│ │ ├── testB
│ │ └── 6.jpg
│ │ ├── trainA
│ │ ├── 1.jpg
│ │ └── 2.jpg
│ │ └── trainB
│ │ ├── 3.jpg
│ │ └── 4.jpg
├── test_datasets
│ ├── test_quicktest_dataset.py
│ ├── test_dataset_wrappers.py
│ ├── test_singan_dataset.py
│ ├── test_unconditional_image_dataset.py
│ ├── test_persistent_worker.py
│ └── test_pipelines
│ │ ├── test_loading.py
│ │ └── test_compose.py
├── test_utils
│ └── test_io_utils.py
├── test_modules
│ ├── test_lpips.py
│ ├── test_fid_inception.py
│ └── test_mspie_archs.py
└── test_ops
│ └── test_conv_gradfix.py
├── configs
├── sagan
│ ├── sagan_128_cvt_studioGAN.py
│ └── sagan_32_cvt_studioGAN.py
├── sngan_proj
│ ├── sngan_proj_128_cvt_studioGAN.py
│ └── sngan_proj_32_cvt_studioGAN.py
├── _base_
│ ├── datasets
│ │ ├── singan.py
│ │ ├── unconditional_imgs_128x128.py
│ │ ├── unconditional_imgs_flip_256x256.py
│ │ ├── unconditional_imgs_flip_512x512.py
│ │ ├── unconditional_imgs_64x64.py
│ │ ├── Inception_Score.py
│ │ ├── unconditional_imgs_flip_lanczos_resize_256x256.py
│ │ ├── lsun-car_pad_512.py
│ │ ├── cifar10_rgb.py
│ │ ├── grow_scale_imgs_ffhq_styleganv1.py
│ │ ├── cifar10_inception_stat.py
│ │ ├── grow_scale_imgs_128x128.py
│ │ ├── ffhq_flip.py
│ │ ├── lsun_stylegan.py
│ │ ├── cifar10_nopad.py
│ │ ├── cifar10.py
│ │ ├── grow_scale_imgs_celeba-hq.py
│ │ ├── cifar10_noaug.py
│ │ ├── imagenet_rgb.py
│ │ └── cifar10_random_noise.py
│ ├── models
│ │ ├── sngan_proj
│ │ │ ├── sngan_proj_32x32.py
│ │ │ └── sngan_proj_128x128.py
│ │ ├── dcgan
│ │ │ ├── dcgan_64x64.py
│ │ │ └── dcgan_128x128.py
│ │ ├── lsgan
│ │ │ └── lsgan_128x128.py
│ │ ├── sagan
│ │ │ ├── sagan_32x32.py
│ │ │ └── sagan_128x128.py
│ │ ├── biggan
│ │ │ ├── biggan_32x32.py
│ │ │ └── biggan_128x128.py
│ │ ├── wgangp
│ │ │ └── wgangp_base.py
│ │ ├── stylegan
│ │ │ ├── styleganv1_base.py
│ │ │ ├── stylegan2_base.py
│ │ │ └── stylegan3_base.py
│ │ ├── singan
│ │ │ └── singan.py
│ │ ├── pix2pix
│ │ │ └── pix2pix_vanilla_unet_bn.py
│ │ ├── pggan
│ │ │ └── pggan_128x128.py
│ │ └── improved_ddpm
│ │ │ ├── ddpm_64x64.py
│ │ │ └── ddpm_32x32.py
│ ├── default_metrics.py
│ └── default_runtime.py
├── positional_encoding_in_gans
│ ├── singan_interp-pad_balloons.py
│ ├── singan_interp-pad_disc-nobn_balloons.py
│ ├── singan_interp-pad_disc-nobn_fish.py
│ ├── singan_csg_fish.py
│ ├── singan_csg_bohemian.py
│ ├── singan_spe-dim4_fish.py
│ ├── singan_spe-dim4_bohemian.py
│ ├── singan_spe-dim8_bohemian.py
│ ├── stylegan2_c2_ffhq_256_b3x8_1100k.py
│ └── stylegan2_c2_ffhq_512_b3x8_1100k.py
├── styleganv2
│ ├── stylegan2_c2_apex_fp16_PL-R1-no-scaler_ffhq_256_b4x8_800k.py
│ ├── stylegan2_c2_fp16-globalG-partialD_PL-R1-no-scaler_ffhq_256_b4x8_800k.py
│ ├── stylegan2_c2_fp16_partial-GD_PL-no-scaler_ffhq_256_b4x8_800k.py
│ ├── stylegan2_c2_apex_fp16_quicktest_ffhq_256_b4x8_800k.py
│ ├── stylegan2_c2_fp16_quicktest_ffhq_256_b4x8_800k.py
│ ├── stylegan2_c2_fp16-global_quicktest_ffhq_256_b4x8_800k.py
│ ├── stylegan2_c2_lsun-cat_256_b4x8_800k.py
│ ├── stylegan2_c2_lsun-horse_256_b4x8_800k.py
│ ├── stylegan2_c2_lsun-church_256_b4x8_800k.py
│ ├── stylegan2_c2_lsun-car_384x512_b4x8.py
│ └── stylegan2_c2_ffhq_1024_b4x8.py
├── styleganv3
│ ├── stylegan3_t_afhqv2_512_b4x8_cvt_official_rgb.py
│ ├── stylegan3_t_ffhq_1024_b4x8_cvt_official_rgb.py
│ ├── stylegan3_t_ffhqu_256_b4x8_cvt_official_rgb.py
│ ├── stylegan3_r_ffhqu_256_b4x8_cvt_official_rgb.py
│ ├── stylegan3_r_ffhq_1024_b4x8_cvt_official_rgb.py
│ └── stylegan3_r_afhqv2_512_b4x8_cvt_official_rgb.py
├── wgan-gp
│ ├── wgangp_GN_GP-50_lsun-bedroom_128_b64x1_160kiter.py
│ ├── wgangp_GN_celeba-cropped_128_b64x1_160kiter.py
│ └── metafile.yml
├── ada
│ └── metafile.yml
├── dcgan
│ ├── dcgan_lsun-bedroom_64x64_b128x1_5e.py
│ └── dcgan_celeba-cropped_64_b128x1_300k.py
├── biggan
│ ├── biggan_128x128_cvt_BigGAN-PyTorch_rgb.py
│ ├── biggan-deep_128x128_cvt_hugging-face_rgb.py
│ ├── biggan-deep_256x256_cvt_hugging-face_rgb.py
│ └── biggan-deep_512x512_cvt_hugging-face_rgb.py
├── singan
│ ├── singan_balloons.py
│ ├── singan_fish.py
│ ├── singan_bohemian.py
│ └── metafile.yml
├── pggan
│ ├── pggan_celeba-hq_1024_g8_12Mimg.py
│ ├── pggan_lsun-bedroom_128_g8_12Mimgs.py
│ └── pggan_celeba-cropped_128_g8_12Mimgs.py
├── styleganv1
│ ├── metafile.yml
│ ├── styleganv1_ffhq_256_g8_25Mimg.py
│ └── styleganv1_ffhq_1024_g8_25Mimg.py
├── improved_ddpm
│ ├── ddpm_cosine_hybird_timestep-4k_drop0.3_cifar10_32x32_b8x16_500k.py
│ ├── ddpm_cosine_hybird_timestep-4k_imagenet1k_64x64_b8x16_1500k.py
│ └── ddpm_cosine_hybird_timestep-4k_drop0.3_imagenet1k_64x64_b8x16_1500k.py
└── lsgan
│ ├── lsgan_lsgan-archi_lr-1e-4_lsun-bedroom_128_b64x1_10m.py
│ ├── lsgan_dcgan-archi_lr-1e-3_celeba-cropped_64_b128x1_12m.py
│ ├── lsgan_dcgan-archi_lr-1e-4_lsun-bedroom_64_b128x1_12m.py
│ └── lsgan_dcgan-archi_lr-1e-4_celeba-cropped_128_b64x1_10m.py
├── mmgen
├── core
│ ├── optimizer
│ │ └── __init__.py
│ ├── scheduler
│ │ └── __init__.py
│ ├── runners
│ │ ├── __init__.py
│ │ └── apex_amp_utils.py
│ ├── __init__.py
│ ├── hooks
│ │ ├── __init__.py
│ │ └── pggan_fetch_data_hook.py
│ ├── evaluation
│ │ └── __init__.py
│ └── registry.py
├── models
│ ├── architectures
│ │ ├── arcface
│ │ │ └── __init__.py
│ │ ├── dcgan
│ │ │ └── __init__.py
│ │ ├── lsgan
│ │ │ └── __init__.py
│ │ ├── wgan_gp
│ │ │ └── __init__.py
│ │ ├── cyclegan
│ │ │ └── __init__.py
│ │ ├── pix2pix
│ │ │ └── __init__.py
│ │ ├── sngan_proj
│ │ │ └── __init__.py
│ │ ├── ddpm
│ │ │ └── __init__.py
│ │ ├── singan
│ │ │ └── __init__.py
│ │ ├── lpips
│ │ │ └── __init__.py
│ │ ├── stylegan
│ │ │ ├── modules
│ │ │ │ └── __init__.py
│ │ │ ├── __init__.py
│ │ │ └── ada
│ │ │ │ └── misc.py
│ │ ├── common.py
│ │ ├── pggan
│ │ │ └── __init__.py
│ │ └── biggan
│ │ │ └── __init__.py
│ ├── diffusions
│ │ ├── __init__.py
│ │ └── sampler.py
│ ├── common
│ │ ├── __init__.py
│ │ └── dist_utils.py
│ ├── translation_models
│ │ └── __init__.py
│ ├── __init__.py
│ ├── gans
│ │ └── __init__.py
│ ├── builder.py
│ └── losses
│ │ └── __init__.py
├── datasets
│ ├── samplers
│ │ └── __init__.py
│ ├── quick_test_dataset.py
│ ├── pipelines
│ │ └── __init__.py
│ ├── __init__.py
│ └── dataset_wrappers.py
├── ops
│ ├── __init__.py
│ └── stylegan3
│ │ ├── ops
│ │ ├── __init__.py
│ │ ├── bias_act.h
│ │ ├── filtered_lrelu_rd.cu
│ │ ├── filtered_lrelu_wr.cu
│ │ └── filtered_lrelu_ns.cu
│ │ └── __init__.py
├── utils
│ ├── __init__.py
│ ├── logger.py
│ └── dist_util.py
├── apis
│ └── __init__.py
├── version.py
└── __init__.py
├── .github
├── ISSUE_TEMPLATE
│ ├── general_questions.md
│ ├── config.yml
│ ├── 4-documentation.yml
│ ├── feature_request.md
│ ├── 2-feature-request.yml
│ ├── 3-new-model.yml
│ └── error-report.md
└── workflows
│ ├── scripts
│ └── get_mmcv_var.sh
│ ├── publish-to-pypi.yml
│ ├── lint.yml
│ └── test_mim.yml
├── .readthedocs.yml
├── tools
├── eval.sh
├── dist_train.sh
├── dist_eval.sh
├── slurm_eval.sh
├── slurm_eval_multi_gpu.sh
├── slurm_train.sh
├── publish_model.py
└── misc
│ └── print_config.py
├── CITATION.cff
├── MANIFEST.in
├── .circleci
├── docker
│ └── Dockerfile
├── scripts
│ └── get_mmcv_var.sh
└── config.yml
├── model-index.yml
├── setup.cfg
├── docker
└── Dockerfile
└── LICENSES.md
/docs/zh_cn/changelog.md:
--------------------------------------------------------------------------------
1 | # 版本更新日志
2 |
--------------------------------------------------------------------------------
/requirements/readthedocs.txt:
--------------------------------------------------------------------------------
1 | mmcv
2 | torch
3 | torchvision
4 |
--------------------------------------------------------------------------------
/docs/zh_cn/tutorials/config.md:
--------------------------------------------------------------------------------
1 | # Tutorial 1: 配置系统 (config files)
2 |
--------------------------------------------------------------------------------
/docs/zh_cn/tutorials/customize_dataset.md:
--------------------------------------------------------------------------------
1 | # Tutorial 2: 自定义数据集
2 |
--------------------------------------------------------------------------------
/docs/zh_cn/tutorials/customize_models.md:
--------------------------------------------------------------------------------
1 | # Tutorial 3: 自定义模型
2 |
--------------------------------------------------------------------------------
/docs/zh_cn/tutorials/customize_runtime.md:
--------------------------------------------------------------------------------
1 | # Tutorial 6: 自定义配置
2 |
--------------------------------------------------------------------------------
/docs/zh_cn/tutorials/customize_losses.md:
--------------------------------------------------------------------------------
1 | # Tutorial 4: 损失函数模块的设计思路
2 |
--------------------------------------------------------------------------------
/requirements/mminstall.txt:
--------------------------------------------------------------------------------
1 | mmcls>=0.18.0
2 | mmcv-full>=1.3.0,<=1.8.0
3 |
--------------------------------------------------------------------------------
/docs/zh_cn/quick_run.md:
--------------------------------------------------------------------------------
1 | # 1: 在标准的数据集上训练和推理现有的模型
2 |
3 | ## 用现有的生成模型来生成图像
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | -r requirements/runtime.txt
2 | -r requirements/tests.txt
3 |
--------------------------------------------------------------------------------
/tests/test_models/test_base_gan.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
--------------------------------------------------------------------------------
/configs/sagan/sagan_128_cvt_studioGAN.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../_base_/models/sagan/sagan_128x128.py']
2 |
--------------------------------------------------------------------------------
/configs/sagan/sagan_32_cvt_studioGAN.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../_base_/models/sagan/sagan_32x32.py']
2 |
--------------------------------------------------------------------------------
/tests/data/image/baboon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/image/baboon.png
--------------------------------------------------------------------------------
/configs/sngan_proj/sngan_proj_128_cvt_studioGAN.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../_base_/models/sngan_proj/sngan_proj_128x128.py']
2 |
--------------------------------------------------------------------------------
/configs/sngan_proj/sngan_proj_32_cvt_studioGAN.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../_base_/models/sngan_proj/sngan_proj_32x32.py']
2 |
--------------------------------------------------------------------------------
/tests/data/paired/test/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/paired/test/3.jpg
--------------------------------------------------------------------------------
/tests/data/paired/train/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/paired/train/1.jpg
--------------------------------------------------------------------------------
/tests/data/paired/train/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/paired/train/2.jpg
--------------------------------------------------------------------------------
/tests/data/paired/test/33_AB.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/paired/test/33_AB.jpg
--------------------------------------------------------------------------------
/tests/data/unpaired/testA/5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/unpaired/testA/5.jpg
--------------------------------------------------------------------------------
/tests/data/unpaired/testB/6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/unpaired/testB/6.jpg
--------------------------------------------------------------------------------
/tests/data/unpaired/trainA/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/unpaired/trainA/1.jpg
--------------------------------------------------------------------------------
/tests/data/unpaired/trainA/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/unpaired/trainA/2.jpg
--------------------------------------------------------------------------------
/tests/data/unpaired/trainB/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/unpaired/trainB/3.jpg
--------------------------------------------------------------------------------
/tests/data/unpaired/trainB/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/unpaired/trainB/4.jpg
--------------------------------------------------------------------------------
/tests/data/image/img_root/GT05.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/image/img_root/GT05.jpg
--------------------------------------------------------------------------------
/requirements/tests.txt:
--------------------------------------------------------------------------------
1 | coverage < 7.0.0
2 | # codecov
3 | flake8
4 | interrogate
5 | isort==4.3.21
6 | pytest
7 | pytest-runner
8 |
--------------------------------------------------------------------------------
/tests/data/image/img_root/baboon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/image/img_root/baboon.png
--------------------------------------------------------------------------------
/requirements/runtime.txt:
--------------------------------------------------------------------------------
1 | mmcls
2 | ninja
3 | numpy
4 | prettytable
5 | requests
6 | scikit-image
7 | scipy
8 | tqdm
9 | yapf
10 |
--------------------------------------------------------------------------------
/tests/data/image/img_root/audi_a5.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/image/img_root/audi_a5.jpeg
--------------------------------------------------------------------------------
/tests/data/image/img_root/horse/horse.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/image/img_root/horse/horse.jpeg
--------------------------------------------------------------------------------
/tests/data/image/img_root/grass/grass1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/image/img_root/grass/grass1.jpeg
--------------------------------------------------------------------------------
/mmgen/core/optimizer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .builder import build_optimizers
3 |
4 | __all__ = ['build_optimizers']
5 |
--------------------------------------------------------------------------------
/mmgen/models/architectures/arcface/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .id_loss import IDLossModel
3 |
4 | __all__ = ['IDLossModel']
5 |
--------------------------------------------------------------------------------
/mmgen/core/scheduler/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .lr_updater import LinearLrUpdaterHook
3 |
4 | __all__ = ['LinearLrUpdaterHook']
5 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/general_questions.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: General questions
3 | about: Ask general questions to get help
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 | ---
8 |
--------------------------------------------------------------------------------
/docs/en/switch_language.md:
--------------------------------------------------------------------------------
1 | ## English
2 |
3 | ## 简体中文
4 |
--------------------------------------------------------------------------------
/mmgen/datasets/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .distributed_sampler import DistributedSampler
3 |
4 | __all__ = ['DistributedSampler']
5 |
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | python:
4 | version: 3.7
5 | install:
6 | - requirements: requirements/docs.txt
7 | - requirements: requirements/readthedocs.txt
8 |
--------------------------------------------------------------------------------
/mmgen/core/runners/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .dynamic_iterbased_runner import DynamicIterBasedRunner
3 |
4 | __all__ = ['DynamicIterBasedRunner']
5 |
--------------------------------------------------------------------------------
/mmgen/models/architectures/dcgan/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .generator_discriminator import DCGANDiscriminator, DCGANGenerator
3 |
4 | __all__ = ['DCGANGenerator', 'DCGANDiscriminator']
5 |
--------------------------------------------------------------------------------
/mmgen/models/architectures/lsgan/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .generator_discriminator import LSGANDiscriminator, LSGANGenerator
3 |
4 | __all__ = ['LSGANDiscriminator', 'LSGANGenerator']
5 |
--------------------------------------------------------------------------------
/mmgen/models/architectures/wgan_gp/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .generator_discriminator import WGANGPDiscriminator, WGANGPGenerator
3 |
4 | __all__ = ['WGANGPDiscriminator', 'WGANGPGenerator']
5 |
--------------------------------------------------------------------------------
/tools/eval.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 |
5 | CONFIG=$1
6 | CKPT=$2
7 | PY_ARGS=${@:3}
8 |
9 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
10 | python -u tools/evaluation.py ${CONFIG} ${CKPT} --launcher="none" ${PY_ARGS}
11 |
--------------------------------------------------------------------------------
/docs/zh_cn/faq.md:
--------------------------------------------------------------------------------
1 | # 常见问题解答
2 |
3 | 我们在这里罗列了许多用户遇到的一些常见的问题和相应的解决方案。如果您发现任何常见问题并有办法帮助他人解决该问题,欢迎丰富该列表。如果此处的内容未涵盖您的问题,请使用[提供的模版](https://github.com/open-mmlab/mmgeneration/blob/master/.github/ISSUE_TEMPLATE/error-report.md)来创建新问题,并确保将模版中所有必需的信息填写完整。
4 |
--------------------------------------------------------------------------------
/docs/zh_cn/tutorials/index.rst:
--------------------------------------------------------------------------------
1 | .. toctree::
2 | :maxdepth: 2
3 |
4 | config.md
5 | customize_dataset.md
6 | customize_models.md
7 | customize_losses.md
8 | ddp_train_gans.md
9 | customize_runtime.md
10 | applications.md
11 |
--------------------------------------------------------------------------------
/mmgen/models/diffusions/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .base_diffusion import BasicGaussianDiffusion
3 | from .sampler import UniformTimeStepSampler
4 |
5 | __all__ = ['BasicGaussianDiffusion', 'UniformTimeStepSampler']
6 |
--------------------------------------------------------------------------------
/mmgen/models/common/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .dist_utils import AllGatherLayer
3 | from .model_utils import GANImageBuffer, set_requires_grad
4 |
5 | __all__ = ['set_requires_grad', 'AllGatherLayer', 'GANImageBuffer']
6 |
--------------------------------------------------------------------------------
/mmgen/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .conv2d_gradfix import conv2d, conv_transpose2d
3 | from .stylegan3.ops import bias_act, filtered_lrelu
4 |
5 | __all__ = ['conv2d', 'conv_transpose2d', 'filtered_lrelu', 'bias_act']
6 |
--------------------------------------------------------------------------------
/mmgen/models/architectures/cyclegan/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .generator_discriminator import ResnetGenerator
3 | from .modules import ResidualBlockWithDropout
4 |
5 | __all__ = ['ResnetGenerator', 'ResidualBlockWithDropout']
6 |
--------------------------------------------------------------------------------
/docs/en/_static/css/readthedocs.css:
--------------------------------------------------------------------------------
1 | .header-logo {
2 | background-image: url("https://user-images.githubusercontent.com/12726765/114528756-de55af80-9c7b-11eb-94d7-d3224ada1585.png");
3 | background-size: 173px 40px;
4 | height: 40px;
5 | width: 173px;
6 | }
7 |
--------------------------------------------------------------------------------
/docs/en/tutorials/index.rst:
--------------------------------------------------------------------------------
1 | .. toctree::
2 | :maxdepth: 2
3 |
4 | config.md
5 | customize_dataset.md
6 | customize_models.md
7 | customize_losses.md
8 | inception_stat.md
9 | ddp_train_gans.md
10 | customize_runtime.md
11 | applications.md
12 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - name: "MMGeneration Contributors"
5 | title: "OpenMMLab's next-generation toolbox for generative models"
6 | date-released: 2020-07-10
7 | url: "https://github.com/open-mmlab/mmgeneration"
8 | license: Apache-2.0
9 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/singan.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'SinGANDataset'
2 |
3 | data = dict(
4 | samples_per_gpu=1,
5 | workers_per_gpu=4,
6 | drop_last=False,
7 | train=dict(
8 | type=dataset_type,
9 | img_path=None, # need to set
10 | min_size=25,
11 | max_size=250,
12 | scale_factor_init=0.75))
13 |
--------------------------------------------------------------------------------
/mmgen/core/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .evaluation import * # noqa: F401, F403
3 | from .hooks import * # noqa: F401, F403
4 | from .optimizer import * # noqa: F401, F403
5 | from .registry import * # noqa: F401, F403
6 | from .runners import * # noqa: F401, F403
7 | from .scheduler import * # noqa: F401, F403
8 |
--------------------------------------------------------------------------------
/configs/positional_encoding_in_gans/singan_interp-pad_balloons.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../singan/singan_balloons.py']
2 |
3 | model = dict(
4 | type='PESinGAN',
5 | generator=dict(
6 | type='SinGANMSGeneratorPE', interp_pad=True, noise_with_pad=True))
7 |
8 | train_cfg = dict(fixed_noise_with_pad=True)
9 |
10 | dist_params = dict(backend='nccl')
11 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include mmgen/model-index.yml
2 | recursive-include mmgen/configs *.py *.yml
3 | recursive-include mmgen/tools *.sh *.py
4 |
5 | include requirements/*.txt
6 | include mmgen/VERSION
7 | include mmgen/.mim/model-index.yml
8 | include mmgen/.mim/demo/*/*
9 | recursive-include mmgen/.mim/configs *.py *.yml
10 | recursive-include mmgen/.mim/tools *.sh *.py
11 |
--------------------------------------------------------------------------------
/mmgen/models/architectures/pix2pix/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .generator_discriminator import PatchDiscriminator, UnetGenerator
3 | from .modules import UnetSkipConnectionBlock, generation_init_weights
4 |
5 | __all__ = [
6 | 'PatchDiscriminator', 'UnetGenerator', 'UnetSkipConnectionBlock',
7 | 'generation_init_weights'
8 | ]
9 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 |
3 | contact_links:
4 | - name: 💬 Forum
5 | url: https://github.com/open-mmlab/mmgeneration/discussions
6 | about: Ask general usage questions and discuss with other MMGeneration community members
7 | - name: 🌐 Explore OpenMMLab
8 | url: https://openmmlab.com/
9 | about: Get know more about OpenMMLab
10 |
--------------------------------------------------------------------------------
/mmgen/models/translation_models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .base_translation_model import BaseTranslationModel
3 | from .cyclegan import CycleGAN
4 | from .pix2pix import Pix2Pix
5 | from .static_translation_gan import StaticTranslationGAN
6 |
7 | __all__ = [
8 | 'Pix2Pix', 'CycleGAN', 'BaseTranslationModel', 'StaticTranslationGAN'
9 | ]
10 |
--------------------------------------------------------------------------------
/mmgen/models/architectures/sngan_proj/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .generator_discriminator import ProjDiscriminator, SNGANGenerator
3 | from .modules import SNGANDiscHeadResBlock, SNGANDiscResBlock, SNGANGenResBlock
4 |
5 | __all__ = [
6 | 'ProjDiscriminator', 'SNGANGenerator', 'SNGANGenResBlock',
7 | 'SNGANDiscResBlock', 'SNGANDiscHeadResBlock'
8 | ]
9 |
--------------------------------------------------------------------------------
/mmgen/models/architectures/ddpm/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .denoising import DenoisingUnet
3 | from .modules import (DenoisingDownsample, DenoisingResBlock,
4 | DenoisingUpsample, TimeEmbedding)
5 |
6 | __all__ = [
7 | 'DenoisingUnet', 'TimeEmbedding', 'DenoisingDownsample',
8 | 'DenoisingUpsample', 'DenoisingResBlock'
9 | ]
10 |
--------------------------------------------------------------------------------
/configs/positional_encoding_in_gans/singan_interp-pad_disc-nobn_balloons.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../singan/singan_balloons.py']
2 |
3 | model = dict(
4 | type='PESinGAN',
5 | generator=dict(
6 | type='SinGANMSGeneratorPE', interp_pad=True, noise_with_pad=True),
7 | discriminator=dict(norm_cfg=None))
8 |
9 | train_cfg = dict(fixed_noise_with_pad=True)
10 |
11 | dist_params = dict(backend='nccl')
12 |
--------------------------------------------------------------------------------
/mmgen/models/architectures/singan/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .generator_discriminator import (SinGANMultiScaleDiscriminator,
3 | SinGANMultiScaleGenerator)
4 | from .positional_encoding import SinGANMSGeneratorPE
5 |
6 | __all__ = [
7 | 'SinGANMultiScaleDiscriminator', 'SinGANMultiScaleGenerator',
8 | 'SinGANMSGeneratorPE'
9 | ]
10 |
--------------------------------------------------------------------------------
/mmgen/models/architectures/lpips/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | r"""
3 | The lpips module was adapted from https://github.com/rosinality/stylegan2-pytorch/tree/master/lpips , # noqa
4 | and you can see the origin implementation in https://github.com/richzhang/PerceptualSimilarity/tree/master/lpips # noqa
5 | """
6 | from .perceptual_loss import PerceptualLoss
7 |
8 | __all__ = ['PerceptualLoss']
9 |
--------------------------------------------------------------------------------
/mmgen/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .collect_env import collect_env
3 | from .dist_util import check_dist_init, sync_random_seed
4 | from .io_utils import MMGEN_CACHE_DIR, download_from_url
5 | from .logger import get_root_logger
6 |
7 | __all__ = [
8 | 'collect_env', 'get_root_logger', 'download_from_url', 'check_dist_init',
9 | 'MMGEN_CACHE_DIR', 'sync_random_seed'
10 | ]
11 |
--------------------------------------------------------------------------------
/configs/styleganv2/stylegan2_c2_apex_fp16_PL-R1-no-scaler_ffhq_256_b4x8_800k.py:
--------------------------------------------------------------------------------
1 | """Config for the `config-f` setting in StyleGAN2."""
2 |
3 | _base_ = ['./stylegan2_c2_ffhq_256_b4x8_800k.py']
4 |
5 | model = dict(
6 | disc_auxiliary_loss=dict(use_apex_amp=False),
7 | gen_auxiliary_loss=dict(use_apex_amp=False),
8 | )
9 |
10 | total_iters = 800002
11 |
12 | apex_amp = dict(mode='gan', init_args=dict(opt_level='O1', num_losses=2))
13 | resume_from = None
14 |
--------------------------------------------------------------------------------
/configs/styleganv3/stylegan3_t_afhqv2_512_b4x8_cvt_official_rgb.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./stylegan3_base.py']
2 |
3 | synthesis_cfg = {
4 | 'type': 'SynthesisNetwork',
5 | 'channel_base': 32768,
6 | 'channel_max': 512,
7 | 'magnitude_ema_beta': 0.999
8 | }
9 | model = dict(
10 | type='StaticUnconditionalGAN',
11 | generator=dict(
12 | out_size=512,
13 | img_channels=3,
14 | rgb2bgr=True,
15 | synthesis_cfg=synthesis_cfg),
16 | discriminator=dict(in_size=512))
17 |
--------------------------------------------------------------------------------
/tests/test_datasets/test_quicktest_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmgen.datasets.quick_test_dataset import QuickTestImageDataset
3 |
4 |
5 | class TestQuickTest:
6 |
7 | @classmethod
8 | def setup_class(cls):
9 | cls.dataset = QuickTestImageDataset(size=(256, 256))
10 |
11 | def test_quicktest_dataset(self):
12 | assert len(self.dataset) == 10000
13 | img = self.dataset[2]
14 | assert img['real_img'].shape == (3, 256, 256)
15 |
--------------------------------------------------------------------------------
/mmgen/apis/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .inference import (init_model, sample_conditional_model,
3 | sample_ddpm_model, sample_img2img_model,
4 | sample_unconditional_model)
5 | from .train import set_random_seed, train_model
6 |
7 | __all__ = [
8 | 'set_random_seed', 'train_model', 'init_model', 'sample_img2img_model',
9 | 'sample_unconditional_model', 'sample_conditional_model',
10 | 'sample_ddpm_model'
11 | ]
12 |
--------------------------------------------------------------------------------
/mmgen/ops/stylegan3/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/configs/styleganv3/stylegan3_t_ffhq_1024_b4x8_cvt_official_rgb.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./stylegan3_base.py']
2 |
3 | synthesis_cfg = {
4 | 'type': 'SynthesisNetwork',
5 | 'channel_base': 32768,
6 | 'channel_max': 512,
7 | 'magnitude_ema_beta': 0.999
8 | }
9 |
10 | model = dict(
11 | type='StaticUnconditionalGAN',
12 | generator=dict(
13 | out_size=1024,
14 | img_channels=3,
15 | synthesis_cfg=synthesis_cfg,
16 | rgb2bgr=True),
17 | discriminator=dict(in_size=1024))
18 |
--------------------------------------------------------------------------------
/configs/wgan-gp/wgangp_GN_GP-50_lsun-bedroom_128_b64x1_160kiter.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./wgangp_GN_celeba-cropped_128_b64x1_160kiter.py']
2 |
3 | model = dict(disc_auxiliary_loss=[
4 | dict(
5 | type='GradientPenaltyLoss',
6 | loss_weight=50,
7 | norm_mode='HWC',
8 | data_info=dict(
9 | discriminator='disc', real_data='real_imgs',
10 | fake_data='fake_imgs'))
11 | ])
12 |
13 | data = dict(
14 | samples_per_gpu=64, train=dict(imgs_root='./data/lsun/bedroom_train'))
15 |
--------------------------------------------------------------------------------
/configs/styleganv3/stylegan3_t_ffhqu_256_b4x8_cvt_official_rgb.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./stylegan3_base.py']
2 |
3 | synthesis_cfg = {
4 | 'type': 'SynthesisNetwork',
5 | 'channel_base': 16384,
6 | 'channel_max': 512,
7 | 'magnitude_ema_beta': 0.999
8 | }
9 | model = dict(
10 | type='StaticUnconditionalGAN',
11 | generator=dict(
12 | out_size=256,
13 | img_channels=3,
14 | rgb2bgr=True,
15 | synthesis_cfg=synthesis_cfg),
16 | discriminator=dict(in_size=256, channel_multiplier=1))
17 |
--------------------------------------------------------------------------------
/tests/test_utils/test_io_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os
3 | from tempfile import TemporaryDirectory
4 |
5 | from mmgen.utils import download_from_url
6 |
7 |
8 | def test_download_from_file():
9 | img_url = 'https://user-images.githubusercontent.com/12726765/114528756-de55af80-9c7b-11eb-94d7-d3224ada1585.png' # noqa
10 | with TemporaryDirectory() as temp_dir:
11 | local_file = download_from_url(url=img_url, dest_dir=temp_dir)
12 | assert os.path.exists(local_file)
13 |
--------------------------------------------------------------------------------
/mmgen/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .architectures import * # noqa: F401, F403
3 | from .builder import MODELS, MODULES, build_model, build_module
4 | from .common import * # noqa: F401, F403
5 | from .diffusions import * # noqa: F401, F403
6 | from .gans import * # noqa: F401, F403
7 | from .losses import * # noqa: F401, F403
8 | from .misc import * # noqa: F401, F403
9 | from .translation_models import * # noqa: F401, F403
10 |
11 | __all__ = ['build_model', 'MODELS', 'build_module', 'MODULES']
12 |
--------------------------------------------------------------------------------
/tools/dist_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | GPUS=$2
5 | NNODES=${NNODES:-1}
6 | NODE_RANK=${NODE_RANK:-0}
7 | PORT=${PORT:-29500}
8 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
9 |
10 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
11 | python -m torch.distributed.launch \
12 | --nnodes=$NNODES \
13 | --node_rank=$NODE_RANK \
14 | --master_addr=$MASTER_ADDR \
15 | --nproc_per_node=$GPUS \
16 | --master_port=$PORT \
17 | $(dirname "$0")/train.py \
18 | $CONFIG \
19 | --seed 0 \
20 | --launcher pytorch ${@:3}
21 |
--------------------------------------------------------------------------------
/configs/positional_encoding_in_gans/singan_interp-pad_disc-nobn_fish.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../singan/singan_fish.py']
2 |
3 | model = dict(
4 | type='PESinGAN',
5 | generator=dict(
6 | type='SinGANMSGeneratorPE', interp_pad=True, noise_with_pad=True),
7 | discriminator=dict(norm_cfg=None))
8 |
9 | train_cfg = dict(fixed_noise_with_pad=True)
10 |
11 | data = dict(
12 | train=dict(
13 | img_path='./data/singan/fish-crop.jpg',
14 | min_size=25,
15 | max_size=300,
16 | ))
17 |
18 | dist_params = dict(backend='nccl')
19 |
--------------------------------------------------------------------------------
/mmgen/models/gans/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .base_gan import BaseGAN
3 | from .basic_conditional_gan import BasicConditionalGAN
4 | from .mspie_stylegan2 import MSPIEStyleGAN2
5 | from .progressive_growing_unconditional_gan import ProgressiveGrowingGAN
6 | from .singan import PESinGAN, SinGAN
7 | from .static_unconditional_gan import StaticUnconditionalGAN
8 |
9 | __all__ = [
10 | 'BaseGAN', 'StaticUnconditionalGAN', 'ProgressiveGrowingGAN', 'SinGAN',
11 | 'MSPIEStyleGAN2', 'PESinGAN', 'BasicConditionalGAN'
12 | ]
13 |
--------------------------------------------------------------------------------
/mmgen/ops/stylegan3/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 | from .ops import filtered_lrelu
11 |
12 | __all__ = ['filtered_lrelu']
13 |
--------------------------------------------------------------------------------
/tools/dist_eval.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | CHECKPOINT=$2
5 | GPUS=$3
6 | NNODES=${NNODES:-1}
7 | NODE_RANK=${NODE_RANK:-0}
8 | PORT=${PORT:-29500}
9 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
10 |
11 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
12 | python -m torch.distributed.launch \
13 | --nnodes=$NNODES \
14 | --node_rank=$NODE_RANK \
15 | --master_addr=$MASTER_ADDR \
16 | --nproc_per_node=$GPUS \
17 | --master_port=$PORT \
18 | $(dirname "$0")/test.py \
19 | $CONFIG \
20 | $CHECKPOINT \
21 | --launcher pytorch \
22 | ${@:4}
23 |
--------------------------------------------------------------------------------
/.circleci/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | ARG PYTORCH="1.8.1"
2 | ARG CUDA="10.2"
3 | ARG CUDNN="7"
4 |
5 | FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
6 |
7 | # To fix GPG key error when running apt-get update
8 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
9 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
10 |
11 | RUN apt-get update && apt-get install -y ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 libgl1-mesa-glx
12 |
--------------------------------------------------------------------------------
/configs/styleganv3/stylegan3_r_ffhqu_256_b4x8_cvt_official_rgb.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./stylegan3_base.py']
2 |
3 | synthesis_cfg = {
4 | 'type': 'SynthesisNetwork',
5 | 'channel_base': 32768,
6 | 'channel_max': 1024,
7 | 'magnitude_ema_beta': 0.999,
8 | 'conv_kernel': 1,
9 | 'use_radial_filters': True
10 | }
11 | model = dict(
12 | type='StaticUnconditionalGAN',
13 | generator=dict(
14 | out_size=256,
15 | img_channels=3,
16 | rgb2bgr=True,
17 | synthesis_cfg=synthesis_cfg),
18 | discriminator=dict(in_size=256, channel_multiplier=1))
19 |
--------------------------------------------------------------------------------
/mmgen/core/hooks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .ceph_hooks import PetrelUploadHook
3 | from .ema_hook import ExponentialMovingAverageHook
4 | from .pggan_fetch_data_hook import PGGANFetchDataHook
5 | from .pickle_data_hook import PickleDataHook
6 | from .visualization import VisualizationHook
7 | from .visualize_training_samples import VisualizeUnconditionalSamples
8 |
9 | __all__ = [
10 | 'VisualizeUnconditionalSamples', 'PGGANFetchDataHook',
11 | 'ExponentialMovingAverageHook', 'VisualizationHook', 'PickleDataHook',
12 | 'PetrelUploadHook'
13 | ]
14 |
--------------------------------------------------------------------------------
/tests/test_models/test_base_ddpm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 | from mmgen.models.diffusions import UniformTimeStepSampler
5 |
6 |
7 | def test_uniform_sampler():
8 | sampler = UniformTimeStepSampler(10)
9 | timesteps = sampler(2)
10 | assert timesteps.shape == torch.Size([
11 | 2,
12 | ])
13 | assert timesteps.max() < 10 and timesteps.min() >= 0
14 |
15 | timesteps = sampler.__call__(2)
16 | assert timesteps.shape == torch.Size([
17 | 2,
18 | ])
19 | assert timesteps.max() < 10 and timesteps.min() >= 0
20 |
--------------------------------------------------------------------------------
/configs/_base_/models/sngan_proj/sngan_proj_32x32.py:
--------------------------------------------------------------------------------
1 | # define GAN model
2 | model = dict(
3 | type='BasiccGAN',
4 | generator=dict(type='SNGANGenerator', output_scale=32, base_channels=256),
5 | discriminator=dict(
6 | type='ProjDiscriminator', input_scale=32, base_channels=128),
7 | gan_loss=dict(type='GANLoss', gan_type='hinge'))
8 |
9 | train_cfg = dict(disc_steps=5)
10 | test_cfg = None
11 |
12 | # define optimizer
13 | optimizer = dict(
14 | generator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)),
15 | discriminator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)))
16 |
--------------------------------------------------------------------------------
/configs/_base_/models/sngan_proj/sngan_proj_128x128.py:
--------------------------------------------------------------------------------
1 | # define GAN model
2 | model = dict(
3 | type='BasiccGAN',
4 | generator=dict(type='SNGANGenerator', output_scale=128, base_channels=64),
5 | discriminator=dict(
6 | type='ProjDiscriminator', input_scale=128, base_channels=64),
7 | gan_loss=dict(type='GANLoss', gan_type='hinge'))
8 |
9 | train_cfg = dict(disc_steps=2)
10 | test_cfg = None
11 |
12 | # define optimizer
13 | optimizer = dict(
14 | generator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)),
15 | discriminator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)))
16 |
--------------------------------------------------------------------------------
/configs/_base_/default_metrics.py:
--------------------------------------------------------------------------------
1 | metrics = dict(
2 | fid50k=dict(type='FID', num_images=50000),
3 | pr50k3=dict(type='PR', num_images=50000, k=3),
4 | is50k=dict(type='IS', num_images=50000),
5 | ppl_zfull=dict(type='PPL', space='Z', sampling='full', num_images=50000),
6 | ppl_wfull=dict(type='PPL', space='W', sampling='full', num_images=50000),
7 | ppl_zend=dict(type='PPL', space='Z', sampling='end', num_images=50000),
8 | ppl_wend=dict(type='PPL', space='W', sampling='end', num_images=50000),
9 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000),
10 | swd16k=dict(type='SWD', num_images=16384))
11 |
--------------------------------------------------------------------------------
/model-index.yml:
--------------------------------------------------------------------------------
1 | Import:
2 | - configs/ada/metafile.yml
3 | - configs/biggan/metafile.yml
4 | - configs/cyclegan/metafile.yml
5 | - configs/dcgan/metafile.yml
6 | - configs/ggan/metafile.yml
7 | - configs/improved_ddpm/metafile.yml
8 | - configs/lsgan/metafile.yml
9 | - configs/pggan/metafile.yml
10 | - configs/pix2pix/metafile.yml
11 | - configs/positional_encoding_in_gans/metafile.yml
12 | - configs/sagan/metafile.yml
13 | - configs/singan/metafile.yml
14 | - configs/sngan_proj/metafile.yml
15 | - configs/styleganv1/metafile.yml
16 | - configs/styleganv2/metafile.yml
17 | - configs/styleganv3/metafile.yml
18 | - configs/wgan-gp/metafile.yml
19 |
--------------------------------------------------------------------------------
/configs/styleganv3/stylegan3_r_ffhq_1024_b4x8_cvt_official_rgb.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./stylegan3_base.py']
2 |
3 | synthesis_cfg = {
4 | 'type': 'SynthesisNetwork',
5 | 'channel_base': 65536,
6 | 'channel_max': 1024,
7 | 'magnitude_ema_beta': 0.999,
8 | 'conv_kernel': 1,
9 | 'use_radial_filters': True
10 | }
11 |
12 | r1_gamma = 32.8
13 | d_reg_interval = 16
14 |
15 | model = dict(
16 | type='StaticUnconditionalGAN',
17 | generator=dict(
18 | out_size=1024,
19 | img_channels=3,
20 | synthesis_cfg=synthesis_cfg,
21 | rgb2bgr=True),
22 | discriminator=dict(type='StyleGAN2Discriminator', in_size=1024))
23 |
--------------------------------------------------------------------------------
/configs/_base_/models/dcgan/dcgan_64x64.py:
--------------------------------------------------------------------------------
1 | # define GAN model
2 | model = dict(
3 | type='StaticUnconditionalGAN',
4 | generator=dict(type='DCGANGenerator', output_scale=64, base_channels=1024),
5 | discriminator=dict(
6 | type='DCGANDiscriminator',
7 | input_scale=64,
8 | output_scale=4,
9 | out_channels=1),
10 | gan_loss=dict(type='GANLoss', gan_type='vanilla'))
11 |
12 | train_cfg = dict(disc_steps=1)
13 | test_cfg = None
14 |
15 | # define optimizer
16 | optimizer = dict(
17 | generator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)),
18 | discriminator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)))
19 |
--------------------------------------------------------------------------------
/mmgen/models/architectures/stylegan/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .styleganv2_modules import (Blur, ConstantInput, ModulatedConv2d,
3 | ModulatedStyleConv, ModulatedToRGB,
4 | NoiseInjection)
5 | from .styleganv3_modules import (MappingNetwork, SynthesisInput,
6 | SynthesisLayer, SynthesisNetwork)
7 |
8 | __all__ = [
9 | 'Blur', 'ModulatedStyleConv', 'ModulatedToRGB', 'NoiseInjection',
10 | 'ConstantInput', 'MappingNetwork', 'SynthesisInput', 'SynthesisLayer',
11 | 'SynthesisNetwork', 'ModulatedConv2d'
12 | ]
13 |
--------------------------------------------------------------------------------
/mmgen/models/architectures/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 |
5 | def get_module_device(module):
6 | """Get the device of a module.
7 |
8 | Args:
9 | module (nn.Module): A module contains the parameters.
10 |
11 | Returns:
12 | torch.device: The device of the module.
13 | """
14 | try:
15 | next(module.parameters())
16 | except StopIteration:
17 | raise ValueError('The input module should contain parameters.')
18 |
19 | if next(module.parameters()).is_cuda:
20 | return next(module.parameters()).get_device()
21 |
22 | return torch.device('cpu')
23 |
--------------------------------------------------------------------------------
/.circleci/scripts/get_mmcv_var.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | TORCH=$1
4 | CUDA=$2
5 |
6 | # 10.2 -> cu102
7 | MMCV_CUDA="cu`echo ${CUDA} | tr -d '.'`"
8 |
9 | # MMCV only provides pre-compiled packages for torch 1.x.0
10 | # which works for any subversions of torch 1.x.
11 | # We force the torch version to be 1.x.0 to ease package searching
12 | # and avoid unnecessary rebuild during MMCV's installation.
13 | TORCH_VER_ARR=(${TORCH//./ })
14 | TORCH_VER_ARR[2]=0
15 | printf -v MMCV_TORCH "%s." "${TORCH_VER_ARR[@]}"
16 | MMCV_TORCH=${MMCV_TORCH%?} # Remove the last dot
17 |
18 | echo "export MMCV_CUDA=${MMCV_CUDA}" >> $BASH_ENV
19 | echo "export MMCV_TORCH=${MMCV_TORCH}" >> $BASH_ENV
20 |
--------------------------------------------------------------------------------
/.github/workflows/scripts/get_mmcv_var.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | TORCH=$1
4 | CUDA=$2
5 |
6 | # 10.2 -> cu102
7 | MMCV_CUDA="cu`echo ${CUDA} | tr -d '.'`"
8 |
9 | # MMCV only provides pre-compiled packages for torch 1.x.0
10 | # which works for any subversions of torch 1.x.
11 | # We force the torch version to be 1.x.0 to ease package searching
12 | # and avoid unnecessary rebuild during MMCV's installation.
13 | TORCH_VER_ARR=(${TORCH//./ })
14 | TORCH_VER_ARR[2]=0
15 | printf -v MMCV_TORCH "%s." "${TORCH_VER_ARR[@]}"
16 | MMCV_TORCH=${MMCV_TORCH%?} # Remove the last dot
17 |
18 | echo "MMCV_CUDA=${MMCV_CUDA}" >> $GITHUB_ENV
19 | echo "MMCV_TORCH=${MMCV_TORCH}" >> $GITHUB_ENV
20 |
--------------------------------------------------------------------------------
/tools/slurm_eval.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 |
5 | PARTITION=$1
6 | JOB_NAME=$2
7 | CONFIG=$3
8 | CKPT=$4
9 | GPUS=${GPUS:-1}
10 | GPUS_PER_NODE=${GPUS_PER_NODE:-1}
11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
12 | PY_ARGS=${@:5}
13 | SRUN_ARGS=${SRUN_ARGS:-""}
14 |
15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
16 | srun -p ${PARTITION} \
17 | --job-name=${JOB_NAME} \
18 | --gres=gpu:${GPUS_PER_NODE} \
19 | --ntasks=${GPUS} \
20 | --ntasks-per-node=${GPUS_PER_NODE} \
21 | --cpus-per-task=${CPUS_PER_TASK} \
22 | --kill-on-bad-exit=1 \
23 | ${SRUN_ARGS} \
24 | python -u tools/evaluation.py ${CONFIG} ${CKPT} --launcher="none" ${PY_ARGS}
25 |
--------------------------------------------------------------------------------
/configs/_base_/models/dcgan/dcgan_128x128.py:
--------------------------------------------------------------------------------
1 | # define GAN model
2 | model = dict(
3 | type='StaticUnconditionalGAN',
4 | generator=dict(
5 | type='DCGANGenerator', output_scale=128, base_channels=1024),
6 | discriminator=dict(
7 | type='DCGANDiscriminator',
8 | input_scale=128,
9 | output_scale=4,
10 | out_channels=100),
11 | gan_loss=dict(type='GANLoss', gan_type='vanilla'))
12 |
13 | train_cfg = dict(disc_steps=1)
14 | test_cfg = None
15 |
16 | # define optimizer
17 | optimizer = dict(
18 | generator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)),
19 | discriminator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)))
20 |
--------------------------------------------------------------------------------
/configs/_base_/models/lsgan/lsgan_128x128.py:
--------------------------------------------------------------------------------
1 | # define GAN model
2 | model = dict(
3 | type='StaticUnconditionalGAN',
4 | generator=dict(
5 | type='LSGANGenerator',
6 | output_scale=128,
7 | base_channels=256,
8 | noise_size=1024),
9 | discriminator=dict(
10 | type='LSGANDiscriminator', input_scale=128, base_channels=64),
11 | gan_loss=dict(type='GANLoss', gan_type='lsgan'))
12 |
13 | train_cfg = dict(disc_steps=1)
14 | test_cfg = None
15 |
16 | # define optimizer
17 | optimizer = dict(
18 | generator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)),
19 | discriminator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)))
20 |
--------------------------------------------------------------------------------
/tools/slurm_eval_multi_gpu.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 |
5 | PARTITION=$1
6 | JOB_NAME=$2
7 | CONFIG=$3
8 | CKPT=$4
9 | GPUS=${GPUS:-8}
10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
12 | PY_ARGS=${@:5}
13 | SRUN_ARGS=${SRUN_ARGS:-""}
14 |
15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
16 | srun -p ${PARTITION} \
17 | --job-name=${JOB_NAME} \
18 | --gres=gpu:${GPUS_PER_NODE} \
19 | --ntasks=${GPUS} \
20 | --ntasks-per-node=${GPUS_PER_NODE} \
21 | --cpus-per-task=${CPUS_PER_TASK} \
22 | --kill-on-bad-exit=1 \
23 | ${SRUN_ARGS} \
24 | python -u tools/evaluation.py ${CONFIG} ${CKPT} --launcher="slurm" ${PY_ARGS}
25 |
--------------------------------------------------------------------------------
/configs/styleganv2/stylegan2_c2_fp16-globalG-partialD_PL-R1-no-scaler_ffhq_256_b4x8_800k.py:
--------------------------------------------------------------------------------
1 | """Config for the `config-f` setting in StyleGAN2."""
2 |
3 | _base_ = ['./stylegan2_c2_ffhq_256_b4x8_800k.py']
4 |
5 | model = dict(
6 | generator=dict(out_size=256, fp16_enabled=True),
7 | discriminator=dict(in_size=256, fp16_enabled=False, num_fp16_scales=4),
8 | )
9 |
10 | total_iters = 800000
11 |
12 | # use ddp wrapper for faster training
13 | use_ddp_wrapper = True
14 | find_unused_parameters = False
15 |
16 | runner = dict(
17 | fp16_loss_scaler=dict(init_scale=512),
18 | is_dynamic_ddp= # noqa
19 | False, # Note that this flag should be False to use DDP wrapper.
20 | )
21 |
--------------------------------------------------------------------------------
/tools/slurm_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 |
5 | PARTITION=$1
6 | JOB_NAME=$2
7 | CONFIG=$3
8 | WORK_DIR=$4
9 | GPUS=${GPUS:-8}
10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
12 | PY_ARGS=${@:5}
13 | SRUN_ARGS=${SRUN_ARGS:-""}
14 |
15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
16 | srun -p ${PARTITION} \
17 | --job-name=${JOB_NAME} \
18 | --gres=gpu:${GPUS_PER_NODE} \
19 | --ntasks=${GPUS} \
20 | --ntasks-per-node=${GPUS_PER_NODE} \
21 | --cpus-per-task=${CPUS_PER_TASK} \
22 | --kill-on-bad-exit=1 \
23 | ${SRUN_ARGS} \
24 | python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS}
25 |
--------------------------------------------------------------------------------
/requirements/docs.txt:
--------------------------------------------------------------------------------
1 | click
2 | docutils==0.16.0
3 | m2r
4 | mmcls==0.18.0
5 | myst-parser
6 | opencv-python!=4.5.5.62,!=4.5.5.64
7 | # Skip problematic opencv-python versions
8 | # MMCV depends opencv-python instead of headless, thus we install opencv-python
9 | # Due to a bug from upstream, we skip this two version
10 | # https://github.com/opencv/opencv-python/issues/602
11 | # https://github.com/opencv/opencv/issues/21366
12 | # It seems to be fixed in https://github.com/opencv/opencv/pull/21382opencv-python
13 | prettytable
14 | -e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
15 | scipy
16 | sphinx==4.0.2
17 | sphinx-copybutton
18 | sphinx_markdown_tables
19 |
--------------------------------------------------------------------------------
/configs/styleganv3/stylegan3_r_afhqv2_512_b4x8_cvt_official_rgb.py:
--------------------------------------------------------------------------------
1 | _base_ = ['./stylegan3_base.py']
2 |
3 | synthesis_cfg = {
4 | 'type': 'SynthesisNetwork',
5 | 'channel_base': 65536,
6 | 'channel_max': 1024,
7 | 'magnitude_ema_beta': 0.999,
8 | 'conv_kernel': 1,
9 | 'use_radial_filters': True
10 | }
11 | model = dict(
12 | type='StaticUnconditionalGAN',
13 | generator=dict(
14 | type='StyleGANv3Generator',
15 | noise_size=512,
16 | style_channels=512,
17 | out_size=512,
18 | img_channels=3,
19 | rgb2bgr=True,
20 | synthesis_cfg=synthesis_cfg),
21 | discriminator=dict(type='StyleGAN2Discriminator', in_size=512))
22 |
--------------------------------------------------------------------------------
/mmgen/models/architectures/stylegan/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .generator_discriminator_v1 import (StyleGAN1Discriminator,
3 | StyleGANv1Generator)
4 | from .generator_discriminator_v2 import (StyleGAN2Discriminator,
5 | StyleGANv2Generator)
6 | from .generator_discriminator_v3 import StyleGANv3Generator
7 | from .mspie import MSStyleGAN2Discriminator, MSStyleGANv2Generator
8 |
9 | __all__ = [
10 | 'StyleGAN2Discriminator', 'StyleGANv2Generator', 'StyleGANv1Generator',
11 | 'StyleGAN1Discriminator', 'MSStyleGAN2Discriminator',
12 | 'MSStyleGANv2Generator', 'StyleGANv3Generator'
13 | ]
14 |
--------------------------------------------------------------------------------
/mmgen/core/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .eval_hooks import GenerativeEvalHook, TranslationEvalHook
3 | from .evaluation import (make_metrics_table, make_vanilla_dataloader,
4 | offline_evaluation, online_evaluation)
5 | from .metric_utils import slerp
6 | from .metrics import (IS, MS_SSIM, PR, SWD, GaussianKLD, ms_ssim,
7 | sliced_wasserstein)
8 |
9 | __all__ = [
10 | 'MS_SSIM', 'SWD', 'ms_ssim', 'sliced_wasserstein', 'offline_evaluation',
11 | 'online_evaluation', 'PR', 'IS', 'slerp', 'GenerativeEvalHook',
12 | 'make_metrics_table', 'make_vanilla_dataloader', 'GaussianKLD',
13 | 'TranslationEvalHook'
14 | ]
15 |
--------------------------------------------------------------------------------
/docs/en/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/zh_cn/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/mmgen/models/architectures/pggan/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .generator_discriminator import PGGANDiscriminator, PGGANGenerator
3 | from .modules import (EqualizedLR, EqualizedLRConvDownModule,
4 | EqualizedLRConvModule, EqualizedLRConvUpModule,
5 | EqualizedLRLinearModule, MiniBatchStddevLayer,
6 | PGGANNoiseTo2DFeat, PixelNorm, equalized_lr)
7 |
8 | __all__ = [
9 | 'EqualizedLR', 'equalized_lr', 'EqualizedLRConvModule',
10 | 'EqualizedLRLinearModule', 'EqualizedLRConvUpModule',
11 | 'EqualizedLRConvDownModule', 'PixelNorm', 'MiniBatchStddevLayer',
12 | 'PGGANNoiseTo2DFeat', 'PGGANGenerator', 'PGGANDiscriminator'
13 | ]
14 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [bdist_wheel]
2 | universal=1
3 |
4 | [aliases]
5 | test=pytest
6 |
7 | [yapf]
8 | based_on_style=pep8
9 | blank_line_before_nested_class_or_def=true
10 | split_before_expression_after_opening_paren=true
11 |
12 | [isort]
13 | line_length=79
14 | multi_line_output=0
15 | extra_standard_library=argparse,inspect,contextlib,hashlib,subprocess,unittest,tempfile,copy,pkg_resources,logging,pickle,platform,setuptools,abc,collections,functools,os,math,time,warnings,random,shutil,sys
16 | known_first_party=mmgen
17 | known_third_party=PIL,click,clip,cv2,imageio,mmcls,mmcv,numpy,prettytable,pytest,pytorch_sphinx_theme,recommonmark,requests,scipy,torch,torchvision,tqdm,ts
18 | no_lines_before=STDLIB,LOCALFOLDER
19 | default_section=THIRDPARTY
20 |
--------------------------------------------------------------------------------
/configs/positional_encoding_in_gans/singan_csg_fish.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../singan/singan_fish.py']
2 |
3 | num_scales = 10 # start from zero
4 | model = dict(
5 | type='PESinGAN',
6 | generator=dict(
7 | type='SinGANMSGeneratorPE',
8 | num_scales=num_scales,
9 | padding=1,
10 | pad_at_head=False,
11 | first_stage_in_channels=2,
12 | positional_encoding=dict(type='CSG')),
13 | discriminator=dict(num_scales=num_scales))
14 |
15 | train_cfg = dict(first_fixed_noises_ch=2)
16 |
17 | data = dict(
18 | train=dict(
19 | img_path='./data/singan/fish-crop.jpg',
20 | min_size=25,
21 | max_size=300,
22 | ))
23 |
24 | dist_params = dict(backend='nccl')
25 | total_iters = 22000
26 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/4-documentation.yml:
--------------------------------------------------------------------------------
1 | name: 📚 Documentation
2 | description: Report an issue related to https://mmgeneration.readthedocs.io/en/latest/.
3 | labels: "docs"
4 | title: "[Docs] "
5 |
6 | body:
7 | - type: textarea
8 | attributes:
9 | label: 📚 The doc issue
10 | description: >
11 | A clear and concise description of what content in https://mmgeneration.readthedocs.io/en/latest/ is an issue.
12 | validations:
13 | required: true
14 |
15 | - type: textarea
16 | attributes:
17 | label: Suggest a potential alternative/fix
18 | description: >
19 | Tell us how we could improve the documentation in this regard.
20 | - type: markdown
21 | attributes:
22 | value: >
23 | Thanks for contributing 🎉!
24 |
--------------------------------------------------------------------------------
/configs/positional_encoding_in_gans/singan_csg_bohemian.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../singan/singan_bohemian.py']
2 |
3 | num_scales = 10 # start from zero
4 | model = dict(
5 | type='PESinGAN',
6 | generator=dict(
7 | type='SinGANMSGeneratorPE',
8 | num_scales=num_scales,
9 | padding=1,
10 | pad_at_head=False,
11 | first_stage_in_channels=2,
12 | positional_encoding=dict(type='CSG')),
13 | discriminator=dict(num_scales=num_scales))
14 |
15 | train_cfg = dict(first_fixed_noises_ch=2)
16 |
17 | data = dict(
18 | train=dict(
19 | img_path='./data/singan/bohemian.png',
20 | min_size=25,
21 | max_size=500,
22 | ))
23 |
24 | dist_params = dict(backend='nccl')
25 | total_iters = 22000
26 |
--------------------------------------------------------------------------------
/mmgen/version.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | __version__ = '0.7.3'
3 |
4 |
5 | def parse_version_info(version_str):
6 | """Parse version information.
7 |
8 | Args:
9 | version_str (str): Version string.
10 |
11 | Returns:
12 | tuple: Version information in tuple.
13 | """
14 | version_info = []
15 | for x in version_str.split('.'):
16 | if x.isdigit():
17 | version_info.append(int(x))
18 | elif x.find('rc') != -1:
19 | patch_version = x.split('rc')
20 | version_info.append(int(patch_version[0]))
21 | version_info.append(f'rc{patch_version[1]}')
22 | return tuple(version_info)
23 |
24 |
25 | version_info = parse_version_info(__version__)
26 |
--------------------------------------------------------------------------------
/mmgen/datasets/quick_test_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | from torch.utils.data import Dataset
4 |
5 | from .builder import DATASETS
6 |
7 |
8 | @DATASETS.register_module()
9 | class QuickTestImageDataset(Dataset):
10 | """Dataset for quickly testing the correctness.
11 |
12 | Args:
13 | size (tuple[int]): The size of the images. Defaults to `None`.
14 | """
15 |
16 | def __init__(self, *args, size=None, **kwargs):
17 | super().__init__()
18 | self.size = size
19 | self.img_tensor = torch.randn(3, self.size[0], self.size[1])
20 |
21 | def __len__(self):
22 | return 10000
23 |
24 | def __getitem__(self, idx):
25 | return dict(real_img=self.img_tensor)
26 |
--------------------------------------------------------------------------------
/mmgen/datasets/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .augmentation import (CenterCropLongEdge, Flip, NumpyPad,
3 | RandomCropLongEdge, RandomImgNoise, Resize)
4 | from .compose import Compose
5 | from .crop import Crop, FixedCrop
6 | from .formatting import Collect, ImageToTensor, ToTensor
7 | from .loading import LoadImageFromFile
8 | from .normalize import Normalize
9 |
10 | __all__ = [
11 | 'LoadImageFromFile',
12 | 'Compose',
13 | 'ImageToTensor',
14 | 'Collect',
15 | 'ToTensor',
16 | 'Flip',
17 | 'Resize',
18 | 'RandomImgNoise',
19 | 'RandomCropLongEdge',
20 | 'CenterCropLongEdge',
21 | 'Normalize',
22 | 'NumpyPad',
23 | 'Crop',
24 | 'FixedCrop',
25 | ]
26 |
--------------------------------------------------------------------------------
/tests/test_datasets/test_dataset_wrappers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from torch.utils.data import Dataset
3 |
4 | from mmgen.datasets import RepeatDataset
5 |
6 |
7 | def test_repeat_dataset():
8 |
9 | class ToyDataset(Dataset):
10 |
11 | def __init__(self):
12 | super(ToyDataset, self).__init__()
13 | self.members = [1, 2, 3, 4, 5]
14 |
15 | def __len__(self):
16 | return len(self.members)
17 |
18 | def __getitem__(self, idx):
19 | return self.members[idx % 5]
20 |
21 | toy_dataset = ToyDataset()
22 | repeat_dataset = RepeatDataset(toy_dataset, 2)
23 | assert len(repeat_dataset) == 10
24 | assert repeat_dataset[2] == 3
25 | assert repeat_dataset[8] == 4
26 |
--------------------------------------------------------------------------------
/docs/zh_cn/index.rst:
--------------------------------------------------------------------------------
1 | 欢迎来到 MMGeneration 的用户手册!
2 | =======================================
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 | :caption: Get Started
7 |
8 | get_started.md
9 | modelzoo_statistics.md
10 |
11 | .. toctree::
12 | :maxdepth: 2
13 | :caption: Quick Run
14 |
15 | quick_run.md
16 |
17 | .. toctree::
18 | :maxdepth: 2
19 | :caption: Tutorials
20 |
21 | tutorials/index.rst
22 |
23 | .. toctree::
24 | :maxdepth: 2
25 | :caption: Notes
26 |
27 | changelog.md
28 | faq.md
29 |
30 | .. toctree::
31 | :caption: Switch Language
32 |
33 | switch_language.md
34 |
35 | .. toctree::
36 | :caption: API Reference
37 |
38 | api.rst
39 |
40 | Indices and tables
41 | ==================
42 |
43 | * :ref:`genindex`
44 | * :ref:`search`
45 |
--------------------------------------------------------------------------------
/configs/ada/metafile.yml:
--------------------------------------------------------------------------------
1 | Collections:
2 | - Metadata:
3 | Architecture:
4 | - ADA
5 | Name: ADA
6 | Paper:
7 | - https://arxiv.org/pdf/2006.06676.pdf
8 | README: configs/ada/README.md
9 | Models:
10 | - Config: https://github.com/open-mmlab/mmgeneration/tree/master/configs/styleganv3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py
11 | In Collection: ADA
12 | Metadata:
13 | Training Data: Others
14 | Name: stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8
15 | Results:
16 | - Dataset: Others
17 | Metrics:
18 | FID50k: 15.09
19 | Iter: 130000.0
20 | Log: '[log]'
21 | Task: Tricks for GANs
22 | Weights: https://download.openmmlab.com/mmgen/stylegan3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8_best_fid_iter_130000_20220401_115101-f2ef498e.pth
23 |
--------------------------------------------------------------------------------
/docs/en/index.rst:
--------------------------------------------------------------------------------
1 | Welcome to MMGeneration's documentation!
2 | =======================================
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 | :caption: Get Started
7 |
8 | get_started.md
9 | modelzoo_statistics.md
10 |
11 | .. toctree::
12 | :maxdepth: 2
13 | :caption: Quick Run
14 |
15 | quick_run.md
16 |
17 | .. toctree::
18 | :maxdepth: 2
19 | :caption: Tutorials
20 |
21 | tutorials/index.rst
22 |
23 | .. toctree::
24 | :maxdepth: 2
25 | :caption: Notes
26 |
27 | changelog.md
28 | faq.md
29 |
30 | .. toctree::
31 | :caption: Switch Language
32 |
33 | switch_language.md
34 |
35 | .. toctree::
36 | :caption: API Reference
37 |
38 | api.rst
39 |
40 | Indices and tables
41 | ==================
42 |
43 | * :ref:`genindex`
44 | * :ref:`search`
45 |
--------------------------------------------------------------------------------
/configs/styleganv2/stylegan2_c2_fp16_partial-GD_PL-no-scaler_ffhq_256_b4x8_800k.py:
--------------------------------------------------------------------------------
1 | """Config for the `config-f` setting in StyleGAN2."""
2 |
3 | _base_ = ['./stylegan2_c2_ffhq_256_b4x8_800k.py']
4 |
5 | model = dict(
6 | generator=dict(out_size=256, num_fp16_scales=4),
7 | discriminator=dict(in_size=256, num_fp16_scales=4),
8 | disc_auxiliary_loss=dict(data_info=dict(loss_scaler='loss_scaler')),
9 | # gen_auxiliary_loss=dict(data_info=dict(loss_scaler='loss_scaler')),
10 | )
11 |
12 | total_iters = 800002
13 |
14 | # use ddp wrapper for faster training
15 | use_ddp_wrapper = True
16 | find_unused_parameters = False
17 |
18 | runner = dict(
19 | fp16_loss_scaler=dict(init_scale=512),
20 | is_dynamic_ddp= # noqa
21 | False, # Note that this flag should be False to use DDP wrapper.
22 | )
23 |
--------------------------------------------------------------------------------
/mmgen/models/architectures/biggan/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .generator_discriminator import BigGANDiscriminator, BigGANGenerator
3 | from .generator_discriminator_deep import (BigGANDeepDiscriminator,
4 | BigGANDeepGenerator)
5 | from .modules import (BigGANConditionBN, BigGANDeepDiscResBlock,
6 | BigGANDeepGenResBlock, BigGANDiscResBlock,
7 | BigGANGenResBlock, SelfAttentionBlock, SNConvModule)
8 |
9 | __all__ = [
10 | 'BigGANGenerator', 'BigGANGenResBlock', 'BigGANConditionBN',
11 | 'BigGANDiscriminator', 'SelfAttentionBlock', 'BigGANDiscResBlock',
12 | 'BigGANDeepDiscriminator', 'BigGANDeepGenerator', 'BigGANDeepDiscResBlock',
13 | 'BigGANDeepGenResBlock', 'SNConvModule'
14 | ]
15 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 | ---
8 |
9 | **Describe the feature**
10 |
11 | **Motivation**
12 | A clear and concise description of the motivation of the feature.
13 | Ex1. It is inconvenient when \[....\].
14 | Ex2. There is a recent paper \[....\], which is very helpful for \[....\].
15 |
16 | **Related resources**
17 | If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 | If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated.
22 |
--------------------------------------------------------------------------------
/configs/wgan-gp/wgangp_GN_celeba-cropped_128_b64x1_160kiter.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/datasets/unconditional_imgs_128x128.py',
3 | '../_base_/models/wgangp/wgangp_base.py'
4 | ]
5 |
6 | data = dict(
7 | samples_per_gpu=64,
8 | train=dict(imgs_root='./data/celeba-cropped/cropped_images_aligned_png/'))
9 |
10 | checkpoint_config = dict(interval=10000, by_epoch=False)
11 | log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
12 |
13 | custom_hooks = [
14 | dict(
15 | type='VisualizeUnconditionalSamples',
16 | output_dir='training_samples',
17 | interval=1000)
18 | ]
19 |
20 | lr_config = None
21 | total_iters = 160000
22 |
23 | metrics = dict(
24 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000),
25 | swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 128, 128)))
26 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/unconditional_imgs_128x128.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'UnconditionalImageDataset'
2 |
3 | train_pipeline = [
4 | dict(type='LoadImageFromFile', key='real_img', io_backend='disk'),
5 | dict(type='Resize', keys=['real_img'], scale=(128, 128)),
6 | dict(
7 | type='Normalize',
8 | keys=['real_img'],
9 | mean=[127.5] * 3,
10 | std=[127.5] * 3,
11 | to_rgb=False),
12 | dict(type='ImageToTensor', keys=['real_img']),
13 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
14 | ]
15 |
16 | # `samples_per_gpu` and `imgs_root` need to be set.
17 | data = dict(
18 | samples_per_gpu=None,
19 | workers_per_gpu=4,
20 | train=dict(type=dataset_type, imgs_root=None, pipeline=train_pipeline),
21 | val=dict(type=dataset_type, imgs_root=None, pipeline=train_pipeline))
22 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/unconditional_imgs_flip_256x256.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'UnconditionalImageDataset'
2 |
3 | train_pipeline = [
4 | dict(type='LoadImageFromFile', key='real_img', io_backend='disk'),
5 | dict(type='Resize', keys=['real_img'], scale=(256, 256)),
6 | dict(type='Flip', keys=['real_img'], direction='horizontal'),
7 | dict(
8 | type='Normalize',
9 | keys=['real_img'],
10 | mean=[127.5] * 3,
11 | std=[127.5] * 3,
12 | to_rgb=False),
13 | dict(type='ImageToTensor', keys=['real_img']),
14 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
15 | ]
16 |
17 | # `samples_per_gpu` and `imgs_root` need to be set.
18 | data = dict(
19 | samples_per_gpu=None,
20 | workers_per_gpu=4,
21 | train=dict(type=dataset_type, imgs_root=None, pipeline=train_pipeline))
22 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/unconditional_imgs_flip_512x512.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'UnconditionalImageDataset'
2 |
3 | train_pipeline = [
4 | dict(type='LoadImageFromFile', key='real_img', io_backend='disk'),
5 | dict(type='Resize', keys=['real_img'], scale=(512, 512)),
6 | dict(type='Flip', keys=['real_img'], direction='horizontal'),
7 | dict(
8 | type='Normalize',
9 | keys=['real_img'],
10 | mean=[127.5] * 3,
11 | std=[127.5] * 3,
12 | to_rgb=False),
13 | dict(type='ImageToTensor', keys=['real_img']),
14 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
15 | ]
16 |
17 | # `samples_per_gpu` and `imgs_root` need to be set.
18 | data = dict(
19 | samples_per_gpu=None,
20 | workers_per_gpu=4,
21 | train=dict(type=dataset_type, imgs_root=None, pipeline=train_pipeline))
22 |
--------------------------------------------------------------------------------
/configs/dcgan/dcgan_lsun-bedroom_64x64_b128x1_5e.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/dcgan/dcgan_64x64.py',
3 | '../_base_/datasets/unconditional_imgs_64x64.py',
4 | '../_base_/default_runtime.py'
5 | ]
6 |
7 | # define dataset
8 | # you must set `samples_per_gpu` and `imgs_root`
9 | data = dict(
10 | samples_per_gpu=128, train=dict(imgs_root='data/lsun/bedroom_train'))
11 |
12 | # adjust running config
13 | lr_config = None
14 | checkpoint_config = dict(interval=100000, by_epoch=False)
15 | custom_hooks = [
16 | dict(
17 | type='VisualizeUnconditionalSamples',
18 | output_dir='training_samples',
19 | interval=10000)
20 | ]
21 |
22 | total_iters = 1500002
23 |
24 | metrics = dict(
25 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000),
26 | swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 64, 64)))
27 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | ARG PYTORCH="1.8.0"
2 | ARG CUDA="11.1"
3 | ARG CUDNN="8"
4 |
5 | FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
6 |
7 | ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX"
8 | ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
9 | ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../"
10 |
11 | RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \
12 | && apt-get clean \
13 | && rm -rf /var/lib/apt/lists/*
14 |
15 | # Install MMCV
16 | RUN pip install mmcv-full==1.3.16 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.8.0/index.html
17 |
18 | # Install MMGeneration
19 | RUN conda clean --all
20 | RUN git clone https://github.com/open-mmlab/mmgeneration.git /mmgen
21 | WORKDIR /mmgen
22 | ENV FORCE_CUDA="1"
23 | RUN pip install -r requirements.txt
24 | RUN pip install --no-cache-dir -e .
25 |
--------------------------------------------------------------------------------
/tests/test_datasets/test_singan_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os.path as osp
3 |
4 | from mmgen.datasets import SinGANDataset
5 |
6 |
7 | class TestSinGANDataset(object):
8 |
9 | @classmethod
10 | def setup_class(cls):
11 | cls.imgs_root = osp.join(
12 | osp.dirname(osp.dirname(__file__)), 'data/image/baboon.png')
13 | cls.min_size = 25
14 | cls.max_size = 250
15 | cls.scale_factor_init = 0.75
16 |
17 | def test_singan_dataset(self):
18 | dataset = SinGANDataset(
19 | self.imgs_root,
20 | min_size=self.min_size,
21 | max_size=self.max_size,
22 | scale_factor_init=self.scale_factor_init)
23 | assert len(dataset) == 1000000
24 |
25 | data_dict = dataset[0]
26 | assert all([f'real_scale{i}' in data_dict for i in range(10)])
27 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/unconditional_imgs_64x64.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'UnconditionalImageDataset'
2 |
3 | train_pipeline = [
4 | dict(
5 | type='LoadImageFromFile',
6 | key='real_img',
7 | io_backend='disk',
8 | ),
9 | dict(type='Resize', keys=['real_img'], scale=(64, 64)),
10 | dict(
11 | type='Normalize',
12 | keys=['real_img'],
13 | mean=[127.5] * 3,
14 | std=[127.5] * 3,
15 | to_rgb=False),
16 | dict(type='ImageToTensor', keys=['real_img']),
17 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
18 | ]
19 |
20 | # `samples_per_gpu` and `imgs_root` need to be set.
21 | data = dict(
22 | samples_per_gpu=None,
23 | workers_per_gpu=4,
24 | train=dict(type=dataset_type, imgs_root=None, pipeline=train_pipeline),
25 | val=dict(type=dataset_type, imgs_root=None, pipeline=train_pipeline))
26 |
--------------------------------------------------------------------------------
/.github/workflows/publish-to-pypi.yml:
--------------------------------------------------------------------------------
1 | name: deploy
2 |
3 | on: push
4 |
5 | concurrency:
6 | group: ${{ github.workflow }}-${{ github.ref }}
7 | cancel-in-progress: true
8 |
9 | jobs:
10 | build-n-publish:
11 | runs-on: ubuntu-latest
12 | if: startsWith(github.event.ref, 'refs/tags')
13 | steps:
14 | - uses: actions/checkout@v2
15 | - name: Set up Python 3.7
16 | uses: actions/setup-python@v1
17 | with:
18 | python-version: 3.7
19 | - name: Build MMGeneration
20 | run: |
21 | pip install torch==1.8.1+cpu torchvision==0.9.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
22 | pip install wheel
23 | python setup.py sdist
24 | - name: Publish distribution to PyPI
25 | run: |
26 | pip install twine
27 | twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }}
28 |
--------------------------------------------------------------------------------
/docs/en/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/LICENSES.md:
--------------------------------------------------------------------------------
1 | # Licenses for special operations
2 |
3 | In this file, we list the operations with other licenses instead of Apache 2.0. Users should be careful about adopting these operations in any commercial matters.
4 |
5 | | Operation | Files | License |
6 | | :------------------: | :-----------------------------------------------------------------------------------------------------------------------------------: | :------------: |
7 | | conv2d_gradfix | [mmgen/ops/conv2d_gradfix.py](https://github.com/open-mmlab/mmgeneration/blob/master/mmgen/ops/conv2d_gradfix.py) | NVIDIA License |
8 | | compute_pr_distances | [mmgen/core/evaluation/metric_utils.py](https://github.com/open-mmlab/mmgeneration/blob/master/mmgen/core/evaluation/metric_utils.py) | NVIDIA License |
9 |
--------------------------------------------------------------------------------
/docs/zh_cn/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/configs/_base_/models/sagan/sagan_32x32.py:
--------------------------------------------------------------------------------
1 | # define GAN model
2 | model = dict(
3 | type='BasiccGAN',
4 | generator=dict(
5 | type='SAGANGenerator',
6 | output_scale=32,
7 | base_channels=256,
8 | attention_cfg=dict(type='SelfAttentionBlock'),
9 | attention_after_nth_block=2,
10 | with_spectral_norm=True),
11 | discriminator=dict(
12 | type='ProjDiscriminator',
13 | input_scale=32,
14 | base_channels=128,
15 | attention_cfg=dict(type='SelfAttentionBlock'),
16 | attention_after_nth_block=1,
17 | with_spectral_norm=True),
18 | gan_loss=dict(type='GANLoss', gan_type='hinge'))
19 |
20 | train_cfg = dict(disc_steps=5)
21 | test_cfg = None
22 |
23 | # define optimizer
24 | optimizer = dict(
25 | generator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)),
26 | discriminator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)))
27 |
--------------------------------------------------------------------------------
/configs/_base_/models/sagan/sagan_128x128.py:
--------------------------------------------------------------------------------
1 | # define GAN model
2 | model = dict(
3 | type='BasiccGAN',
4 | generator=dict(
5 | type='SAGANGenerator',
6 | output_scale=128,
7 | base_channels=64,
8 | attention_cfg=dict(type='SelfAttentionBlock'),
9 | attention_after_nth_block=4,
10 | with_spectral_norm=True),
11 | discriminator=dict(
12 | type='ProjDiscriminator',
13 | input_scale=128,
14 | base_channels=64,
15 | attention_cfg=dict(type='SelfAttentionBlock'),
16 | attention_after_nth_block=1,
17 | with_spectral_norm=True),
18 | gan_loss=dict(type='GANLoss', gan_type='hinge'))
19 |
20 | train_cfg = dict(disc_steps=1)
21 | test_cfg = None
22 |
23 | # define optimizer
24 | optimizer = dict(
25 | generator=dict(type='Adam', lr=0.0001, betas=(0.0, 0.999)),
26 | discriminator=dict(type='Adam', lr=0.0004, betas=(0.0, 0.999)))
27 |
--------------------------------------------------------------------------------
/configs/positional_encoding_in_gans/singan_spe-dim4_fish.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../singan/singan_fish.py']
2 |
3 | embedding_dim = 4
4 | num_scales = 10 # start from zero
5 | model = dict(
6 | type='PESinGAN',
7 | generator=dict(
8 | type='SinGANMSGeneratorPE',
9 | num_scales=num_scales,
10 | padding=1,
11 | pad_at_head=False,
12 | first_stage_in_channels=embedding_dim * 2,
13 | positional_encoding=dict(
14 | type='SPE',
15 | embedding_dim=embedding_dim,
16 | padding_idx=0,
17 | init_size=512,
18 | div_half_dim=False,
19 | center_shift=200)),
20 | discriminator=dict(num_scales=num_scales))
21 |
22 | data = dict(
23 | train=dict(
24 | img_path='./data/singan/fish-crop.jpg',
25 | min_size=25,
26 | max_size=300,
27 | ))
28 |
29 | dist_params = dict(backend='nccl')
30 | total_iters = 22000
31 |
--------------------------------------------------------------------------------
/docs/en/faq.md:
--------------------------------------------------------------------------------
1 | # FAQ
2 |
3 | We list some common troubles faced by many users and their corresponding solutions here. Feel free to enrich the list if you find any frequent issues and have ways to help others to solve them. If the contents here do not cover your issue, please create an issue using the [provided templates](https://github.com/open-mmlab/mmgeneration/blob/master/.github/ISSUE_TEMPLATE/error-report.md) and make sure you fill in all required information in the template.
4 |
5 | ## Installation
6 |
7 | - Compatible MMGeneration and MMCV versions are shown as below. Please choose the correct version of MMCV to avoid installation issues.
8 |
9 | | MMGeneration version | MMCV version |
10 | | :------------------: | :--------------: |
11 | | master | mmcv-full>=1.3.0 |
12 |
13 | Note: You need to run `pip uninstall mmcv` first if you have mmcv installed.
14 | If mmcv and mmcv-full are both installed, there will be `ModuleNotFoundError`.
15 |
--------------------------------------------------------------------------------
/tests/test_datasets/test_unconditional_image_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os.path as osp
3 |
4 | from mmgen.datasets import UnconditionalImageDataset
5 |
6 |
7 | class TestUnconditionalImageDataset(object):
8 |
9 | @classmethod
10 | def setup_class(cls):
11 | cls.imgs_root = osp.join(osp.dirname(__file__), '..', 'data/image')
12 | cls.default_pipeline = [
13 | dict(type='LoadImageFromFile', io_backend='disk', key='real_img')
14 | ]
15 |
16 | def test_unconditional_imgs_dataset(self):
17 | dataset = UnconditionalImageDataset(
18 | self.imgs_root, pipeline=self.default_pipeline)
19 | assert len(dataset) == 6
20 | img = dataset[2]['real_img']
21 | assert img.ndim == 3
22 | assert repr(dataset) == (
23 | f'dataset_name: {dataset.__class__}, '
24 | f'total {6} images in imgs_root: {self.imgs_root}')
25 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/Inception_Score.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'UnconditionalImageDataset'
2 |
3 | # To be noted that, `Resize` operation with `pillow` backend and
4 | # `bicubic` interpolation is the must for correct IS evaluation
5 | val_pipeline = [
6 | dict(
7 | type='LoadImageFromFile',
8 | key='real_img',
9 | io_backend='disk',
10 | ),
11 | dict(
12 | type='Resize',
13 | keys=['real_img'],
14 | scale=(299, 299),
15 | backend='pillow',
16 | interpolation='bicubic'),
17 | dict(
18 | type='Normalize',
19 | keys=['real_img'],
20 | mean=[127.5] * 3,
21 | std=[127.5] * 3,
22 | to_rgb=True),
23 | dict(type='ImageToTensor', keys=['real_img']),
24 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
25 | ]
26 |
27 | data = dict(
28 | samples_per_gpu=None,
29 | workers_per_gpu=4,
30 | val=dict(type=dataset_type, imgs_root=None, pipeline=val_pipeline))
31 |
--------------------------------------------------------------------------------
/.github/workflows/lint.yml:
--------------------------------------------------------------------------------
1 | name: lint
2 |
3 | on: [push, pull_request]
4 |
5 | concurrency:
6 | group: ${{ github.workflow }}-${{ github.ref }}
7 | cancel-in-progress: true
8 |
9 | jobs:
10 | lint:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/checkout@v2
14 | - name: Set up Python 3.7
15 | uses: actions/setup-python@v2
16 | with:
17 | python-version: 3.7
18 | - name: Install pre-commit hook
19 | run: |
20 | pip install pre-commit
21 | pre-commit install
22 | - name: Linting
23 | run: pre-commit run --all-files
24 | - name: Check docstring coverage
25 | run: |
26 | pip install interrogate
27 | interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --exclude mmgen/ops --ignore-regex "__repr__" --fail-under 50 mmgen
28 | - name: Setup tmate session
29 | if: ${{ failure() }}
30 | uses: mxschmitt/action-tmate@v3
31 |
--------------------------------------------------------------------------------
/mmgen/core/registry.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmcv.utils import Registry, build_from_cfg
3 |
4 | METRICS = Registry('metric')
5 |
6 |
7 | def build(cfg, registry, default_args=None):
8 | """Build a module.
9 |
10 | Args:
11 | cfg (dict, list[dict]): The config of modules, is is either a dict
12 | or a list of configs.
13 | registry (:obj:`Registry`): A registry the module belongs to.
14 | default_args (dict, optional): Default arguments to build the module.
15 | Defaults to None.
16 | Returns:
17 | nn.Module: A built nn module.
18 | """
19 | if isinstance(cfg, list):
20 | modules = [
21 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
22 | ]
23 | return modules
24 |
25 | return build_from_cfg(cfg, registry, default_args)
26 |
27 |
28 | def build_metric(cfg):
29 | """Build a metric calculator."""
30 | return build(cfg, METRICS)
31 |
--------------------------------------------------------------------------------
/configs/positional_encoding_in_gans/singan_spe-dim4_bohemian.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../singan/singan_bohemian.py']
2 |
3 | embedding_dim = 4
4 | num_scales = 10 # start from zero
5 | model = dict(
6 | type='PESinGAN',
7 | generator=dict(
8 | type='SinGANMSGeneratorPE',
9 | num_scales=num_scales,
10 | padding=1,
11 | pad_at_head=False,
12 | first_stage_in_channels=embedding_dim * 2,
13 | positional_encoding=dict(
14 | type='SPE',
15 | embedding_dim=embedding_dim,
16 | padding_idx=0,
17 | init_size=512,
18 | div_half_dim=False,
19 | center_shift=200)),
20 | discriminator=dict(num_scales=num_scales))
21 |
22 | train_cfg = dict(first_fixed_noises_ch=embedding_dim * 2)
23 |
24 | data = dict(
25 | train=dict(
26 | img_path='./data/singan/bohemian.png',
27 | min_size=25,
28 | max_size=500,
29 | ))
30 |
31 | dist_params = dict(backend='nccl')
32 | total_iters = 22000
33 |
--------------------------------------------------------------------------------
/configs/positional_encoding_in_gans/singan_spe-dim8_bohemian.py:
--------------------------------------------------------------------------------
1 | _base_ = ['../singan/singan_bohemian.py']
2 |
3 | embedding_dim = 4
4 | num_scales = 10 # start from zero
5 | model = dict(
6 | type='PESinGAN',
7 | generator=dict(
8 | type='SinGANMSGeneratorPE',
9 | num_scales=num_scales,
10 | padding=1,
11 | pad_at_head=False,
12 | first_stage_in_channels=embedding_dim * 2,
13 | positional_encoding=dict(
14 | type='SPE',
15 | embedding_dim=embedding_dim,
16 | padding_idx=0,
17 | init_size=512,
18 | div_half_dim=False,
19 | center_shift=200)),
20 | discriminator=dict(num_scales=num_scales))
21 |
22 | train_cfg = dict(first_fixed_noises_ch=embedding_dim * 2)
23 |
24 | data = dict(
25 | train=dict(
26 | img_path='./data/singan/bohemian.png',
27 | min_size=25,
28 | max_size=500,
29 | ))
30 |
31 | dist_params = dict(backend='nccl')
32 | total_iters = 22000
33 |
--------------------------------------------------------------------------------
/configs/_base_/default_runtime.py:
--------------------------------------------------------------------------------
1 | checkpoint_config = dict(interval=10000, by_epoch=False)
2 | # yapf:disable
3 | log_config = dict(
4 | interval=100,
5 | hooks=[
6 | dict(type='TextLoggerHook'),
7 | # dict(type='TensorboardLoggerHook'),
8 | ])
9 | # yapf:enable
10 |
11 | custom_hooks = [
12 | dict(
13 | type='VisualizeUnconditionalSamples',
14 | output_dir='training_samples',
15 | interval=1000)
16 | ]
17 |
18 | # use dynamic runner
19 | runner = dict(
20 | type='DynamicIterBasedRunner',
21 | is_dynamic_ddp=True,
22 | pass_training_status=True)
23 |
24 | dist_params = dict(backend='nccl')
25 | log_level = 'INFO'
26 | load_from = None
27 | resume_from = None
28 | workflow = [('train', 10000)]
29 | find_unused_parameters = True
30 | cudnn_benchmark = True
31 |
32 | # disable opencv multithreading to avoid system being overloaded
33 | opencv_num_threads = 0
34 | # set multi-process start method as `fork` to speed up the training
35 | mp_start_method = 'fork'
36 |
--------------------------------------------------------------------------------
/configs/styleganv2/stylegan2_c2_apex_fp16_quicktest_ffhq_256_b4x8_800k.py:
--------------------------------------------------------------------------------
1 | """Config for the `config-f` setting in StyleGAN2."""
2 |
3 | _base_ = ['./stylegan2_c2_ffhq_256_b4x8_800k.py']
4 |
5 | model = dict(
6 | generator=dict(out_size=256),
7 | discriminator=dict(in_size=256, convert_input_fp32=False),
8 | # disc_auxiliary_loss=dict(use_apex_amp=True),
9 | # gen_auxiliary_loss=dict(use_apex_amp=True),
10 | )
11 |
12 | dataset_type = 'QuickTestImageDataset'
13 | data = dict(
14 | samples_per_gpu=2,
15 | train=dict(type=dataset_type, size=(256, 256)),
16 | val=dict(type=dataset_type, size=(256, 256)))
17 |
18 | log_config = dict(interval=1)
19 |
20 | total_iters = 800002
21 |
22 | apex_amp = dict(
23 | mode='gan', init_args=dict(opt_level='O1', num_losses=2, loss_scale=512.))
24 |
25 | evaluation = dict(
26 | type='GenerativeEvalHook',
27 | interval=10000,
28 | metrics=dict(
29 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True),
30 | sample_kwargs=dict(sample_model='ema'))
31 |
--------------------------------------------------------------------------------
/configs/styleganv2/stylegan2_c2_fp16_quicktest_ffhq_256_b4x8_800k.py:
--------------------------------------------------------------------------------
1 | """Config for the `config-f` setting in StyleGAN2."""
2 |
3 | _base_ = ['./stylegan2_c2_ffhq_256_b4x8_800k.py']
4 |
5 | model = dict(
6 | generator=dict(out_size=256, num_fp16_scales=4),
7 | discriminator=dict(in_size=256, num_fp16_scales=4),
8 | disc_auxiliary_loss=dict(data_info=dict(loss_scaler='loss_scaler')),
9 | # gen_auxiliary_loss=dict(data_info=dict(loss_scaler='loss_scaler')),
10 | )
11 |
12 | dataset_type = 'QuickTestImageDataset'
13 | data = dict(
14 | samples_per_gpu=2,
15 | train=dict(type=dataset_type, size=(256, 256)),
16 | val=dict(type=dataset_type, size=(256, 256)))
17 |
18 | log_config = dict(interval=1)
19 |
20 | total_iters = 800002
21 |
22 | runner = dict(fp16_loss_scaler=dict(init_scale=512))
23 |
24 | evaluation = dict(
25 | type='GenerativeEvalHook',
26 | interval=10000,
27 | metrics=dict(
28 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True),
29 | sample_kwargs=dict(sample_model='ema'))
30 |
--------------------------------------------------------------------------------
/mmgen/models/common/dist_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | import torch.autograd as autograd
4 | import torch.distributed as dist
5 |
6 |
7 | class AllGatherLayer(autograd.Function):
8 | """All gather layer with backward propagation path.
9 |
10 | Indeed, this module is to make ``dist.all_gather()`` in the backward graph.
11 | Such kind of operation has been widely used in Moco and other contrastive
12 | learning algorithms.
13 | """
14 |
15 | @staticmethod
16 | def forward(ctx, x):
17 | """Forward function."""
18 | ctx.save_for_backward(x)
19 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
20 | dist.all_gather(output, x)
21 | return tuple(output)
22 |
23 | @staticmethod
24 | def backward(ctx, *grad_outputs):
25 | """Backward function."""
26 | x, = ctx.saved_tensors
27 | grad_out = torch.zeros_like(x)
28 | grad_out = grad_outputs[dist.get_rank()]
29 | return grad_out
30 |
--------------------------------------------------------------------------------
/tests/test_modules/test_lpips.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmgen.models.architectures import PerceptualLoss
6 |
7 |
8 | class TestLpips:
9 |
10 | @classmethod
11 | def setup_class(cls):
12 | cls.pretrained = False
13 |
14 | def test_lpips(self):
15 | percept = PerceptualLoss(use_gpu=False, pretrained=self.pretrained)
16 | img_a = torch.randn((2, 3, 256, 256))
17 | img_b = torch.randn((2, 3, 256, 256))
18 | perceptual_loss = percept(img_a, img_b)
19 | assert perceptual_loss.shape == (2, 1, 1, 1)
20 |
21 | @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
22 | def test_lpips_cuda(self):
23 | percept = PerceptualLoss(use_gpu=True, pretrained=self.pretrained)
24 | img_a = torch.randn((2, 3, 256, 256)).cuda()
25 | img_b = torch.randn((2, 3, 256, 256)).cuda()
26 | perceptual_loss = percept(img_a, img_b)
27 | assert perceptual_loss.shape == (2, 1, 1, 1)
28 |
--------------------------------------------------------------------------------
/configs/styleganv2/stylegan2_c2_fp16-global_quicktest_ffhq_256_b4x8_800k.py:
--------------------------------------------------------------------------------
1 | """Config for the `config-f` setting in StyleGAN2."""
2 |
3 | _base_ = ['./stylegan2_c2_ffhq_256_b4x8_800k.py']
4 |
5 | model = dict(
6 | generator=dict(out_size=256, fp16_enabled=True),
7 | discriminator=dict(in_size=256, fp16_enabled=True),
8 | disc_auxiliary_loss=dict(data_info=dict(loss_scaler='loss_scaler')),
9 | # gen_auxiliary_loss=dict(data_info=dict(loss_scaler='loss_scaler')),
10 | )
11 |
12 | dataset_type = 'QuickTestImageDataset'
13 | data = dict(
14 | samples_per_gpu=2,
15 | train=dict(type=dataset_type, size=(256, 256)),
16 | val=dict(type=dataset_type, size=(256, 256)))
17 |
18 | log_config = dict(interval=1)
19 |
20 | total_iters = 800002
21 |
22 | runner = dict(fp16_loss_scaler=dict(init_scale=512))
23 |
24 | evaluation = dict(
25 | type='GenerativeEvalHook',
26 | interval=10000,
27 | metrics=dict(
28 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True),
29 | sample_kwargs=dict(sample_model='ema'))
30 |
--------------------------------------------------------------------------------
/tests/test_datasets/test_persistent_worker.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os.path as osp
3 |
4 | from mmgen.datasets.builder import build_dataloader, build_dataset
5 |
6 |
7 | class TestPersistentWorker(object):
8 |
9 | @classmethod
10 | def setup_class(cls):
11 | imgs_root = osp.join(osp.dirname(__file__), '..', 'data/image')
12 | train_pipeline = [
13 | dict(type='LoadImageFromFile', io_backend='disk', key='real_img')
14 | ]
15 | cls.config = dict(
16 | samples_per_gpu=1,
17 | workers_per_gpu=4,
18 | drop_last=True,
19 | persistent_workers=True)
20 |
21 | cls.data_cfg = dict(
22 | type='UnconditionalImageDataset',
23 | imgs_root=imgs_root,
24 | pipeline=train_pipeline,
25 | test_mode=False)
26 |
27 | def test_persistent_worker(self):
28 | # test non-persistent-worker
29 | dataset = build_dataset(self.data_cfg)
30 | build_dataloader(dataset, **self.config)
31 |
--------------------------------------------------------------------------------
/mmgen/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import mmcv
3 |
4 | from .version import __version__, parse_version_info, version_info
5 |
6 |
7 | def digit_version(version_str):
8 | digit_version = []
9 | for x in version_str.split('.'):
10 | if x.isdigit():
11 | digit_version.append(int(x))
12 | elif x.find('rc') != -1:
13 | patch_version = x.split('rc')
14 | digit_version.append(int(patch_version[0]) - 1)
15 | digit_version.append(int(patch_version[1]))
16 | return digit_version
17 |
18 |
19 | mmcv_minimum_version = '1.3.0'
20 | mmcv_maximum_version = '1.8.0'
21 | mmcv_version = digit_version(mmcv.__version__)
22 |
23 |
24 | assert (mmcv_version >= digit_version(mmcv_minimum_version)
25 | and mmcv_version <= digit_version(mmcv_maximum_version)), \
26 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \
27 | f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
28 |
29 | __all__ = ['__version__', 'version_info', 'parse_version_info']
30 |
--------------------------------------------------------------------------------
/configs/_base_/models/biggan/biggan_32x32.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='BasiccGAN',
3 | num_classes=10,
4 | generator=dict(
5 | type='BigGANGenerator',
6 | output_scale=32,
7 | noise_size=128,
8 | num_classes=10,
9 | base_channels=64,
10 | with_shared_embedding=False,
11 | sn_eps=1e-8,
12 | sn_style='torch',
13 | init_type='N02',
14 | split_noise=False,
15 | auto_sync_bn=False),
16 | discriminator=dict(
17 | type='BigGANDiscriminator',
18 | input_scale=32,
19 | num_classes=10,
20 | base_channels=64,
21 | sn_eps=1e-8,
22 | sn_style='torch',
23 | init_type='N02',
24 | with_spectral_norm=True),
25 | gan_loss=dict(type='GANLoss', gan_type='hinge'))
26 |
27 | train_cfg = dict(
28 | disc_steps=4, gen_steps=1, batch_accumulation_steps=1, use_ema=True)
29 | test_cfg = None
30 | optimizer = dict(
31 | generator=dict(type='Adam', lr=0.0002, betas=(0.0, 0.999)),
32 | discriminator=dict(type='Adam', lr=0.0002, betas=(0.0, 0.999)))
33 |
--------------------------------------------------------------------------------
/mmgen/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .builder import build_dataloader, build_dataset
3 | from .dataset_wrappers import RepeatDataset
4 | from .grow_scale_image_dataset import GrowScaleImgDataset
5 | from .paired_image_dataset import PairedImageDataset
6 | from .pipelines import (Collect, Compose, Flip, ImageToTensor,
7 | LoadImageFromFile, Normalize, Resize, ToTensor)
8 | from .quick_test_dataset import QuickTestImageDataset
9 | from .samplers import DistributedSampler
10 | from .singan_dataset import SinGANDataset
11 | from .unconditional_image_dataset import UnconditionalImageDataset
12 | from .unpaired_image_dataset import UnpairedImageDataset
13 |
14 | __all__ = [
15 | 'build_dataloader', 'build_dataset', 'LoadImageFromFile',
16 | 'DistributedSampler', 'UnconditionalImageDataset', 'Compose', 'ToTensor',
17 | 'ImageToTensor', 'Collect', 'Flip', 'Resize', 'RepeatDataset', 'Normalize',
18 | 'GrowScaleImgDataset', 'SinGANDataset', 'PairedImageDataset',
19 | 'UnpairedImageDataset', 'QuickTestImageDataset'
20 | ]
21 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/unconditional_imgs_flip_lanczos_resize_256x256.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'UnconditionalImageDataset'
2 |
3 | train_pipeline = [
4 | dict(type='LoadImageFromFile', key='real_img'),
5 | dict(
6 | type='Resize',
7 | keys=['real_img'],
8 | scale=(256, 256),
9 | interpolation='lanczos',
10 | backend='pillow'),
11 | dict(type='Flip', keys=['real_img'], direction='horizontal'),
12 | dict(
13 | type='Normalize',
14 | keys=['real_img'],
15 | mean=[127.5] * 3,
16 | std=[127.5] * 3,
17 | to_rgb=False),
18 | dict(type='ImageToTensor', keys=['real_img']),
19 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
20 | ]
21 |
22 | # `samples_per_gpu` and `imgs_root` need to be set.
23 | data = dict(
24 | samples_per_gpu=None,
25 | workers_per_gpu=4,
26 | train=dict(
27 | type='RepeatDataset',
28 | times=100,
29 | dataset=dict(
30 | type=dataset_type, imgs_root=None, pipeline=train_pipeline)),
31 | val=dict(type=dataset_type, imgs_root=None, pipeline=train_pipeline))
32 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/2-feature-request.yml:
--------------------------------------------------------------------------------
1 | name: 🚀 Feature request
2 | description: Suggest an idea for this project
3 | labels: [feature-request]
4 | title: "[Feature] "
5 |
6 | body:
7 | - type: markdown
8 | attributes:
9 | value: |
10 | We strongly appreciate you creating a PR to implete this feature [here](https://github.com/open-mmlab/mmgeneration/pulls)!
11 | If you need our help, please fill in as much of the following form as you're able.
12 |
13 | **The less clear the description, the longer it will take to solve it.**
14 |
15 | - type: textarea
16 | attributes:
17 | label: What is the problem this feature will solve?
18 | placeholder: |
19 | E.g., It is inconvenient when \[....\].
20 | validations:
21 | required: true
22 |
23 | - type: textarea
24 | attributes:
25 | label: What is the feature?
26 | validations:
27 | required: true
28 |
29 | - type: textarea
30 | attributes:
31 | label: What alternatives have you considered?
32 | description: |
33 | Add any other context or screenshots about the feature request here.
34 |
--------------------------------------------------------------------------------
/configs/_base_/models/biggan/biggan_128x128.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='BasiccGAN',
3 | generator=dict(
4 | type='BigGANGenerator',
5 | output_scale=128,
6 | noise_size=120,
7 | num_classes=1000,
8 | base_channels=96,
9 | shared_dim=128,
10 | with_shared_embedding=True,
11 | sn_eps=1e-6,
12 | init_type='ortho',
13 | act_cfg=dict(type='ReLU', inplace=True),
14 | split_noise=True,
15 | auto_sync_bn=False),
16 | discriminator=dict(
17 | type='BigGANDiscriminator',
18 | input_scale=128,
19 | num_classes=1000,
20 | base_channels=96,
21 | sn_eps=1e-6,
22 | init_type='ortho',
23 | act_cfg=dict(type='ReLU', inplace=True),
24 | with_spectral_norm=True),
25 | gan_loss=dict(type='GANLoss', gan_type='hinge'))
26 |
27 | train_cfg = dict(
28 | disc_steps=8, gen_steps=1, batch_accumulation_steps=8, use_ema=True)
29 | test_cfg = None
30 | optimizer = dict(
31 | generator=dict(type='Adam', lr=0.0001, betas=(0.0, 0.999), eps=1e-6),
32 | discriminator=dict(type='Adam', lr=0.0004, betas=(0.0, 0.999), eps=1e-6))
33 |
--------------------------------------------------------------------------------
/docs/en/stat.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import glob
3 | import os.path as osp
4 | import re
5 |
6 | url_prefix = 'https://github.com/open-mmlab/mmgeneration/blob/master/'
7 |
8 | files = sorted(glob.glob('../../configs/*/README.md'))
9 |
10 | stats = []
11 | titles = []
12 | num_ckpts = 0
13 |
14 | for f in files:
15 | url = osp.dirname(f.replace('../../', url_prefix))
16 |
17 | with open(f, 'r') as content_file:
18 | content = content_file.read()
19 |
20 | title = content.split('\n')[0].replace('# ', '')
21 |
22 | titles.append(title)
23 | ckpts = set(x.lower().strip()
24 | for x in re.findall(r'https?://download.(.*?)\.pth', content)
25 | if 'mmgen' in x)
26 |
27 | num_ckpts += len(ckpts)
28 | statsmsg = f"""
29 | \t* [{title}]({url}) ({len(ckpts)} ckpts)
30 | """
31 | stats.append((title, ckpts, statsmsg))
32 |
33 | msglist = '\n'.join(x for _, _, x in stats)
34 |
35 | modelzoo = f"""
36 | # Model Zoo Statistics
37 |
38 | * Number of papers: {len(titles)}
39 | * Number of checkpoints: {num_ckpts}
40 | {msglist}
41 | """
42 |
43 | with open('modelzoo_statistics.md', 'w') as f:
44 | f.write(modelzoo)
45 |
--------------------------------------------------------------------------------
/docs/zh_cn/stat.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import glob
3 | import os.path as osp
4 | import re
5 |
6 | url_prefix = 'https://github.com/open-mmlab/mmgeneration/blob/master/'
7 |
8 | files = sorted(glob.glob('../../configs/*/README.md'))
9 |
10 | stats = []
11 | titles = []
12 | num_ckpts = 0
13 |
14 | for f in files:
15 | url = osp.dirname(f.replace('../', url_prefix))
16 |
17 | with open(f, 'r') as content_file:
18 | content = content_file.read()
19 |
20 | title = content.split('\n')[0].replace('# ', '')
21 |
22 | titles.append(title)
23 | ckpts = set(x.lower().strip()
24 | for x in re.findall(r'https?://download.(.*?)\.pth', content)
25 | if 'mmgen' in x)
26 |
27 | num_ckpts += len(ckpts)
28 | statsmsg = f"""
29 | \t* [{title}]({url}) ({len(ckpts)} ckpts)
30 | """
31 | stats.append((title, ckpts, statsmsg))
32 |
33 | msglist = '\n'.join(x for _, _, x in stats)
34 |
35 | modelzoo = f"""
36 | # Model Zoo Statistics
37 |
38 | * Number of papers: {len(titles)}
39 | * Number of checkpoints: {num_ckpts}
40 | {msglist}
41 | """
42 |
43 | with open('modelzoo_statistics.md', 'w') as f:
44 | f.write(modelzoo)
45 |
--------------------------------------------------------------------------------
/configs/dcgan/dcgan_celeba-cropped_64_b128x1_300k.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/dcgan/dcgan_64x64.py',
3 | '../_base_/datasets/unconditional_imgs_64x64.py',
4 | '../_base_/default_runtime.py'
5 | ]
6 |
7 | # define dataset
8 | # you must set `samples_per_gpu` and `imgs_root`
9 | data = dict(
10 | samples_per_gpu=128,
11 | train=dict(imgs_root='data/celeba-cropped/cropped_images_aligned_png'))
12 |
13 | # adjust running config
14 | lr_config = None
15 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=20)
16 | custom_hooks = [
17 | dict(
18 | type='VisualizeUnconditionalSamples',
19 | output_dir='training_samples',
20 | interval=10000)
21 | ]
22 |
23 | total_iters = 300002
24 |
25 | # use ddp wrapper for faster training
26 | use_ddp_wrapper = True
27 | find_unused_parameters = False
28 |
29 | runner = dict(
30 | type='DynamicIterBasedRunner',
31 | is_dynamic_ddp=False, # Note that this flag should be False.
32 | pass_training_status=True)
33 |
34 | metrics = dict(
35 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000),
36 | swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 64, 64)))
37 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/lsun-car_pad_512.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'UnconditionalImageDataset'
2 |
3 | train_pipeline = [
4 | dict(
5 | type='LoadImageFromFile',
6 | key='real_img',
7 | io_backend='disk',
8 | ),
9 | dict(type='Resize', keys=['real_img'], scale=(512, 384)),
10 | dict(
11 | type='NumpyPad',
12 | keys=['real_img'],
13 | padding=((64, 64), (0, 0), (0, 0)),
14 | ),
15 | dict(type='Flip', keys=['real_img'], direction='horizontal'),
16 | dict(
17 | type='Normalize',
18 | keys=['real_img'],
19 | mean=[127.5] * 3,
20 | std=[127.5] * 3,
21 | to_rgb=False),
22 | dict(type='ImageToTensor', keys=['real_img']),
23 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
24 | ]
25 |
26 | # `samples_per_gpu` and `imgs_root` need to be set.
27 | data = dict(
28 | samples_per_gpu=None,
29 | workers_per_gpu=4,
30 | train=dict(
31 | type='RepeatDataset',
32 | times=5,
33 | dataset=dict(
34 | type=dataset_type, imgs_root=None, pipeline=train_pipeline)),
35 | val=dict(type=dataset_type, imgs_root=None, pipeline=train_pipeline))
36 |
--------------------------------------------------------------------------------
/configs/biggan/biggan_128x128_cvt_BigGAN-PyTorch_rgb.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='BasiccGAN',
3 | generator=dict(
4 | type='BigGANGenerator',
5 | output_scale=128,
6 | noise_size=120,
7 | num_classes=1000,
8 | base_channels=96,
9 | shared_dim=128,
10 | with_shared_embedding=True,
11 | sn_eps=1e-6,
12 | init_type='ortho',
13 | act_cfg=dict(type='ReLU', inplace=True),
14 | split_noise=True,
15 | auto_sync_bn=False,
16 | rgb2bgr=True),
17 | discriminator=dict(
18 | type='BigGANDiscriminator',
19 | input_scale=128,
20 | num_classes=1000,
21 | base_channels=96,
22 | sn_eps=1e-6,
23 | init_type='ortho',
24 | act_cfg=dict(type='ReLU', inplace=True),
25 | with_spectral_norm=True),
26 | gan_loss=dict(type='GANLoss', gan_type='hinge'))
27 |
28 | train_cfg = dict(
29 | disc_steps=8, gen_steps=1, batch_accumulation_steps=8, use_ema=True)
30 | test_cfg = None
31 | optimizer = dict(
32 | generator=dict(type='Adam', lr=0.0001, betas=(0.0, 0.999), eps=1e-6),
33 | discriminator=dict(type='Adam', lr=0.0004, betas=(0.0, 0.999), eps=1e-6))
34 |
--------------------------------------------------------------------------------
/configs/singan/singan_balloons.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/singan/singan.py', '../_base_/datasets/singan.py',
3 | '../_base_/default_runtime.py'
4 | ]
5 |
6 | num_scales = 8 # start from zero
7 | model = dict(
8 | generator=dict(num_scales=num_scales),
9 | discriminator=dict(num_scales=num_scales))
10 |
11 | train_cfg = dict(
12 | noise_weight_init=0.1,
13 | iters_per_scale=2000,
14 | )
15 |
16 | # test_cfg = dict(
17 | # _delete_ = True
18 | # pkl_data = 'path to pkl data'
19 | # )
20 |
21 | data = dict(train=dict(img_path='./data/singan/balloons.png'))
22 |
23 | optimizer = None
24 | lr_config = None
25 | checkpoint_config = dict(by_epoch=False, interval=2000, max_keep_ckpts=3)
26 |
27 | custom_hooks = [
28 | dict(
29 | type='MMGenVisualizationHook',
30 | output_dir='visual',
31 | interval=500,
32 | bgr2rgb=True,
33 | res_name_list=['fake_imgs', 'recon_imgs', 'real_imgs']),
34 | dict(
35 | type='PickleDataHook',
36 | output_dir='pickle',
37 | interval=-1,
38 | after_run=True,
39 | data_name_list=['noise_weights', 'fixed_noises', 'curr_stage'])
40 | ]
41 |
42 | total_iters = 18000
43 |
--------------------------------------------------------------------------------
/mmgen/datasets/dataset_wrappers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .builder import DATASETS
3 |
4 |
5 | @DATASETS.register_module()
6 | class RepeatDataset:
7 | """A wrapper of repeated dataset.
8 |
9 | The length of repeated dataset will be `times` larger than the original
10 | dataset. This is useful when the data loading time is long but the dataset
11 | is small. Using RepeatDataset can reduce the data loading time between
12 | epochs.
13 |
14 | Args:
15 | dataset (:obj:`Dataset`): The dataset to be repeated.
16 | times (int): Repeat times.
17 | """
18 |
19 | def __init__(self, dataset, times):
20 | self.dataset = dataset
21 | self.times = times
22 |
23 | self._ori_len = len(self.dataset)
24 |
25 | def __getitem__(self, idx):
26 | """Get item at each call.
27 |
28 | Args:
29 | idx (int): Index for getting each item.
30 | """
31 | return self.dataset[idx % self._ori_len]
32 |
33 | def __len__(self):
34 | """Length of the dataset.
35 |
36 | Returns:
37 | int: Length of the dataset.
38 | """
39 | return self.times * self._ori_len
40 |
--------------------------------------------------------------------------------
/mmgen/models/architectures/stylegan/ada/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import numpy as np
3 | import torch
4 |
5 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
6 | # same constant is used multiple times.
7 |
8 | _constant_cache = dict()
9 |
10 |
11 | def constant(value, shape=None, dtype=None, device=None, memory_format=None):
12 | value = np.asarray(value)
13 | if shape is not None:
14 | shape = tuple(shape)
15 | if dtype is None:
16 | dtype = torch.get_default_dtype()
17 | if device is None:
18 | device = torch.device('cpu')
19 | if memory_format is None:
20 | memory_format = torch.contiguous_format
21 |
22 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device,
23 | memory_format)
24 | tensor = _constant_cache.get(key, None)
25 | if tensor is None:
26 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
27 | if shape is not None:
28 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
29 | tensor = tensor.contiguous(memory_format=memory_format)
30 | _constant_cache[key] = tensor
31 | return tensor
32 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/3-new-model.yml:
--------------------------------------------------------------------------------
1 | name: "\U0001F31F New model/dataset/scheduler addition"
2 | description: Submit a proposal/request to implement a new model / dataset / scheduler
3 | labels: [ "feature-request" ]
4 | title: "[New Models] "
5 |
6 |
7 | body:
8 | - type: textarea
9 | id: description-request
10 | validations:
11 | required: true
12 | attributes:
13 | label: Model/Dataset/Scheduler description
14 | description: |
15 | Put any and all important information relative to the model/dataset/scheduler
16 |
17 | - type: checkboxes
18 | attributes:
19 | label: Open source status
20 | description: |
21 | Please provide the open-source status, which would be very helpful
22 | options:
23 | - label: "The model implementation is available"
24 | - label: "The model weights are available."
25 |
26 | - type: textarea
27 | id: additional-info
28 | attributes:
29 | label: Provide useful links for the implementation
30 | description: |
31 | Please provide information regarding the implementation, the weights, and the authors.
32 | Please mention the authors by @gh-username if you're aware of their usernames.
33 |
--------------------------------------------------------------------------------
/mmgen/core/hooks/pggan_fetch_data_hook.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | from mmcv.parallel import is_module_wrapper
4 | from mmcv.runner import HOOKS, Hook
5 |
6 |
7 | @HOOKS.register_module()
8 | class PGGANFetchDataHook(Hook):
9 | """PGGAN Fetch Data Hook.
10 |
11 | Args:
12 | interval (int, optional): The interval of calling this hook. If set
13 | to -1, the visualization hook will not be called. Defaults to 1.
14 | """
15 |
16 | def __init__(self, interval=1):
17 | super().__init__()
18 | self.interval = interval
19 |
20 | def before_fetch_train_data(self, runner):
21 | """The behavior before fetch train data.
22 |
23 | Args:
24 | runner (object): The runner.
25 | """
26 | if not self.every_n_iters(runner, self.interval):
27 | return
28 | _module = runner.model.module if is_module_wrapper(
29 | runner.model) else runner.model
30 |
31 | _next_scale_int = _module._next_scale_int
32 | if isinstance(_next_scale_int, torch.Tensor):
33 | _next_scale_int = _next_scale_int.item()
34 | runner.data_loader.update_dataloader(_next_scale_int)
35 |
--------------------------------------------------------------------------------
/tests/test_modules/test_fid_inception.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import pytest
3 | import torch
4 |
5 | from mmgen.models.architectures import InceptionV3
6 |
7 |
8 | class TestFIDInception:
9 |
10 | @classmethod
11 | def setup_class(cls):
12 | cls.load_fid_inception = False
13 |
14 | def test_fid_inception(self):
15 | inception = InceptionV3(load_fid_inception=self.load_fid_inception)
16 | imgs = torch.randn((2, 3, 256, 256))
17 | out = inception(imgs)[0]
18 | assert out.shape == (2, 2048, 1, 1)
19 |
20 | imgs = torch.randn((2, 3, 512, 512))
21 | out = inception(imgs)[0]
22 | assert out.shape == (2, 2048, 1, 1)
23 |
24 | @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
25 | def test_fid_inception_cuda(self):
26 | inception = InceptionV3(
27 | load_fid_inception=self.load_fid_inception).cuda()
28 | imgs = torch.randn((2, 3, 256, 256)).cuda()
29 | out = inception(imgs)[0]
30 | assert out.shape == (2, 2048, 1, 1)
31 |
32 | imgs = torch.randn((2, 3, 512, 512)).cuda()
33 | out = inception(imgs)[0]
34 | assert out.shape == (2, 2048, 1, 1)
35 |
--------------------------------------------------------------------------------
/configs/_base_/models/wgangp/wgangp_base.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='StaticUnconditionalGAN',
3 | generator=dict(type='WGANGPGenerator', noise_size=128, out_scale=128),
4 | discriminator=dict(
5 | type='WGANGPDiscriminator',
6 | in_channel=3,
7 | in_scale=128,
8 | conv_module_cfg=dict(
9 | conv_cfg=None,
10 | kernel_size=3,
11 | stride=1,
12 | padding=1,
13 | bias=True,
14 | act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
15 | norm_cfg=dict(type='GN'),
16 | order=('conv', 'norm', 'act'))),
17 | gan_loss=dict(type='GANLoss', gan_type='wgan'),
18 | disc_auxiliary_loss=[
19 | dict(
20 | type='GradientPenaltyLoss',
21 | loss_weight=10,
22 | norm_mode='HWC',
23 | data_info=dict(
24 | discriminator='disc',
25 | real_data='real_imgs',
26 | fake_data='fake_imgs'))
27 | ])
28 |
29 | train_cfg = dict(disc_steps=5)
30 |
31 | test_cfg = None
32 |
33 | optimizer = dict(
34 | generator=dict(type='Adam', lr=0.0001, betas=(0.5, 0.9)),
35 | discriminator=dict(type='Adam', lr=0.0001, betas=(0.5, 0.9)))
36 |
--------------------------------------------------------------------------------
/configs/_base_/models/stylegan/styleganv1_base.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='StyleGANV1',
3 | generator=dict(
4 | type='StyleGANv1Generator', out_size=None, style_channels=512),
5 | discriminator=dict(type='StyleGAN1Discriminator', in_size=None),
6 | gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'),
7 | disc_auxiliary_loss=[
8 | dict(
9 | type='R1GradientPenalty',
10 | loss_weight=10,
11 | norm_mode='HWC',
12 | data_info=dict(
13 | discriminator='disc_partial', real_data='real_imgs'))
14 | ])
15 |
16 | train_cfg = dict(
17 | use_ema=True,
18 | transition_kimgs=600,
19 | optimizer_cfg=dict(
20 | generator=dict(type='Adam', lr=0.001, betas=(0.0, 0.99)),
21 | discriminator=dict(type='Adam', lr=0.001, betas=(0.0, 0.99))),
22 | g_lr_base=0.001,
23 | d_lr_base=0.001,
24 | g_lr_schedule=dict({
25 | '128': 0.0015,
26 | '256': 0.002,
27 | '512': 0.003,
28 | '1024': 0.003
29 | }),
30 | d_lr_schedule=dict({
31 | '128': 0.0015,
32 | '256': 0.002,
33 | '512': 0.003,
34 | '1024': 0.003
35 | }))
36 |
37 | test_cfg = None
38 | optimizer = None
39 |
--------------------------------------------------------------------------------
/configs/singan/singan_fish.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/singan/singan.py', '../_base_/datasets/singan.py',
3 | '../_base_/default_runtime.py'
4 | ]
5 |
6 | num_scales = 10 # start from zero
7 | model = dict(
8 | generator=dict(num_scales=num_scales),
9 | discriminator=dict(num_scales=num_scales))
10 |
11 | train_cfg = dict(
12 | noise_weight_init=0.1,
13 | iters_per_scale=2000,
14 | )
15 |
16 | # test_cfg = dict(
17 | # _delete_ = True
18 | # pkl_data = 'path to pkl data'
19 | # )
20 |
21 | data = dict(
22 | train=dict(
23 | img_path='./data/singan/fish-crop.jpg', min_size=25, max_size=300))
24 |
25 | optimizer = None
26 | lr_config = None
27 | checkpoint_config = dict(by_epoch=False, interval=2000, max_keep_ckpts=3)
28 |
29 | custom_hooks = [
30 | dict(
31 | type='MMGenVisualizationHook',
32 | output_dir='visual',
33 | interval=500,
34 | bgr2rgb=True,
35 | res_name_list=['fake_imgs', 'recon_imgs', 'real_imgs']),
36 | dict(
37 | type='PickleDataHook',
38 | output_dir='pickle',
39 | interval=-1,
40 | after_run=True,
41 | data_name_list=['noise_weights', 'fixed_noises', 'curr_stage'])
42 | ]
43 |
44 | total_iters = 22000
45 |
--------------------------------------------------------------------------------
/configs/singan/singan_bohemian.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/singan/singan.py', '../_base_/datasets/singan.py',
3 | '../_base_/default_runtime.py'
4 | ]
5 |
6 | num_scales = 10 # start from zero
7 | model = dict(
8 | generator=dict(num_scales=num_scales),
9 | discriminator=dict(num_scales=num_scales))
10 |
11 | train_cfg = dict(
12 | noise_weight_init=0.1,
13 | iters_per_scale=2000,
14 | )
15 |
16 | # test_cfg = dict(
17 | # _delete_ = True
18 | # pkl_data = 'path to pkl data'
19 | # )
20 |
21 | data = dict(
22 | train=dict(
23 | img_path='./data/singan/bohemian.png', min_size=25, max_size=500))
24 |
25 | optimizer = None
26 | lr_config = None
27 | checkpoint_config = dict(by_epoch=False, interval=2000, max_keep_ckpts=3)
28 |
29 | custom_hooks = [
30 | dict(
31 | type='MMGenVisualizationHook',
32 | output_dir='visual',
33 | interval=500,
34 | bgr2rgb=True,
35 | res_name_list=['fake_imgs', 'recon_imgs', 'real_imgs']),
36 | dict(
37 | type='PickleDataHook',
38 | output_dir='pickle',
39 | interval=-1,
40 | after_run=True,
41 | data_name_list=['noise_weights', 'fixed_noises', 'curr_stage'])
42 | ]
43 |
44 | total_iters = 22000
45 |
--------------------------------------------------------------------------------
/docs/en/api.rst:
--------------------------------------------------------------------------------
1 | mmgen.apis
2 | --------------
3 | .. automodule:: mmgen.apis
4 | :members:
5 |
6 | mmgen.core
7 | --------------
8 |
9 | evaluation
10 | ^^^^^^^^^^
11 | .. automodule:: mmgen.core.evaluation
12 | :members:
13 |
14 | hooks
15 | ^^^^^^^^^^
16 | .. automodule:: mmgen.core.hooks
17 | :members:
18 |
19 | optimizer
20 | ^^^^^^^^^^
21 | .. automodule:: mmgen.core.optimizer
22 | :members:
23 |
24 | runners
25 | ^^^^^^^^^^
26 | .. automodule:: mmgen.core.runners
27 | :members:
28 |
29 | scheduler
30 | ^^^^^^^^^^
31 | .. automodule:: mmgen.core.scheduler
32 | :members:
33 |
34 |
35 |
36 | mmgen.datasets
37 | --------------
38 |
39 | datasets
40 | ^^^^^^^^^^
41 | .. automodule:: mmgen.datasets
42 | :members:
43 |
44 | pipelines
45 | ^^^^^^^^^^
46 | .. automodule:: mmgen.datasets.pipelines
47 | :members:
48 |
49 | mmgen.models
50 | --------------
51 |
52 | architectures
53 | ^^^^^^^^^^
54 | .. automodule:: mmgen.models.architectures
55 | :members:
56 |
57 | common
58 | ^^^^^^^^^^
59 | .. automodule:: mmgen.models.common
60 | :members:
61 |
62 | gans
63 | ^^^^^^^^^^^^
64 | .. automodule:: mmgen.models.gans
65 | :members:
66 |
67 | losses
68 | ^^^^^^^^^^^^
69 | .. automodule:: mmgen.models.losses
70 | :members:
71 |
--------------------------------------------------------------------------------
/configs/biggan/biggan-deep_128x128_cvt_hugging-face_rgb.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='BasiccGAN',
3 | generator=dict(
4 | type='BigGANDeepGenerator',
5 | output_scale=128,
6 | noise_size=128,
7 | num_classes=1000,
8 | base_channels=128,
9 | shared_dim=128,
10 | with_shared_embedding=True,
11 | sn_eps=1e-6,
12 | sn_style='torch',
13 | init_type='ortho',
14 | act_cfg=dict(type='ReLU', inplace=True),
15 | concat_noise=True,
16 | auto_sync_bn=False,
17 | rgb2bgr=True),
18 | discriminator=dict(
19 | type='BigGANDeepDiscriminator',
20 | input_scale=128,
21 | num_classes=1000,
22 | base_channels=128,
23 | sn_eps=1e-6,
24 | sn_style='torch',
25 | init_type='ortho',
26 | act_cfg=dict(type='ReLU', inplace=True),
27 | with_spectral_norm=True),
28 | gan_loss=dict(type='GANLoss', gan_type='hinge'))
29 |
30 | train_cfg = dict(
31 | disc_steps=8, gen_steps=1, batch_accumulation_steps=8, use_ema=True)
32 | test_cfg = None
33 | optimizer = dict(
34 | generator=dict(type='Adam', lr=0.0001, betas=(0.0, 0.999), eps=1e-6),
35 | discriminator=dict(type='Adam', lr=0.0004, betas=(0.0, 0.999), eps=1e-6))
36 |
--------------------------------------------------------------------------------
/configs/biggan/biggan-deep_256x256_cvt_hugging-face_rgb.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='BasiccGAN',
3 | generator=dict(
4 | type='BigGANDeepGenerator',
5 | output_scale=256,
6 | noise_size=128,
7 | num_classes=1000,
8 | base_channels=128,
9 | shared_dim=128,
10 | with_shared_embedding=True,
11 | sn_eps=1e-6,
12 | sn_style='torch',
13 | init_type='ortho',
14 | act_cfg=dict(type='ReLU', inplace=True),
15 | concat_noise=True,
16 | auto_sync_bn=False,
17 | rgb2bgr=True),
18 | discriminator=dict(
19 | type='BigGANDeepDiscriminator',
20 | input_scale=256,
21 | num_classes=1000,
22 | base_channels=128,
23 | sn_eps=1e-6,
24 | sn_style='torch',
25 | init_type='ortho',
26 | act_cfg=dict(type='ReLU', inplace=True),
27 | with_spectral_norm=True),
28 | gan_loss=dict(type='GANLoss', gan_type='hinge'))
29 |
30 | train_cfg = dict(
31 | disc_steps=8, gen_steps=1, batch_accumulation_steps=8, use_ema=True)
32 | test_cfg = None
33 | optimizer = dict(
34 | generator=dict(type='Adam', lr=0.0001, betas=(0.0, 0.999), eps=1e-6),
35 | discriminator=dict(type='Adam', lr=0.0004, betas=(0.0, 0.999), eps=1e-6))
36 |
--------------------------------------------------------------------------------
/configs/biggan/biggan-deep_512x512_cvt_hugging-face_rgb.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='BasiccGAN',
3 | generator=dict(
4 | type='BigGANDeepGenerator',
5 | output_scale=512,
6 | noise_size=128,
7 | num_classes=1000,
8 | base_channels=128,
9 | shared_dim=128,
10 | with_shared_embedding=True,
11 | sn_eps=1e-6,
12 | sn_style='torch',
13 | init_type='ortho',
14 | act_cfg=dict(type='ReLU', inplace=True),
15 | concat_noise=True,
16 | auto_sync_bn=False,
17 | rgb2bgr=True),
18 | discriminator=dict(
19 | type='BigGANDeepDiscriminator',
20 | input_scale=512,
21 | num_classes=1000,
22 | base_channels=128,
23 | sn_eps=1e-6,
24 | sn_style='torch',
25 | init_type='ortho',
26 | act_cfg=dict(type='ReLU', inplace=True),
27 | with_spectral_norm=True),
28 | gan_loss=dict(type='GANLoss', gan_type='hinge'))
29 |
30 | train_cfg = dict(
31 | disc_steps=8, gen_steps=1, batch_accumulation_steps=8, use_ema=True)
32 | test_cfg = None
33 | optimizer = dict(
34 | generator=dict(type='Adam', lr=0.0001, betas=(0.0, 0.999), eps=1e-6),
35 | discriminator=dict(type='Adam', lr=0.0004, betas=(0.0, 0.999), eps=1e-6))
36 |
--------------------------------------------------------------------------------
/configs/pggan/pggan_celeba-hq_1024_g8_12Mimg.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/pggan/pggan_1024.py',
3 | '../_base_/datasets/grow_scale_imgs_celeba-hq.py',
4 | '../_base_/default_runtime.py'
5 | ]
6 |
7 | optimizer = None
8 | checkpoint_config = dict(interval=5000, by_epoch=False, max_keep_ckpts=20)
9 |
10 | data = dict(
11 | samples_per_gpu=64,
12 | train=dict(
13 | gpu_samples_base=4,
14 | # note that this should be changed with total gpu number
15 | gpu_samples_per_scale={
16 | '4': 64,
17 | '8': 32,
18 | '16': 16,
19 | '32': 8,
20 | '64': 4
21 | },
22 | ))
23 |
24 | custom_hooks = [
25 | dict(
26 | type='VisualizeUnconditionalSamples',
27 | output_dir='training_samples',
28 | interval=5000),
29 | dict(type='PGGANFetchDataHook', interval=1),
30 | dict(
31 | type='ExponentialMovingAverageHook',
32 | module_keys=('generator_ema', ),
33 | interval=1,
34 | priority='VERY_HIGH')
35 | ]
36 |
37 | lr_config = None
38 |
39 | total_iters = 280000
40 |
41 | metrics = dict(
42 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000),
43 | swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 1024, 1024)))
44 |
--------------------------------------------------------------------------------
/mmgen/models/builder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch.nn as nn
3 | from mmcv.utils import Registry, build_from_cfg
4 |
5 | MODELS = Registry('model')
6 | MODULES = Registry('module')
7 |
8 |
9 | def build(cfg, registry, default_args=None):
10 | """Build a module.
11 |
12 | Args:
13 | cfg (dict, list[dict]): The config of modules, is is either a dict
14 | or a list of configs.
15 | registry (:obj:`Registry`): A registry the module belongs to.
16 | default_args (dict, optional): Default arguments to build the module.
17 | Defaults to None.
18 | Returns:
19 | nn.Module: A built nn module.
20 | """
21 | if isinstance(cfg, list):
22 | modules = [
23 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
24 | ]
25 | return nn.ModuleList(modules)
26 |
27 | return build_from_cfg(cfg, registry, default_args)
28 |
29 |
30 | def build_model(cfg, train_cfg=None, test_cfg=None):
31 | """Build model (GAN)."""
32 | return build(cfg, MODELS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
33 |
34 |
35 | def build_module(cfg, default_args=None):
36 | """Build a module or modules from a list."""
37 | return build(cfg, MODULES, default_args)
38 |
--------------------------------------------------------------------------------
/mmgen/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .ddpm_loss import DDPMVLBLoss
3 | from .disc_auxiliary_loss import (DiscShiftLoss, GradientPenaltyLoss,
4 | R1GradientPenalty, disc_shift_loss,
5 | gradient_penalty_loss,
6 | r1_gradient_penalty_loss)
7 | from .gan_loss import GANLoss
8 | from .gen_auxiliary_loss import (CLIPLoss, FaceIdLoss,
9 | GeneratorPathRegularizer, PerceptualLoss,
10 | gen_path_regularizer)
11 | from .pixelwise_loss import (DiscretizedGaussianLogLikelihoodLoss,
12 | GaussianKLDLoss, L1Loss, MSELoss,
13 | discretized_gaussian_log_likelihood, gaussian_kld)
14 |
15 | __all__ = [
16 | 'GANLoss', 'DiscShiftLoss', 'disc_shift_loss', 'gradient_penalty_loss',
17 | 'GradientPenaltyLoss', 'R1GradientPenalty', 'r1_gradient_penalty_loss',
18 | 'GeneratorPathRegularizer', 'gen_path_regularizer', 'MSELoss', 'L1Loss',
19 | 'gaussian_kld', 'GaussianKLDLoss', 'DiscretizedGaussianLogLikelihoodLoss',
20 | 'DDPMVLBLoss', 'discretized_gaussian_log_likelihood', 'FaceIdLoss',
21 | 'CLIPLoss', 'PerceptualLoss'
22 | ]
23 |
--------------------------------------------------------------------------------
/docs/zh_cn/api.rst:
--------------------------------------------------------------------------------
1 | 接口参考手册
2 | =================
3 |
4 | mmgen.apis
5 | --------------
6 | .. automodule:: mmgen.apis
7 | :members:
8 |
9 | mmgen.core
10 | --------------
11 |
12 | evaluation
13 | ^^^^^^^^^^
14 | .. automodule:: mmgen.core.evaluation
15 | :members:
16 |
17 | hooks
18 | ^^^^^^^^^^
19 | .. automodule:: mmgen.core.hooks
20 | :members:
21 |
22 | optimizer
23 | ^^^^^^^^^^
24 | .. automodule:: mmgen.core.optimizer
25 | :members:
26 |
27 | runners
28 | ^^^^^^^^^^
29 | .. automodule:: mmgen.core.runners
30 | :members:
31 |
32 | scheduler
33 | ^^^^^^^^^^
34 | .. automodule:: mmgen.core.scheduler
35 | :members:
36 |
37 | mmgen.datasets
38 | --------------
39 |
40 | datasets
41 | ^^^^^^^^^^
42 | .. automodule:: mmgen.datasets
43 | :members:
44 |
45 | pipelines
46 | ^^^^^^^^^^
47 | .. automodule:: mmgen.datasets.pipelines
48 | :members:
49 |
50 | mmgen.models
51 | --------------
52 |
53 | architectures
54 | ^^^^^^^^^^
55 | .. automodule:: mmgen.models.architectures
56 | :members:
57 |
58 | common
59 | ^^^^^^^^^^
60 | .. automodule:: mmgen.models.common
61 | :members:
62 |
63 | gans
64 | ^^^^^^^^^^^^
65 | .. automodule:: mmgen.models.gans
66 | :members:
67 |
68 | losses
69 | ^^^^^^^^^^^^
70 | .. automodule:: mmgen.models.losses
71 | :members:
72 |
--------------------------------------------------------------------------------
/configs/_base_/models/singan/singan.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='SinGAN',
3 | generator=dict(
4 | type='SinGANMultiScaleGenerator',
5 | in_channels=3,
6 | out_channels=3,
7 | num_scales=None, # need to be specified
8 | ),
9 | discriminator=dict(
10 | type='SinGANMultiScaleDiscriminator',
11 | in_channels=3,
12 | num_scales=None, # need to be specified
13 | ),
14 | gan_loss=dict(type='GANLoss', gan_type='wgan', loss_weight=1),
15 | disc_auxiliary_loss=[
16 | dict(
17 | type='GradientPenaltyLoss',
18 | loss_weight=0.1,
19 | norm_mode='pixel',
20 | data_info=dict(
21 | discriminator='disc_partial',
22 | real_data='real_imgs',
23 | fake_data='fake_imgs'))
24 | ],
25 | gen_auxiliary_loss=dict(
26 | type='MSELoss',
27 | loss_weight=10,
28 | data_info=dict(pred='recon_imgs', target='real_imgs'),
29 | ))
30 |
31 | train_cfg = dict(
32 | noise_weight_init=0.1,
33 | iters_per_scale=2000,
34 | curr_scale=-1,
35 | disc_steps=3,
36 | generator_steps=3,
37 | lr_d=0.0005,
38 | lr_g=0.0005,
39 | lr_scheduler_args=dict(milestones=[1600], gamma=0.1))
40 |
41 | test_cfg = None
42 |
--------------------------------------------------------------------------------
/mmgen/models/diffusions/sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import numpy as np
3 | import torch
4 |
5 | from ..builder import MODULES
6 |
7 |
8 | @MODULES.register_module()
9 | class UniformTimeStepSampler:
10 | """Timestep sampler for DDPM-based models. This sampler sample all
11 | timesteps with the same probabilistic.
12 |
13 | Args:
14 | num_timesteps (int): Total timesteps of the diffusion process.
15 | """
16 |
17 | def __init__(self, num_timesteps):
18 | self.num_timesteps = num_timesteps
19 | self.prob = [1 / self.num_timesteps for _ in range(self.num_timesteps)]
20 |
21 | def sample(self, batch_size):
22 | """Sample timesteps.
23 | Args:
24 | batch_size (int): The desired batch size of the sampled timesteps.
25 |
26 | Returns:
27 | torch.Tensor: Sampled timesteps.
28 | """
29 | # use numpy to make sure our implementation is consistent with the
30 | # official ones.
31 | return torch.from_numpy(
32 | np.random.choice(
33 | self.num_timesteps, size=(batch_size, ), p=self.prob)).long()
34 |
35 | def __call__(self, batch_size):
36 | """Return sampled results."""
37 | return self.sample(batch_size)
38 |
--------------------------------------------------------------------------------
/mmgen/utils/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import logging
3 |
4 | from mmcv.utils import get_logger
5 |
6 |
7 | def get_root_logger(log_file=None, log_level=logging.INFO, file_mode='w'):
8 | """Initialize and get a logger with name of mmgen.
9 |
10 | If the logger has not been initialized, this method will initialize the
11 | logger by adding one or two handlers, otherwise the initialized logger will
12 | be directly returned. During initialization, a StreamHandler will always be
13 | added. If `log_file` is specified and the process rank is 0, a FileHandler
14 | will also be added.
15 |
16 | Args:
17 | log_file (str | None): The log filename. If specified, a FileHandler
18 | will be added to the logger. Defaults to ``None``.
19 | log_level (int): The logger level. Note that only the process of
20 | rank 0 is affected, and other processes will set the level to
21 | "Error" thus be silent most of the time.
22 | Defaults to ``logging.INFO``.
23 | file_mode (str): The file mode used in opening log file.
24 | Defaults to 'w'.
25 |
26 | Returns:
27 | logging.Logger: The expected logger.
28 | """
29 | return get_logger('mmgen', log_file, log_level, file_mode=file_mode)
30 |
--------------------------------------------------------------------------------
/tests/test_datasets/test_pipelines/test_loading.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from pathlib import Path
3 |
4 | import mmcv
5 | import numpy as np
6 |
7 | from mmgen.datasets import LoadImageFromFile
8 |
9 |
10 | def test_load_image_from_file():
11 | path_baboon = Path(
12 | __file__).parent / '..' / '..' / 'data' / 'image' / 'baboon.png'
13 | img_baboon = mmcv.imread(str(path_baboon), flag='color')
14 |
15 | # read gt image
16 | # input path is Path object
17 | results = dict(gt_path=path_baboon)
18 | config = dict(io_backend='disk', key='gt')
19 | image_loader = LoadImageFromFile(**config)
20 | results = image_loader(results)
21 | assert results['gt'].shape == (480, 500, 3)
22 | np.testing.assert_almost_equal(results['gt'], img_baboon)
23 | assert results['gt_path'] == str(path_baboon)
24 | # input path is str
25 | results = dict(gt_path=str(path_baboon))
26 | results = image_loader(results)
27 | assert results['gt'].shape == (480, 500, 3)
28 | np.testing.assert_almost_equal(results['gt'], img_baboon)
29 | assert results['gt_path'] == str(path_baboon)
30 |
31 | assert repr(image_loader) == (
32 | image_loader.__class__.__name__ +
33 | ('(io_backend=disk, key=gt, '
34 | 'flag=color, save_original_img=False)'))
35 |
--------------------------------------------------------------------------------
/configs/pggan/pggan_lsun-bedroom_128_g8_12Mimgs.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/pggan/pggan_128x128.py',
3 | '../_base_/datasets/grow_scale_imgs_128x128.py',
4 | '../_base_/default_runtime.py'
5 | ]
6 |
7 | optimizer = None
8 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=20)
9 |
10 | data = dict(
11 | samples_per_gpu=64,
12 | train=dict(
13 | imgs_roots={'128': './data/lsun/bedroom_train'},
14 | gpu_samples_base=4,
15 | # note that this should be changed with total gpu number
16 | gpu_samples_per_scale={
17 | '4': 64,
18 | '8': 32,
19 | '16': 16,
20 | '32': 8,
21 | '64': 4
22 | },
23 | ))
24 |
25 | custom_hooks = [
26 | dict(
27 | type='VisualizeUnconditionalSamples',
28 | output_dir='training_samples',
29 | interval=5000),
30 | dict(type='PGGANFetchDataHook', interval=1),
31 | dict(
32 | type='ExponentialMovingAverageHook',
33 | module_keys=('generator_ema', ),
34 | interval=1,
35 | priority='VERY_HIGH')
36 | ]
37 |
38 | lr_config = None
39 |
40 | total_iters = 280000
41 |
42 | metrics = dict(
43 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000),
44 | swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 128, 128)))
45 |
--------------------------------------------------------------------------------
/tools/publish_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import subprocess
4 | from datetime import datetime
5 |
6 | import torch
7 |
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser(
11 | description='Process a checkpoint to be published')
12 | parser.add_argument('in_file', help='input checkpoint filename')
13 | parser.add_argument('out_file', help='output checkpoint filename')
14 | args = parser.parse_args()
15 | return args
16 |
17 |
18 | def process_checkpoint(in_file, out_file):
19 | checkpoint = torch.load(in_file, map_location='cpu')
20 | # remove optimizer for smaller file size
21 | if 'optimizer' in checkpoint:
22 | del checkpoint['optimizer']
23 | # if it is necessary to remove some sensitive data in checkpoint['meta'],
24 | # add the code here.
25 | torch.save(checkpoint, out_file)
26 | now = datetime.now()
27 | time = now.strftime('%Y%m%d_%H%M%S')
28 | sha = subprocess.check_output(['sha256sum', out_file]).decode()
29 | final_file = out_file.rstrip('.pth') + f'_{time}-{sha[:8]}.pth'
30 | subprocess.Popen(['mv', out_file, final_file])
31 |
32 |
33 | def main():
34 | args = parse_args()
35 | process_checkpoint(args.in_file, args.out_file)
36 |
37 |
38 | if __name__ == '__main__':
39 | main()
40 |
--------------------------------------------------------------------------------
/.github/workflows/test_mim.yml:
--------------------------------------------------------------------------------
1 | name: test-mim
2 |
3 | on:
4 | push:
5 | paths:
6 | - 'model-index.yml'
7 | - 'configs/**'
8 |
9 | pull_request:
10 | paths:
11 | - 'model-index.yml'
12 | - 'configs/**'
13 |
14 | concurrency:
15 | group: ${{ github.workflow }}-${{ github.ref }}
16 | cancel-in-progress: true
17 |
18 | jobs:
19 | build_cpu:
20 | runs-on: ubuntu-18.04
21 | strategy:
22 | matrix:
23 | python-version: [3.7]
24 | torch: [1.8.0]
25 | include:
26 | - torch: 1.8.0
27 | torch_version: torch1.8
28 | torchvision: 0.9.0
29 | steps:
30 | - uses: actions/checkout@v2
31 | - name: Set up Python ${{ matrix.python-version }}
32 | uses: actions/setup-python@v2
33 | with:
34 | python-version: ${{ matrix.python-version }}
35 | - name: Upgrade pip
36 | run: pip install pip --upgrade
37 | - name: Install PyTorch
38 | run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
39 | - name: Install openmim
40 | run: pip install openmim
41 | - name: Build and install
42 | run: rm -rf .eggs && mim install -e .
43 | - name: test commands of mim
44 | run: mim search mmgen
45 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/cifar10_rgb.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'mmcls.CIFAR10'
2 |
3 | # different from mmcls, we adopt the setting used in BigGAN
4 | # Note that the pipelines below are from MMClassification
5 | img_norm_cfg = dict(
6 | mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=False)
7 | train_pipeline = [
8 | dict(type='RandomCrop', size=32, padding=4),
9 | dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
10 | dict(type='Normalize', **img_norm_cfg),
11 | dict(type='ImageToTensor', keys=['img']),
12 | dict(type='ToTensor', keys=['gt_label']),
13 | dict(type='Collect', keys=['img', 'gt_label'])
14 | ]
15 | test_pipeline = [
16 | dict(type='Normalize', **img_norm_cfg),
17 | dict(type='ImageToTensor', keys=['img']),
18 | dict(type='Collect', keys=['img'])
19 | ]
20 |
21 | # Different from the classification task, the val/test split also use the
22 | # training part, which is the same to StyleGAN-ADA.
23 | data = dict(
24 | samples_per_gpu=None,
25 | workers_per_gpu=4,
26 | train=dict(
27 | type=dataset_type, data_prefix='data/cifar10',
28 | pipeline=train_pipeline),
29 | val=dict(
30 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline),
31 | test=dict(
32 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline))
33 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/grow_scale_imgs_ffhq_styleganv1.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'GrowScaleImgDataset'
2 |
3 | train_pipeline = [
4 | dict(
5 | type='LoadImageFromFile',
6 | key='real_img',
7 | io_backend='disk',
8 | ),
9 | dict(type='Flip', keys=['real_img'], direction='horizontal'),
10 | dict(
11 | type='Normalize',
12 | keys=['real_img'],
13 | mean=[127.5, 127.5, 127.5],
14 | std=[127.5, 127.5, 127.5],
15 | to_rgb=False),
16 | dict(type='ImageToTensor', keys=['real_img']),
17 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
18 | ]
19 |
20 | data = dict(
21 | samples_per_gpu=64,
22 | workers_per_gpu=4,
23 | train=dict(
24 | type='GrowScaleImgDataset',
25 | imgs_roots=dict({
26 | '1024': './data/ffhq/images',
27 | '256': './data/ffhq/ffhq_imgs/ffhq_256',
28 | '64': './data/ffhq/ffhq_imgs/ffhq_64'
29 | }),
30 | pipeline=train_pipeline,
31 | gpu_samples_base=4,
32 | gpu_samples_per_scale={
33 | '4': 64,
34 | '8': 32,
35 | '16': 16,
36 | '32': 8,
37 | '64': 4,
38 | '128': 4,
39 | '256': 4,
40 | '512': 4,
41 | '1024': 4
42 | },
43 | len_per_stage=300000))
44 |
--------------------------------------------------------------------------------
/configs/pggan/pggan_celeba-cropped_128_g8_12Mimgs.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/pggan/pggan_128x128.py',
3 | '../_base_/datasets/grow_scale_imgs_128x128.py',
4 | '../_base_/default_runtime.py'
5 | ]
6 |
7 | optimizer = None
8 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=20)
9 |
10 | data = dict(
11 | samples_per_gpu=64,
12 | train=dict(
13 | imgs_roots={'128': './data/celeba-cropped/cropped_images_aligned_png'},
14 | gpu_samples_base=4,
15 | # note that this should be changed with total gpu number
16 | gpu_samples_per_scale={
17 | '4': 64,
18 | '8': 32,
19 | '16': 16,
20 | '32': 8,
21 | '64': 4
22 | }))
23 |
24 | custom_hooks = [
25 | dict(
26 | type='VisualizeUnconditionalSamples',
27 | output_dir='training_samples',
28 | interval=5000),
29 | dict(type='PGGANFetchDataHook', interval=1),
30 | dict(
31 | type='ExponentialMovingAverageHook',
32 | module_keys=('generator_ema', ),
33 | interval=1,
34 | priority='VERY_HIGH')
35 | ]
36 |
37 | lr_config = None
38 |
39 | total_iters = 280000
40 |
41 | metrics = dict(
42 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000),
43 | swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 128, 128)))
44 |
--------------------------------------------------------------------------------
/configs/styleganv1/metafile.yml:
--------------------------------------------------------------------------------
1 | Collections:
2 | - Metadata:
3 | Architecture:
4 | - StyleGANv1
5 | Name: StyleGANv1
6 | Paper:
7 | - https://openaccess.thecvf.com/content_CVPR_2019/html/Karras_A_Style-Based_Generator_Architecture_for_Generative_Adversarial_Networks_CVPR_2019_paper.html
8 | README: configs/styleganv1/README.md
9 | Models:
10 | - Config: https://github.com/open-mmlab/mmgeneration/tree/master/configs/styleganv1/styleganv1_ffhq_256_g8_25Mimg.py
11 | In Collection: StyleGANv1
12 | Metadata:
13 | Training Data: FFHQ
14 | Name: styleganv1_ffhq_256_g8_25Mimg
15 | Results:
16 | - Dataset: FFHQ
17 | Metrics:
18 | FID50k: 6.09
19 | P&R50k_full: 70.228/27.050
20 | Task: Unconditional GANs
21 | Weights: https://download.openmmlab.com/mmgen/styleganv1/styleganv1_ffhq_256_g8_25Mimg_20210407_161748-0094da86.pth
22 | - Config: https://github.com/open-mmlab/mmgeneration/tree/master/configs/styleganv1/styleganv1_ffhq_1024_g8_25Mimg.py
23 | In Collection: StyleGANv1
24 | Metadata:
25 | Training Data: FFHQ
26 | Name: styleganv1_ffhq_1024_g8_25Mimg
27 | Results:
28 | - Dataset: FFHQ
29 | Metrics:
30 | FID50k: 4.056
31 | P&R50k_full: 70.302/36.869
32 | Task: Unconditional GANs
33 | Weights: https://download.openmmlab.com/mmgen/styleganv1/styleganv1_ffhq_1024_g8_25Mimg_20210407_161627-850a7234.pth
34 |
--------------------------------------------------------------------------------
/configs/_base_/models/pix2pix/pix2pix_vanilla_unet_bn.py:
--------------------------------------------------------------------------------
1 | source_domain = None # set by user
2 | target_domain = None # set by user
3 | # model settings
4 | model = dict(
5 | type='Pix2Pix',
6 | generator=dict(
7 | type='UnetGenerator',
8 | in_channels=3,
9 | out_channels=3,
10 | num_down=8,
11 | base_channels=64,
12 | norm_cfg=dict(type='BN'),
13 | use_dropout=True,
14 | init_cfg=dict(type='normal', gain=0.02)),
15 | discriminator=dict(
16 | type='PatchDiscriminator',
17 | in_channels=6,
18 | base_channels=64,
19 | num_conv=3,
20 | norm_cfg=dict(type='BN'),
21 | init_cfg=dict(type='normal', gain=0.02)),
22 | gan_loss=dict(
23 | type='GANLoss',
24 | gan_type='vanilla',
25 | real_label_val=1.0,
26 | fake_label_val=0.0,
27 | loss_weight=1.0),
28 | default_domain=target_domain,
29 | reachable_domains=[target_domain],
30 | related_domains=[target_domain, source_domain],
31 | gen_auxiliary_loss=dict(
32 | type='L1Loss',
33 | loss_weight=100.0,
34 | loss_name='pixel_loss',
35 | data_info=dict(
36 | pred=f'fake_{target_domain}', target=f'real_{target_domain}'),
37 | reduction='mean'))
38 | # model training and testing settings
39 | train_cfg = None
40 | test_cfg = None
41 |
--------------------------------------------------------------------------------
/configs/wgan-gp/metafile.yml:
--------------------------------------------------------------------------------
1 | Collections:
2 | - Metadata:
3 | Architecture:
4 | - WGAN-GP
5 | Name: WGAN-GP
6 | Paper:
7 | - https://arxiv.org/abs/1704.00028
8 | README: configs/wgan-gp/README.md
9 | Models:
10 | - Config: https://github.com/open-mmlab/mmgeneration/tree/master/configs/wgan-gp/wgangp_GN_celeba-cropped_128_b64x1_160kiter.py
11 | In Collection: WGAN-GP
12 | Metadata:
13 | Training Data: CELEBA
14 | Name: wgangp_GN_celeba-cropped_128_b64x1_160kiter
15 | Results:
16 | - Dataset: CELEBA
17 | Metrics:
18 | Details: GN
19 | MS-SSIM: 0.2601
20 | SWD: 5.87, 9.76, 9.43, 18.84/10.97
21 | Task: Unconditional GANs
22 | Weights: https://download.openmmlab.com/mmgen/wgangp/wgangp_GN_celeba-cropped_128_b64x1_160k_20210408_170611-f8a99336.pth
23 | - Config: https://github.com/open-mmlab/mmgeneration/tree/master/configs/wgan-gp/wgangp_GN_GP-50_lsun-bedroom_128_b64x1_160kiter.py
24 | In Collection: WGAN-GP
25 | Metadata:
26 | Training Data: LSUN
27 | Name: wgangp_GN_GP-50_lsun-bedroom_128_b64x1_160kiter
28 | Results:
29 | - Dataset: LSUN
30 | Metrics:
31 | Details: GN, GP-lambda = 50
32 | MS-SSIM: 0.059
33 | SWD: 11.7, 7.87, 9.82, 25.36/13.69
34 | Task: Unconditional GANs
35 | Weights: https://download.openmmlab.com/mmgen/wgangp/wgangp_GN_GP-50_lsun-bedroom_128_b64x1_130k_20210408_170509-56f2a37c.pth
36 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/cifar10_inception_stat.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'mmcls.CIFAR10'
2 |
3 | # This config is set for extract inception state of CIFAR dataset.
4 |
5 | # Different from mmcls, we adopt the setting used in BigGAN.
6 | # Note that the pipelines below are from MMClassification.
7 | # The default order in Cifar10 is RGB. Thus, we set `to_rgb` as `False`.
8 | img_norm_cfg = dict(
9 | mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=False)
10 | train_pipeline = [
11 | dict(type='Normalize', **img_norm_cfg),
12 | dict(type='ImageToTensor', keys=['img']),
13 | dict(type='ToTensor', keys=['gt_label']),
14 | dict(type='Collect', keys=['img', 'gt_label'])
15 | ]
16 | test_pipeline = [
17 | dict(type='Normalize', **img_norm_cfg),
18 | dict(type='ImageToTensor', keys=['img']),
19 | dict(type='Collect', keys=['img'])
20 | ]
21 |
22 | # Different from the classification task, the val/test split also use the
23 | # training part, which is the same to StyleGAN-ADA.
24 | data = dict(
25 | samples_per_gpu=None,
26 | workers_per_gpu=4,
27 | train=dict(
28 | type=dataset_type, data_prefix='data/cifar10',
29 | pipeline=train_pipeline),
30 | val=dict(
31 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline),
32 | test=dict(
33 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline))
34 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/grow_scale_imgs_128x128.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'GrowScaleImgDataset'
2 |
3 | train_pipeline = [
4 | dict(
5 | type='LoadImageFromFile',
6 | key='real_img',
7 | io_backend='disk',
8 | ),
9 | dict(type='Resize', keys=['real_img'], scale=(128, 128)),
10 | dict(type='Flip', keys=['real_img'], direction='horizontal'),
11 | dict(
12 | type='Normalize',
13 | keys=['real_img'],
14 | mean=[127.5] * 3,
15 | std=[127.5] * 3,
16 | to_rgb=False),
17 | dict(type='ImageToTensor', keys=['real_img']),
18 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
19 | ]
20 |
21 | # `samples_per_gpu` and `imgs_root` need to be set.
22 | data = dict(
23 | # samples per gpu should be the same as the first scale, e.g. '4': 64
24 | # in this case
25 | samples_per_gpu=None,
26 | workers_per_gpu=4,
27 | train=dict(
28 | type=dataset_type,
29 | # just an example
30 | imgs_roots={'128': './data/lsun/bedroom_train'},
31 | pipeline=train_pipeline,
32 | gpu_samples_base=4,
33 | # note that this should be changed with total gpu number
34 | gpu_samples_per_scale={
35 | '4': 64,
36 | '8': 32,
37 | '16': 16,
38 | '32': 8,
39 | '64': 4
40 | },
41 | len_per_stage=-1))
42 |
--------------------------------------------------------------------------------
/mmgen/core/runners/apex_amp_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | try:
3 | from apex import amp
4 | except ImportError:
5 | amp = None
6 |
7 |
8 | def apex_amp_initialize(models, optimizers, init_args=None, mode='gan'):
9 | """Initialize apex.amp for mixed-precision training.
10 |
11 | Args:
12 | models (nn.Module | list[Module]): Modules to be wrapped with apex.amp.
13 | optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
14 | init_args (dict | None, optional): Config for amp initialization.
15 | Defaults to None.
16 | mode (str, optional): The moded used to initialize the apex.map.
17 | Different modes lead to different wrapping mode for models and
18 | optimizers. Defaults to 'gan'.
19 |
20 | Returns:
21 | Module, :obj:`Optimizer`: Wrapped module and optimizer.
22 | """
23 | init_args = init_args or dict()
24 |
25 | if mode == 'gan':
26 | _optmizers = [optimizers['generator'], optimizers['discriminator']]
27 |
28 | models, _optmizers = amp.initialize(models, _optmizers, **init_args)
29 | optimizers['generator'] = _optmizers[0]
30 | optimizers['discriminator'] = _optmizers[1]
31 |
32 | return models, optimizers
33 |
34 | else:
35 | raise NotImplementedError(
36 | f'Cannot initialize apex.amp with mode {mode}')
37 |
--------------------------------------------------------------------------------
/tests/test_datasets/test_pipelines/test_compose.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import numpy as np
3 | import pytest
4 |
5 | from mmgen.datasets.pipelines import Compose, ImageToTensor
6 |
7 |
8 | def check_keys_equal(result_keys, target_keys):
9 | """Check if all elements in target_keys is in result_keys."""
10 | return set(target_keys) == set(result_keys)
11 |
12 |
13 | def test_compose():
14 | with pytest.raises(TypeError):
15 | Compose('LoadAlpha')
16 |
17 | target_keys = ['img', 'meta']
18 |
19 | img = np.random.randn(256, 256, 3)
20 | results = dict(img=img, abandoned_key=None, img_name='test_image.png')
21 | test_pipeline = [
22 | dict(type='Collect', keys=['img'], meta_keys=['img_name']),
23 | dict(type='ImageToTensor', keys=['img'])
24 | ]
25 | compose = Compose(test_pipeline)
26 | compose_results = compose(results)
27 | assert check_keys_equal(compose_results.keys(), target_keys)
28 | assert check_keys_equal(compose_results['meta'].data.keys(), ['img_name'])
29 |
30 | results = None
31 | image_to_tensor = ImageToTensor(keys=[])
32 | test_pipeline = [image_to_tensor]
33 | compose = Compose(test_pipeline)
34 | compose_results = compose(results)
35 | assert compose_results is None
36 |
37 | assert repr(compose) == (
38 | compose.__class__.__name__ + f'(\n {image_to_tensor}\n)')
39 |
--------------------------------------------------------------------------------
/.circleci/config.yml:
--------------------------------------------------------------------------------
1 | version: 2.1
2 |
3 | # this allows you to use CircleCI's dynamic configuration feature
4 | setup: true
5 |
6 | # the path-filtering orb is required to continue a pipeline based on
7 | # the path of an updated fileset
8 | orbs:
9 | path-filtering: circleci/path-filtering@0.1.2
10 |
11 | workflows:
12 | # the always-run workflow is always triggered, regardless of the pipeline parameters.
13 | always-run:
14 | jobs:
15 | # the path-filtering/filter job determines which pipeline
16 | # parameters to update.
17 | - path-filtering/filter:
18 | name: check-updated-files
19 | # 3-column, whitespace-delimited mapping. One mapping per
20 | # line:
21 | #
22 | mapping: |
23 | mmgen/.* lint_only false
24 | requirements/.* lint_only false
25 | tests/.* lint_only false
26 | tools/.* lint_only false
27 | configs/.* lint_only false
28 | .circleci/.* lint_only false
29 | base-revision: master
30 | # this is the path of the configuration we should trigger once
31 | # path filtering and pipeline parameter value updates are
32 | # complete. In this case, we are using the parent dynamic
33 | # configuration itself.
34 | config-path: .circleci/test.yml
35 |
--------------------------------------------------------------------------------
/tools/misc/print_config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 |
4 | from mmcv import Config, DictAction
5 |
6 |
7 | def parse_args():
8 | parser = argparse.ArgumentParser(description='Print the whole config')
9 | parser.add_argument('config', help='config file path')
10 | parser.add_argument(
11 | '--cfg-options',
12 | nargs='+',
13 | action=DictAction,
14 | help='override some settings in the used config, the key-value pair '
15 | 'in xxx=yyy format will be merged into config file. If the value to '
16 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
17 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
18 | 'Note that the quotation marks are necessary and that no white space '
19 | 'is allowed.')
20 | args = parser.parse_args()
21 |
22 | return args
23 |
24 |
25 | def main():
26 | args = parse_args()
27 |
28 | cfg = Config.fromfile(args.config)
29 | if args.cfg_options is not None:
30 | cfg.merge_from_dict(args.cfg_options)
31 | # import modules from string list.
32 | if cfg.get('custom_imports', None):
33 | from mmcv.utils import import_modules_from_strings
34 | import_modules_from_strings(**cfg['custom_imports'])
35 | print(f'Config:\n{cfg.pretty_text}')
36 |
37 |
38 | if __name__ == '__main__':
39 | main()
40 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/ffhq_flip.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'UnconditionalImageDataset'
2 |
3 | train_pipeline = [
4 | dict(
5 | type='LoadImageFromFile',
6 | key='real_img',
7 | io_backend='disk',
8 | ),
9 | dict(type='Flip', keys=['real_img'], direction='horizontal'),
10 | dict(
11 | type='Normalize',
12 | keys=['real_img'],
13 | mean=[127.5] * 3,
14 | std=[127.5] * 3,
15 | to_rgb=False),
16 | dict(type='ImageToTensor', keys=['real_img']),
17 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
18 | ]
19 |
20 | val_pipeline = [
21 | dict(
22 | type='LoadImageFromFile',
23 | key='real_img',
24 | io_backend='disk',
25 | ),
26 | dict(
27 | type='Normalize',
28 | keys=['real_img'],
29 | mean=[127.5] * 3,
30 | std=[127.5] * 3,
31 | to_rgb=True),
32 | dict(type='ImageToTensor', keys=['real_img']),
33 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
34 | ]
35 |
36 | # `samples_per_gpu` and `imgs_root` need to be set.
37 | data = dict(
38 | samples_per_gpu=None,
39 | workers_per_gpu=4,
40 | train=dict(
41 | type='RepeatDataset',
42 | times=100,
43 | dataset=dict(
44 | type=dataset_type, imgs_root=None, pipeline=train_pipeline)),
45 | val=dict(type=dataset_type, imgs_root=None, pipeline=val_pipeline))
46 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/lsun_stylegan.py:
--------------------------------------------------------------------------------
1 | # Style-based GANs do not perform any augmentation for the LSUN datasets
2 | dataset_type = 'UnconditionalImageDataset'
3 |
4 | train_pipeline = [
5 | dict(
6 | type='LoadImageFromFile',
7 | key='real_img',
8 | io_backend='disk',
9 | ),
10 | dict(
11 | type='Normalize',
12 | keys=['real_img'],
13 | mean=[127.5] * 3,
14 | std=[127.5] * 3,
15 | to_rgb=False),
16 | dict(type='ImageToTensor', keys=['real_img']),
17 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
18 | ]
19 |
20 | val_pipeline = [
21 | dict(
22 | type='LoadImageFromFile',
23 | key='real_img',
24 | io_backend='disk',
25 | ),
26 | dict(
27 | type='Normalize',
28 | keys=['real_img'],
29 | mean=[127.5] * 3,
30 | std=[127.5] * 3,
31 | to_rgb=True),
32 | dict(type='ImageToTensor', keys=['real_img']),
33 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
34 | ]
35 |
36 | # `samples_per_gpu` and `imgs_root` need to be set.
37 | data = dict(
38 | samples_per_gpu=None,
39 | workers_per_gpu=4,
40 | train=dict(
41 | type='RepeatDataset',
42 | times=100,
43 | dataset=dict(
44 | type=dataset_type, imgs_root=None, pipeline=train_pipeline)),
45 | val=dict(type=dataset_type, imgs_root=None, pipeline=val_pipeline))
46 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/cifar10_nopad.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'mmcls.CIFAR10'
2 |
3 | # different from mmcls, we adopt the setting used in BigGAN
4 | # Note that the pipelines below are from MMClassification. Importantly, the
5 | # `to_rgb` is set to `True` to convert image to BGR orders. The default order
6 | # in Cifar10 is RGB. Thus, we have to convert it to BGR.
7 | img_norm_cfg = dict(
8 | mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
9 | train_pipeline = [
10 | dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
11 | dict(type='Normalize', **img_norm_cfg),
12 | dict(type='ImageToTensor', keys=['img']),
13 | dict(type='ToTensor', keys=['gt_label']),
14 | dict(type='Collect', keys=['img', 'gt_label'])
15 | ]
16 | test_pipeline = [
17 | dict(type='Normalize', **img_norm_cfg),
18 | dict(type='ImageToTensor', keys=['img']),
19 | dict(type='Collect', keys=['img'])
20 | ]
21 |
22 | # Different from the classification task, the val/test split also use the
23 | # training part, which is the same to StyleGAN-ADA.
24 | data = dict(
25 | samples_per_gpu=None,
26 | workers_per_gpu=4,
27 | train=dict(
28 | type=dataset_type, data_prefix='data/cifar10',
29 | pipeline=train_pipeline),
30 | val=dict(
31 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline),
32 | test=dict(
33 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline))
34 |
--------------------------------------------------------------------------------
/mmgen/ops/stylegan3/ops/bias_act.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | //------------------------------------------------------------------------
10 | // CUDA kernel parameters.
11 |
12 | struct bias_act_kernel_params
13 | {
14 | const void* x; // [sizeX]
15 | const void* b; // [sizeB] or NULL
16 | const void* xref; // [sizeX] or NULL
17 | const void* yref; // [sizeX] or NULL
18 | const void* dy; // [sizeX] or NULL
19 | void* y; // [sizeX]
20 |
21 | int grad;
22 | int act;
23 | float alpha;
24 | float gain;
25 | float clamp;
26 |
27 | int sizeX;
28 | int sizeB;
29 | int stepB;
30 | int loopX;
31 | };
32 |
33 | //------------------------------------------------------------------------
34 | // CUDA kernel selection.
35 |
36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37 |
38 | //------------------------------------------------------------------------
39 |
--------------------------------------------------------------------------------
/configs/styleganv1/styleganv1_ffhq_256_g8_25Mimg.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/stylegan/styleganv1_base.py',
3 | '../_base_/datasets/grow_scale_imgs_ffhq_styleganv1.py',
4 | '../_base_/default_runtime.py',
5 | ]
6 |
7 | model = dict(generator=dict(out_size=256), discriminator=dict(in_size=256))
8 |
9 | train_cfg = dict(nkimgs_per_scale={
10 | '8': 1200,
11 | '16': 1200,
12 | '32': 1200,
13 | '64': 1200,
14 | '128': 1200,
15 | '256': 190000
16 | })
17 |
18 | checkpoint_config = dict(interval=5000, by_epoch=False, max_keep_ckpts=20)
19 | lr_config = None
20 |
21 | ema_half_life = 10. # G_smoothing_kimg
22 |
23 | custom_hooks = [
24 | dict(
25 | type='VisualizeUnconditionalSamples',
26 | output_dir='training_samples',
27 | interval=5000),
28 | dict(type='PGGANFetchDataHook', interval=1),
29 | dict(
30 | type='ExponentialMovingAverageHook',
31 | module_keys=('generator_ema', ),
32 | interval=1,
33 | interp_cfg=dict(momentum=0.5**(32. / (ema_half_life * 1000.))),
34 | priority='VERY_HIGH')
35 | ]
36 |
37 | total_iters = 670000
38 |
39 | metrics = dict(
40 | fid50k=dict(
41 | type='FID',
42 | num_images=50000,
43 | inception_pkl='work_dirs/inception_pkl/ffhq-256-50k-rgb.pkl',
44 | bgr2rgb=True),
45 | pr50k3=dict(type='PR', num_images=50000, k=3),
46 | ppl_wend=dict(type='PPL', space='W', sampling='end', num_images=50000))
47 |
--------------------------------------------------------------------------------
/configs/_base_/models/stylegan/stylegan2_base.py:
--------------------------------------------------------------------------------
1 | # define GAN model
2 |
3 | d_reg_interval = 16
4 | g_reg_interval = 4
5 |
6 | g_reg_ratio = g_reg_interval / (g_reg_interval + 1)
7 | d_reg_ratio = d_reg_interval / (d_reg_interval + 1)
8 |
9 | model = dict(
10 | type='StaticUnconditionalGAN',
11 | generator=dict(
12 | type='StyleGANv2Generator',
13 | out_size=None, # Need to be set.
14 | style_channels=512,
15 | ),
16 | discriminator=dict(
17 | type='StyleGAN2Discriminator',
18 | in_size=None, # Need to be set.
19 | ),
20 | gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'),
21 | disc_auxiliary_loss=dict(
22 | type='R1GradientPenalty',
23 | loss_weight=10. / 2. * d_reg_interval,
24 | interval=d_reg_interval,
25 | norm_mode='HWC',
26 | data_info=dict(real_data='real_imgs', discriminator='disc')),
27 | gen_auxiliary_loss=dict(
28 | type='GeneratorPathRegularizer',
29 | loss_weight=2. * g_reg_interval,
30 | pl_batch_shrink=2,
31 | interval=g_reg_interval,
32 | data_info=dict(generator='gen', num_batches='batch_size')))
33 |
34 | train_cfg = dict(use_ema=True)
35 | test_cfg = None
36 |
37 | # define optimizer
38 | optimizer = dict(
39 | generator=dict(
40 | type='Adam', lr=0.002 * g_reg_ratio, betas=(0, 0.99**g_reg_ratio)),
41 | discriminator=dict(
42 | type='Adam', lr=0.002 * d_reg_ratio, betas=(0, 0.99**d_reg_ratio)))
43 |
--------------------------------------------------------------------------------
/configs/styleganv2/stylegan2_c2_lsun-cat_256_b4x8_800k.py:
--------------------------------------------------------------------------------
1 | """Note that this config is just for testing."""
2 |
3 | _base_ = [
4 | '../_base_/datasets/lsun_stylegan.py',
5 | '../_base_/models/stylegan/stylegan2_base.py',
6 | '../_base_/default_runtime.py'
7 | ]
8 |
9 | model = dict(generator=dict(out_size=256), discriminator=dict(in_size=256))
10 |
11 | data = dict(
12 | samples_per_gpu=4, train=dict(dataset=dict(imgs_root='./data/lsun-cat')))
13 |
14 | ema_half_life = 10. # G_smoothing_kimg
15 |
16 | custom_hooks = [
17 | dict(
18 | type='VisualizeUnconditionalSamples',
19 | output_dir='training_samples',
20 | interval=5000),
21 | dict(
22 | type='ExponentialMovingAverageHook',
23 | module_keys=('generator_ema', ),
24 | interval=1,
25 | interp_cfg=dict(momentum=0.5**(32. / (ema_half_life * 1000.))),
26 | priority='VERY_HIGH')
27 | ]
28 |
29 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=30)
30 | lr_config = None
31 |
32 | log_config = dict(
33 | interval=100,
34 | hooks=[
35 | dict(type='TextLoggerHook'),
36 | # dict(type='TensorboardLoggerHook'),
37 | ])
38 |
39 | total_iters = 800002 # need to modify
40 |
41 | metrics = dict(
42 | fid50k=dict(
43 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True),
44 | pr50k3=dict(type='PR', num_images=50000, k=3),
45 | ppl_wend=dict(type='PPL', space='W', sampling='end', num_images=50000))
46 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/cifar10.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'mmcls.CIFAR10'
2 |
3 | # different from mmcls, we adopt the setting used in BigGAN
4 | # Note that the pipelines below are from MMClassification. Importantly, the
5 | # `to_rgb` is set to `True` to convert image to BGR orders. The default order
6 | # in Cifar10 is RGB. Thus, we have to convert it to BGR.
7 | img_norm_cfg = dict(
8 | mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
9 | train_pipeline = [
10 | dict(type='RandomCrop', size=32, padding=4),
11 | dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
12 | dict(type='Normalize', **img_norm_cfg),
13 | dict(type='ImageToTensor', keys=['img']),
14 | dict(type='ToTensor', keys=['gt_label']),
15 | dict(type='Collect', keys=['img', 'gt_label'])
16 | ]
17 | test_pipeline = [
18 | dict(type='Normalize', **img_norm_cfg),
19 | dict(type='ImageToTensor', keys=['img']),
20 | dict(type='Collect', keys=['img'])
21 | ]
22 |
23 | # Different from the classification task, the val/test split also use the
24 | # training part, which is the same to StyleGAN-ADA.
25 | data = dict(
26 | samples_per_gpu=None,
27 | workers_per_gpu=4,
28 | train=dict(
29 | type=dataset_type, data_prefix='data/cifar10',
30 | pipeline=train_pipeline),
31 | val=dict(
32 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline),
33 | test=dict(
34 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline))
35 |
--------------------------------------------------------------------------------
/configs/styleganv2/stylegan2_c2_lsun-horse_256_b4x8_800k.py:
--------------------------------------------------------------------------------
1 | """Note that this config is just for testing."""
2 |
3 | _base_ = [
4 | '../_base_/datasets/lsun_stylegan.py',
5 | '../_base_/models/stylegan/stylegan2_base.py',
6 | '../_base_/default_runtime.py'
7 | ]
8 |
9 | model = dict(generator=dict(out_size=256), discriminator=dict(in_size=256))
10 |
11 | data = dict(
12 | samples_per_gpu=4, train=dict(dataset=dict(imgs_root='./data/lsun-horse')))
13 |
14 | ema_half_life = 10. # G_smoothing_kimg
15 |
16 | custom_hooks = [
17 | dict(
18 | type='VisualizeUnconditionalSamples',
19 | output_dir='training_samples',
20 | interval=5000),
21 | dict(
22 | type='ExponentialMovingAverageHook',
23 | module_keys=('generator_ema', ),
24 | interval=1,
25 | interp_cfg=dict(momentum=0.5**(32. / (ema_half_life * 1000.))),
26 | priority='VERY_HIGH')
27 | ]
28 |
29 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=30)
30 | lr_config = None
31 |
32 | log_config = dict(
33 | interval=100,
34 | hooks=[
35 | dict(type='TextLoggerHook'),
36 | # dict(type='TensorboardLoggerHook'),
37 | ])
38 |
39 | total_iters = 800002 # need to modify
40 |
41 | metrics = dict(
42 | fid50k=dict(
43 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True),
44 | pr50k3=dict(type='PR', num_images=50000, k=3),
45 | ppl_wend=dict(type='PPL', space='W', sampling='end', num_images=50000))
46 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/grow_scale_imgs_celeba-hq.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'GrowScaleImgDataset'
2 |
3 | train_pipeline = [
4 | dict(
5 | type='LoadImageFromFile',
6 | key='real_img',
7 | io_backend='disk',
8 | ),
9 | dict(type='Flip', keys=['real_img'], direction='horizontal'),
10 | dict(
11 | type='Normalize',
12 | keys=['real_img'],
13 | mean=[127.5] * 3,
14 | std=[127.5] * 3,
15 | to_rgb=False),
16 | dict(type='ImageToTensor', keys=['real_img']),
17 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path'])
18 | ]
19 |
20 | # `samples_per_gpu` and `imgs_root` need to be set.
21 | data = dict(
22 | # samples per gpu should be the same as the first scale, e.g. '4': 64
23 | # in this case
24 | samples_per_gpu=None,
25 | workers_per_gpu=4,
26 | train=dict(
27 | type=dataset_type,
28 | # just an example
29 | imgs_roots={
30 | '64': './data/celebahq/imgs_64',
31 | '256': './data/celebahq/imgs_256',
32 | '512': './data/celebahq/imgs_512',
33 | '1024': './data/celebahq/imgs_1024'
34 | },
35 | pipeline=train_pipeline,
36 | gpu_samples_base=4,
37 | # note that this should be changed with total gpu number
38 | gpu_samples_per_scale={
39 | '4': 64,
40 | '8': 32,
41 | '16': 16,
42 | '32': 8,
43 | '64': 4
44 | },
45 | len_per_stage=300000))
46 |
--------------------------------------------------------------------------------
/configs/styleganv2/stylegan2_c2_lsun-church_256_b4x8_800k.py:
--------------------------------------------------------------------------------
1 | """Note that this config is just for testing."""
2 |
3 | _base_ = [
4 | '../_base_/datasets/lsun_stylegan.py',
5 | '../_base_/models/stylegan/stylegan2_base.py',
6 | '../_base_/default_runtime.py'
7 | ]
8 |
9 | model = dict(generator=dict(out_size=256), discriminator=dict(in_size=256))
10 |
11 | data = dict(
12 | samples_per_gpu=4,
13 | train=dict(dataset=dict(imgs_root='./data/lsun-church')))
14 |
15 | ema_half_life = 10. # G_smoothing_kimg
16 |
17 | custom_hooks = [
18 | dict(
19 | type='VisualizeUnconditionalSamples',
20 | output_dir='training_samples',
21 | interval=5000),
22 | dict(
23 | type='ExponentialMovingAverageHook',
24 | module_keys=('generator_ema', ),
25 | interval=1,
26 | interp_cfg=dict(momentum=0.5**(32. / (ema_half_life * 1000.))),
27 | priority='VERY_HIGH')
28 | ]
29 |
30 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=30)
31 | lr_config = None
32 |
33 | log_config = dict(
34 | interval=100,
35 | hooks=[
36 | dict(type='TextLoggerHook'),
37 | # dict(type='TensorboardLoggerHook'),
38 | ])
39 |
40 | total_iters = 800002 # need to modify
41 |
42 | metrics = dict(
43 | fid50k=dict(
44 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True),
45 | pr50k3=dict(type='PR', num_images=50000, k=3),
46 | ppl_wend=dict(type='PPL', space='W', sampling='end', num_images=50000))
47 |
--------------------------------------------------------------------------------
/configs/positional_encoding_in_gans/stylegan2_c2_ffhq_256_b3x8_1100k.py:
--------------------------------------------------------------------------------
1 | """Config for the `config-f` setting in StyleGAN2."""
2 |
3 | _base_ = [
4 | '../_base_/datasets/ffhq_flip.py',
5 | '../_base_/models/stylegan/stylegan2_base.py',
6 | '../_base_/default_runtime.py'
7 | ]
8 |
9 | model = dict(generator=dict(out_size=256), discriminator=dict(in_size=256))
10 |
11 | data = dict(
12 | samples_per_gpu=3,
13 | train=dict(dataset=dict(imgs_root='./data/ffhq/ffhq_imgs/ffhq_256')))
14 |
15 | ema_half_life = 10. # G_smoothing_kimg
16 |
17 | custom_hooks = [
18 | dict(
19 | type='VisualizeUnconditionalSamples',
20 | output_dir='training_samples',
21 | interval=5000),
22 | dict(
23 | type='ExponentialMovingAverageHook',
24 | module_keys=('generator_ema', ),
25 | interval=1,
26 | interp_cfg=dict(momentum=0.5**(32. / (ema_half_life * 1000.))),
27 | priority='VERY_HIGH')
28 | ]
29 |
30 | metrics = dict(
31 | fid50k=dict(
32 | type='FID',
33 | num_images=50000,
34 | inception_pkl='work_dirs/inception_pkl/ffhq-256-50k-rgb.pkl',
35 | bgr2rgb=True),
36 | pr10k3=dict(type='PR', num_images=10000, k=3))
37 |
38 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=30)
39 | lr_config = None
40 |
41 | log_config = dict(
42 | interval=100,
43 | hooks=[
44 | dict(type='TextLoggerHook'),
45 | # dict(type='TensorboardLoggerHook'),
46 | ])
47 |
48 | total_iters = 1100002
49 |
--------------------------------------------------------------------------------
/configs/positional_encoding_in_gans/stylegan2_c2_ffhq_512_b3x8_1100k.py:
--------------------------------------------------------------------------------
1 | """Config for the `config-f` setting in StyleGAN2."""
2 |
3 | _base_ = [
4 | '../_base_/datasets/ffhq_flip.py',
5 | '../_base_/models/stylegan/stylegan2_base.py',
6 | '../_base_/default_runtime.py'
7 | ]
8 |
9 | model = dict(generator=dict(out_size=512), discriminator=dict(in_size=512))
10 |
11 | data = dict(
12 | samples_per_gpu=3,
13 | train=dict(dataset=dict(imgs_root='./data/ffhq/ffhq_imgs/ffhq_512')))
14 |
15 | ema_half_life = 10. # G_smoothing_kimg
16 |
17 | custom_hooks = [
18 | dict(
19 | type='VisualizeUnconditionalSamples',
20 | output_dir='training_samples',
21 | interval=5000),
22 | dict(
23 | type='ExponentialMovingAverageHook',
24 | module_keys=('generator_ema', ),
25 | interval=1,
26 | interp_cfg=dict(momentum=0.5**(32. / (ema_half_life * 1000.))),
27 | priority='VERY_HIGH')
28 | ]
29 |
30 | metrics = dict(
31 | fid50k=dict(
32 | type='FID',
33 | num_images=50000,
34 | inception_pkl='work_dirs/inception_pkl/ffhq-512-50k-rgb.pkl',
35 | bgr2rgb=True),
36 | pr10k3=dict(type='PR', num_images=10000, k=3))
37 |
38 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=30)
39 | lr_config = None
40 |
41 | log_config = dict(
42 | interval=100,
43 | hooks=[
44 | dict(type='TextLoggerHook'),
45 | # dict(type='TensorboardLoggerHook'),
46 | ])
47 |
48 | total_iters = 1100002
49 |
--------------------------------------------------------------------------------
/configs/_base_/models/pggan/pggan_128x128.py:
--------------------------------------------------------------------------------
1 | # define GAN model
2 | model = dict(
3 | type='ProgressiveGrowingGAN',
4 | generator=dict(type='PGGANGenerator', out_scale=128, noise_size=512),
5 | discriminator=dict(type='PGGANDiscriminator', in_scale=128),
6 | gan_loss=dict(type='GANLoss', gan_type='wgan'),
7 | disc_auxiliary_loss=[
8 | dict(
9 | type='DiscShiftLoss',
10 | loss_weight=0.001 * 0.5,
11 | data_info=dict(pred='disc_pred_fake')),
12 | dict(
13 | type='DiscShiftLoss',
14 | loss_weight=0.001 * 0.5,
15 | data_info=dict(pred='disc_pred_real')),
16 | dict(
17 | type='GradientPenaltyLoss',
18 | loss_weight=10,
19 | norm_mode='HWC',
20 | data_info=dict(
21 | discriminator='disc_partial',
22 | real_data='real_imgs',
23 | fake_data='fake_imgs'))
24 | ])
25 |
26 | train_cfg = dict(
27 | use_ema=True,
28 | nkimgs_per_scale={
29 | '4': 600,
30 | '8': 1200,
31 | '16': 1200,
32 | '32': 1200,
33 | '64': 1200,
34 | '128': 12000
35 | },
36 | transition_kimgs=600,
37 | optimizer_cfg=dict(
38 | generator=dict(type='Adam', lr=0.001, betas=(0., 0.99)),
39 | discriminator=dict(type='Adam', lr=0.001, betas=(0., 0.99))),
40 | g_lr_base=0.001,
41 | d_lr_base=0.001,
42 | g_lr_schedule={'128': 0.0015},
43 | d_lr_schedule={'128': 0.0015})
44 |
45 | test_cfg = None
46 |
--------------------------------------------------------------------------------
/configs/_base_/models/stylegan/stylegan3_base.py:
--------------------------------------------------------------------------------
1 | # define GAN model
2 |
3 | d_reg_interval = 16
4 | g_reg_interval = 4
5 |
6 | g_reg_ratio = g_reg_interval / (g_reg_interval + 1)
7 | d_reg_ratio = d_reg_interval / (d_reg_interval + 1)
8 |
9 | model = dict(
10 | type='StaticUnconditionalGAN',
11 | generator=dict(
12 | type='StyleGANv3Generator',
13 | noise_size=512,
14 | style_channels=512,
15 | out_size=None, # Need to be set.
16 | img_channels=3,
17 | ),
18 | discriminator=dict(
19 | type='StyleGAN2Discriminator',
20 | in_size=None, # Need to be set.
21 | ),
22 | gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'),
23 | disc_auxiliary_loss=dict(
24 | type='R1GradientPenalty',
25 | loss_weight=10. / 2. * d_reg_interval,
26 | interval=d_reg_interval,
27 | norm_mode='HWC',
28 | data_info=dict(real_data='real_imgs', discriminator='disc')),
29 | gen_auxiliary_loss=dict(
30 | type='GeneratorPathRegularizer',
31 | loss_weight=2. * g_reg_interval,
32 | pl_batch_shrink=2,
33 | interval=g_reg_interval,
34 | data_info=dict(generator='gen', num_batches='batch_size')))
35 |
36 | train_cfg = dict(use_ema=True)
37 | test_cfg = None
38 |
39 | # define optimizer
40 | optimizer = dict(
41 | generator=dict(
42 | type='Adam', lr=0.0025 * g_reg_ratio, betas=(0, 0.99**g_reg_ratio)),
43 | discriminator=dict(
44 | type='Adam', lr=0.002 * d_reg_ratio, betas=(0, 0.99**d_reg_ratio)))
45 |
--------------------------------------------------------------------------------
/configs/styleganv1/styleganv1_ffhq_1024_g8_25Mimg.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/stylegan/styleganv1_base.py',
3 | '../_base_/datasets/grow_scale_imgs_ffhq_styleganv1.py',
4 | '../_base_/default_runtime.py',
5 | ]
6 |
7 | model = dict(generator=dict(out_size=1024), discriminator=dict(in_size=1024))
8 |
9 | train_cfg = dict(
10 | nkimgs_per_scale={
11 | '8': 1200,
12 | '16': 1200,
13 | '32': 1200,
14 | '64': 1200,
15 | '128': 1200,
16 | '256': 1200,
17 | '512': 1200,
18 | '1024': 166000
19 | })
20 |
21 | checkpoint_config = dict(interval=5000, by_epoch=False, max_keep_ckpts=20)
22 | lr_config = None
23 |
24 | ema_half_life = 10. # G_smoothing_kimg
25 |
26 | custom_hooks = [
27 | dict(
28 | type='VisualizeUnconditionalSamples',
29 | output_dir='training_samples',
30 | interval=5000),
31 | dict(type='PGGANFetchDataHook', interval=1),
32 | dict(
33 | type='ExponentialMovingAverageHook',
34 | module_keys=('generator_ema', ),
35 | interval=1,
36 | interp_cfg=dict(momentum=0.5**(32. / (ema_half_life * 1000.))),
37 | priority='VERY_HIGH')
38 | ]
39 |
40 | total_iters = 670000
41 |
42 | metrics = dict(
43 | fid50k=dict(
44 | type='FID',
45 | num_images=50000,
46 | inception_pkl='work_dirs/inception_pkl/ffhq-1024-50k-rgb.pkl',
47 | bgr2rgb=True),
48 | pr50k3=dict(type='PR', num_images=50000, k=3),
49 | ppl_wend=dict(type='PPL', space='W', sampling='end', num_images=50000))
50 |
--------------------------------------------------------------------------------
/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_cifar10_32x32_b8x16_500k.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/improved_ddpm/ddpm_32x32.py',
3 | '../_base_/datasets/cifar10_noaug.py', '../_base_/default_runtime.py'
4 | ]
5 |
6 | lr_config = None
7 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=20)
8 | custom_hooks = [
9 | dict(
10 | type='MMGenVisualizationHook',
11 | output_dir='training_samples',
12 | res_name_list=['real_imgs', 'x_0_pred', 'x_t', 'x_t_1'],
13 | padding=1,
14 | interval=1000),
15 | dict(
16 | type='ExponentialMovingAverageHook',
17 | module_keys=('denoising_ema'),
18 | interval=1,
19 | start_iter=0,
20 | interp_cfg=dict(momentum=0.9999),
21 | priority='VERY_HIGH')
22 | ]
23 |
24 | # do not evaluation in training process because evaluation take too much time.
25 | evaluation = None
26 |
27 | total_iters = 500000 # 500k
28 | data = dict(samples_per_gpu=16) # 8x16=128
29 |
30 | # use ddp wrapper for faster training
31 | use_ddp_wrapper = True
32 | find_unused_parameters = False
33 |
34 | runner = dict(
35 | type='DynamicIterBasedRunner',
36 | is_dynamic_ddp=False, # Note that this flag should be False.
37 | pass_training_status=True)
38 |
39 | inception_pkl = './work_dirs/inception_pkl/cifar10.pkl'
40 | metrics = dict(
41 | fid50k=dict(
42 | type='FID',
43 | num_images=50000,
44 | bgr2rgb=True,
45 | inception_pkl=inception_pkl,
46 | inception_args=dict(type='StyleGAN')))
47 |
--------------------------------------------------------------------------------
/configs/singan/metafile.yml:
--------------------------------------------------------------------------------
1 | Collections:
2 | - Metadata:
3 | Architecture:
4 | - SinGAN
5 | Name: SinGAN
6 | Paper:
7 | - https://openaccess.thecvf.com/content_ICCV_2019/html/Shaham_SinGAN_Learning_a_Generative_Model_From_a_Single_Natural_Image_ICCV_2019_paper.html
8 | README: configs/singan/README.md
9 | Models:
10 | - Config: https://github.com/open-mmlab/mmgeneration/tree/master/configs/singan/singan_balloons.py
11 | In Collection: SinGAN
12 | Metadata:
13 | Training Data: Others
14 | Name: singan_balloons
15 | Results:
16 | - Dataset: Others
17 | Metrics:
18 | Num Scales: 8.0
19 | Task: Internal Learning
20 | Weights: https://download.openmmlab.com/mmgen/singan/singan_balloons_20210406_191047-8fcd94cf.pth
21 | - Config: https://github.com/open-mmlab/mmgeneration/tree/master/configs/singan/singan_fish.py
22 | In Collection: SinGAN
23 | Metadata:
24 | Training Data: Others
25 | Name: singan_fish
26 | Results:
27 | - Dataset: Others
28 | Metrics:
29 | Num Scales: 10.0
30 | Task: Internal Learning
31 | Weights: https://download.openmmlab.com/mmgen/singan/singan_fis_20210406_201006-860d91b6.pth
32 | - Config: https://github.com/open-mmlab/mmgeneration/tree/master/configs/singan/singan_bohemian.py
33 | In Collection: SinGAN
34 | Metadata:
35 | Training Data: Others
36 | Name: singan_bohemian
37 | Results:
38 | - Dataset: Others
39 | Metrics:
40 | Num Scales: 10.0
41 | Task: Internal Learning
42 | Weights: https://download.openmmlab.com/mmgen/singan/singan_bohemian_20210406_175439-f964ee38.pth
43 |
--------------------------------------------------------------------------------
/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_imagenet1k_64x64_b8x16_1500k.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/improved_ddpm/ddpm_64x64.py',
3 | '../_base_/datasets/imagenet_noaug_64.py', '../_base_/default_runtime.py'
4 | ]
5 |
6 | lr_config = None
7 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=20)
8 | custom_hooks = [
9 | dict(
10 | type='MMGenVisualizationHook',
11 | output_dir='training_samples',
12 | res_name_list=['real_imgs', 'x_0_pred', 'x_t', 'x_t_1'],
13 | padding=1,
14 | interval=1000),
15 | dict(
16 | type='ExponentialMovingAverageHook',
17 | module_keys=('denoising_ema'),
18 | interval=1,
19 | start_iter=0,
20 | interp_cfg=dict(momentum=0.9999),
21 | priority='VERY_HIGH')
22 | ]
23 |
24 | # do not evaluation in training process because evaluation take too much time.
25 | evaluation = None
26 |
27 | total_iters = 1500000 # 1500k
28 | data = dict(samples_per_gpu=16) # 8x16=128
29 |
30 | # use ddp wrapper for faster training
31 | use_ddp_wrapper = True
32 | find_unused_parameters = False
33 |
34 | runner = dict(
35 | type='DynamicIterBasedRunner',
36 | is_dynamic_ddp=False, # Note that this flag should be False.
37 | pass_training_status=True)
38 |
39 | inception_pkl = './work_dirs/inception_pkl/imagenet_64x64.pkl'
40 | metrics = dict(
41 | fid50k=dict(
42 | type='FID',
43 | num_images=50000,
44 | bgr2rgb=True,
45 | inception_pkl=inception_pkl,
46 | inception_args=dict(type='StyleGAN')))
47 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/cifar10_noaug.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'mmcls.CIFAR10'
2 |
3 | # different from mmcls, we adopt the setting used in BigGAN
4 | # Note that the pipelines below are from MMClassification. Importantly, the
5 | # `to_rgb` is set to `True` to convert image to BGR orders. The default order
6 | # in Cifar10 is RGB. Thus, we have to convert it to BGR.
7 |
8 | # Cifar dataset w/o augmentations. Remove `RandomFlip` and `RandomCrop`
9 | # augmentations.
10 | img_norm_cfg = dict(
11 | mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
12 | train_pipeline = [
13 | dict(type='Normalize', **img_norm_cfg),
14 | dict(type='ImageToTensor', keys=['img']),
15 | dict(type='ToTensor', keys=['gt_label']),
16 | dict(type='Collect', keys=['img', 'gt_label'])
17 | ]
18 | test_pipeline = [
19 | dict(type='Normalize', **img_norm_cfg),
20 | dict(type='ImageToTensor', keys=['img']),
21 | dict(type='Collect', keys=['img'])
22 | ]
23 |
24 | # Different from the classification task, the val/test split also use the
25 | # training part, which is the same to StyleGAN-ADA.
26 | data = dict(
27 | samples_per_gpu=None,
28 | workers_per_gpu=4,
29 | train=dict(
30 | type='RepeatDataset',
31 | times=500,
32 | dataset=dict(
33 | type=dataset_type,
34 | data_prefix='data/cifar10',
35 | pipeline=train_pipeline)),
36 | val=dict(
37 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline),
38 | test=dict(
39 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline))
40 |
--------------------------------------------------------------------------------
/configs/_base_/models/improved_ddpm/ddpm_64x64.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='BasicGaussianDiffusion',
3 | num_timesteps=4000,
4 | betas_cfg=dict(type='cosine'),
5 | denoising=dict(
6 | type='DenoisingUnet',
7 | image_size=64,
8 | in_channels=3,
9 | base_channels=128,
10 | resblocks_per_downsample=3,
11 | attention_res=[16, 8],
12 | use_scale_shift_norm=True,
13 | dropout=0,
14 | num_heads=4,
15 | use_rescale_timesteps=True,
16 | output_cfg=dict(mean='eps', var='learned_range')),
17 | timestep_sampler=dict(type='UniformTimeStepSampler'),
18 | ddpm_loss=[
19 | dict(
20 | type='DDPMVLBLoss',
21 | rescale_mode='constant',
22 | rescale_cfg=dict(scale=4000 / 1000),
23 | data_info=dict(
24 | mean_pred='mean_pred',
25 | mean_target='mean_posterior',
26 | logvar_pred='logvar_pred',
27 | logvar_target='logvar_posterior'),
28 | log_cfgs=[
29 | dict(
30 | type='quartile',
31 | prefix_name='loss_vlb',
32 | total_timesteps=4000),
33 | dict(type='name')
34 | ]),
35 | dict(
36 | type='DDPMMSELoss',
37 | log_cfgs=dict(
38 | type='quartile', prefix_name='loss_mse', total_timesteps=4000),
39 | )
40 | ],
41 | )
42 |
43 | train_cfg = dict(use_ema=True, real_img_key='img')
44 | test_cfg = None
45 | optimizer = dict(denoising=dict(type='AdamW', lr=1e-4, weight_decay=0))
46 |
--------------------------------------------------------------------------------
/configs/lsgan/lsgan_lsgan-archi_lr-1e-4_lsun-bedroom_128_b64x1_10m.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/lsgan/lsgan_128x128.py',
3 | '../_base_/datasets/unconditional_imgs_128x128.py',
4 | '../_base_/default_runtime.py'
5 | ]
6 | # define dataset
7 | # you must set `samples_per_gpu` and `imgs_root`
8 | data = dict(
9 | samples_per_gpu=64, train=dict(imgs_root='./data/lsun/bedroom_train'))
10 |
11 | optimizer = dict(
12 | generator=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99)),
13 | discriminator=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99)))
14 |
15 | # adjust running config
16 | lr_config = None
17 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=20)
18 | custom_hooks = [
19 | dict(
20 | type='VisualizeUnconditionalSamples',
21 | output_dir='training_samples',
22 | interval=10000)
23 | ]
24 |
25 | evaluation = dict(
26 | type='GenerativeEvalHook',
27 | interval=10000,
28 | metrics=dict(
29 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True),
30 | sample_kwargs=dict(sample_model='orig'))
31 |
32 | total_iters = 160000
33 | # use ddp wrapper for faster training
34 | use_ddp_wrapper = True
35 | find_unused_parameters = False
36 |
37 | runner = dict(
38 | type='DynamicIterBasedRunner',
39 | is_dynamic_ddp=False, # Note that this flag should be False.
40 | pass_training_status=True)
41 |
42 | metrics = dict(
43 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000),
44 | swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 128, 128)),
45 | fid50k=dict(type='FID', num_images=50000, inception_pkl=None))
46 |
--------------------------------------------------------------------------------
/configs/_base_/models/improved_ddpm/ddpm_32x32.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='BasicGaussianDiffusion',
3 | num_timesteps=4000,
4 | betas_cfg=dict(type='cosine'),
5 | denoising=dict(
6 | type='DenoisingUnet',
7 | image_size=32,
8 | in_channels=3,
9 | base_channels=128,
10 | resblocks_per_downsample=3,
11 | attention_res=[16, 8],
12 | use_scale_shift_norm=True,
13 | dropout=0.3,
14 | num_heads=4,
15 | use_rescale_timesteps=True,
16 | output_cfg=dict(mean='eps', var='learned_range')),
17 | timestep_sampler=dict(type='UniformTimeStepSampler'),
18 | ddpm_loss=[
19 | dict(
20 | type='DDPMVLBLoss',
21 | rescale_mode='constant',
22 | rescale_cfg=dict(scale=4000 / 1000),
23 | data_info=dict(
24 | mean_pred='mean_pred',
25 | mean_target='mean_posterior',
26 | logvar_pred='logvar_pred',
27 | logvar_target='logvar_posterior'),
28 | log_cfgs=[
29 | dict(
30 | type='quartile',
31 | prefix_name='loss_vlb',
32 | total_timesteps=4000),
33 | dict(type='name')
34 | ]),
35 | dict(
36 | type='DDPMMSELoss',
37 | log_cfgs=dict(
38 | type='quartile', prefix_name='loss_mse', total_timesteps=4000),
39 | )
40 | ],
41 | )
42 |
43 | train_cfg = dict(use_ema=True, real_img_key='img')
44 | test_cfg = None
45 | optimizer = dict(denoising=dict(type='AdamW', lr=1e-4, weight_decay=0))
46 |
--------------------------------------------------------------------------------
/configs/styleganv2/stylegan2_c2_lsun-car_384x512_b4x8.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/datasets/lsun-car_pad_512.py',
3 | '../_base_/models/stylegan/stylegan2_base.py',
4 | '../_base_/default_runtime.py'
5 | ]
6 |
7 | model = dict(generator=dict(out_size=512), discriminator=dict(in_size=512))
8 |
9 | data = dict(
10 | samples_per_gpu=4,
11 | train=dict(dataset=dict(imgs_root='./data/lsun/images/car')),
12 | val=dict(imgs_root='./data/lsun/images/car'))
13 |
14 | ema_half_life = 10. # G_smoothing_kimg
15 |
16 | custom_hooks = [
17 | dict(
18 | type='VisualizeUnconditionalSamples',
19 | output_dir='training_samples',
20 | interval=5000),
21 | dict(
22 | type='ExponentialMovingAverageHook',
23 | module_keys=('generator_ema', ),
24 | interval=1,
25 | interp_cfg=dict(momentum=0.5**(32. / (ema_half_life * 1000.))),
26 | priority='VERY_HIGH')
27 | ]
28 |
29 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=40)
30 | lr_config = None
31 |
32 | total_iters = 1800002
33 |
34 | metrics = dict(
35 | fid50k=dict(
36 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True),
37 | pr50k3=dict(type='PR', num_images=50000, k=3),
38 | ppl_wend=dict(type='PPL', space='W', sampling='end', num_images=50000))
39 |
40 | evaluation = dict(
41 | type='GenerativeEvalHook',
42 | interval=10000,
43 | metrics=dict(
44 | type='FID',
45 | num_images=50000,
46 | inception_pkl='work_dirs/inception_pkl/lsun-car-512-50k-rgb.pkl',
47 | bgr2rgb=True),
48 | sample_kwargs=dict(sample_model='ema'))
49 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/error-report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Error report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 | ---
8 |
9 | Thanks for your error report and we appreciate it a lot.
10 |
11 | **Checklist**
12 |
13 | 1. I have searched related issues but cannot get the expected help.
14 | 2. I have read the [FAQ documentation](https://mmgeneration.readthedocs.io/en/latest/faq.html) but cannot get the expected help.
15 | 3. The bug has not been fixed in the latest version.
16 |
17 | **Describe the bug**
18 | A clear and concise description of what the bug is.
19 |
20 | **Reproduction**
21 |
22 | 1. What command or script did you run?
23 |
24 | ```none
25 | A placeholder for the command.
26 | ```
27 |
28 | 2. Did you make any modifications on the code or config? Did you understand what you have modified?
29 | 3. What dataset did you use?
30 |
31 | **Environment**
32 |
33 | 1. Please run `python mmgen/utils/collect_env.py` to collect necessary environment information and paste it here.
34 | 2. You may add addition that may be helpful for locating the problem, such as
35 | - How you installed PyTorch \[e.g., pip, conda, source\]
36 | - Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.)
37 |
38 | **Error traceback**
39 | If applicable, paste the error trackback here.
40 |
41 | ```none
42 | A placeholder for trackback.
43 | ```
44 |
45 | **Bug fix**
46 | If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!
47 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/imagenet_rgb.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | dataset_type = 'mmcls.ImageNet'
3 |
4 | # Note that the pipelines below are from MMClassification
5 | img_norm_cfg = dict(
6 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
7 | train_pipeline = [
8 | dict(type='LoadImageFromFile'),
9 | dict(type='RandomResizedCrop', size=224, backend='pillow'),
10 | dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
11 | dict(type='Normalize', **img_norm_cfg),
12 | dict(type='ImageToTensor', keys=['img']),
13 | dict(type='ToTensor', keys=['gt_label']),
14 | dict(type='Collect', keys=['img', 'gt_label'])
15 | ]
16 |
17 | test_pipeline = [
18 | dict(type='LoadImageFromFile'),
19 | dict(type='Resize', size=(256, -1), backend='pillow'),
20 | dict(type='CenterCrop', crop_size=224),
21 | dict(type='Normalize', **img_norm_cfg),
22 | dict(type='ImageToTensor', keys=['img']),
23 | dict(type='Collect', keys=['img'])
24 | ]
25 |
26 | data = dict(
27 | samples_per_gpu=64,
28 | workers_per_gpu=2,
29 | train=dict(
30 | type=dataset_type,
31 | data_prefix='data/imagenet/train',
32 | pipeline=train_pipeline),
33 | val=dict(
34 | type=dataset_type,
35 | data_prefix='data/imagenet/val',
36 | ann_file='data/imagenet/meta/val.txt',
37 | pipeline=test_pipeline),
38 | test=dict(
39 | # replace `data/val` with `data/test` for standard test
40 | type=dataset_type,
41 | data_prefix='data/imagenet/val',
42 | ann_file='data/imagenet/meta/val.txt',
43 | pipeline=test_pipeline))
44 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/cifar10_random_noise.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'mmcls.CIFAR10'
2 |
3 | # cifar dataset without augmentation
4 | # different from mmcls, we adopt the setting used in BigGAN
5 | # Note that the pipelines below are from MMClassification. Importantly, the
6 | # `to_rgb` is set to `True` to convert image to BGR orders. The default order
7 | # in Cifar10 is RGB. Thus, we have to convert it to BGR.
8 |
9 | # Follow the pipeline in
10 | # https://github.com/pfnet-research/sngan_projection/blob/master/datasets/cifar10.py
11 | # Only `RandomImageNoise` augmentation is adopted.
12 | img_norm_cfg = dict(
13 | mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
14 | train_pipeline = [
15 | dict(type='Normalize', **img_norm_cfg),
16 | dict(type='RandomImgNoise', keys=['img']),
17 | dict(type='ImageToTensor', keys=['img']),
18 | dict(type='ToTensor', keys=['gt_label']),
19 | dict(type='Collect', keys=['img', 'gt_label'])
20 | ]
21 | test_pipeline = [
22 | dict(type='Normalize', **img_norm_cfg),
23 | dict(type='ImageToTensor', keys=['img']),
24 | dict(type='Collect', keys=['img'])
25 | ]
26 |
27 | # Different from the classification task, the val/test split also use the
28 | # training part, which is the same to StyleGAN-ADA.
29 | data = dict(
30 | samples_per_gpu=None,
31 | workers_per_gpu=4,
32 | train=dict(
33 | type=dataset_type, data_prefix='data/cifar10',
34 | pipeline=train_pipeline),
35 | val=dict(
36 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline),
37 | test=dict(
38 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline))
39 |
--------------------------------------------------------------------------------
/tests/test_modules/test_mspie_archs.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from copy import deepcopy
3 |
4 | import torch
5 |
6 | from mmgen.models.architectures.stylegan import MSStyleGANv2Generator
7 |
8 |
9 | class TestMSStyleGAN2:
10 |
11 | @classmethod
12 | def setup_class(cls):
13 | cls.generator_cfg = dict(out_size=32, style_channels=16)
14 | cls.disc_cfg = dict(in_size=32, with_adaptive_pool=True)
15 |
16 | def test_msstylegan2_cpu(self):
17 |
18 | # test normal forward
19 | cfg_ = deepcopy(self.generator_cfg)
20 | g = MSStyleGANv2Generator(**cfg_)
21 | res = g(None, num_batches=2)
22 | assert res.shape == (2, 3, 32, 32)
23 |
24 | # set mix_prob as 1.0 and 0 to force cover lines
25 | cfg_ = deepcopy(self.generator_cfg)
26 | cfg_['mix_prob'] = 1
27 | g = MSStyleGANv2Generator(**cfg_)
28 | res = g(torch.randn, num_batches=2)
29 | assert res.shape == (2, 3, 32, 32)
30 |
31 | cfg_ = deepcopy(self.generator_cfg)
32 | cfg_['mix_prob'] = 0
33 | g = MSStyleGANv2Generator(**cfg_)
34 | res = g(torch.randn, num_batches=2)
35 | assert res.shape == (2, 3, 32, 32)
36 |
37 | cfg_ = deepcopy(self.generator_cfg)
38 | cfg_['mix_prob'] = 1
39 | g = MSStyleGANv2Generator(**cfg_)
40 | res = g(None, num_batches=2)
41 | assert res.shape == (2, 3, 32, 32)
42 |
43 | cfg_ = deepcopy(self.generator_cfg)
44 | cfg_['mix_prob'] = 0
45 | g = MSStyleGANv2Generator(**cfg_)
46 | res = g(None, num_batches=2)
47 | assert res.shape == (2, 3, 32, 32)
48 |
--------------------------------------------------------------------------------
/tests/test_ops/test_conv_gradfix.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from functools import partial
3 |
4 | import pytest
5 | import torch
6 | import torch.nn as nn
7 | from torch.autograd import gradgradcheck
8 |
9 | from mmgen.ops import conv2d, conv_transpose2d
10 |
11 |
12 | class TestCond2d:
13 |
14 | @classmethod
15 | def setup_class(cls):
16 | cls.input = torch.randn((1, 3, 32, 32))
17 | cls.weight = nn.Parameter(torch.randn(1, 3, 3, 3))
18 |
19 | @pytest.mark.skipif(
20 | not torch.cuda.is_available()
21 | or not hasattr(torch.backends.cudnn, 'allow_tf32'),
22 | reason='requires cuda')
23 | def test_conv2d_cuda(self):
24 | x = self.input.cuda()
25 | weight = self.weight.cuda()
26 | res = conv2d(x, weight, None, 1, 1)
27 | assert res.shape == (1, 1, 32, 32)
28 | gradgradcheck(partial(conv2d, weight=weight, padding=1, stride=1), x)
29 |
30 |
31 | class TestCond2dTansposed:
32 |
33 | @classmethod
34 | def setup_class(cls):
35 | cls.input = torch.randn((1, 3, 32, 32))
36 | cls.weight = nn.Parameter(torch.randn(3, 1, 3, 3))
37 |
38 | @pytest.mark.skipif(
39 | not torch.cuda.is_available()
40 | or not hasattr(torch.backends.cudnn, 'allow_tf32'),
41 | reason='requires cuda')
42 | def test_conv2d_transposed_cuda(self):
43 | x = self.input.cuda()
44 | weight = self.weight.cuda()
45 | res = conv_transpose2d(x, weight, None, 1, 1)
46 | assert res.shape == (1, 1, 32, 32)
47 | gradgradcheck(
48 | partial(conv_transpose2d, weight=weight, padding=1, stride=1), x)
49 |
--------------------------------------------------------------------------------
/configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_imagenet1k_64x64_b8x16_1500k.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/improved_ddpm/ddpm_64x64.py',
3 | '../_base_/datasets/imagenet_noaug_64.py', '../_base_/default_runtime.py'
4 | ]
5 |
6 | # set dropout prob as 0.3
7 | model = dict(denoising=dict(dropout=0.3))
8 |
9 | lr_config = None
10 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=20)
11 | custom_hooks = [
12 | dict(
13 | type='MMGenVisualizationHook',
14 | output_dir='training_samples',
15 | res_name_list=['real_imgs', 'x_0_pred', 'x_t', 'x_t_1'],
16 | padding=1,
17 | interval=1000),
18 | dict(
19 | type='ExponentialMovingAverageHook',
20 | module_keys=('denoising_ema'),
21 | interval=1,
22 | start_iter=0,
23 | interp_cfg=dict(momentum=0.9999),
24 | priority='VERY_HIGH')
25 | ]
26 |
27 | # do not evaluation in training process because evaluation take too much time.
28 | evaluation = None
29 |
30 | total_iters = 1500000 # 1500k
31 | data = dict(samples_per_gpu=16) # 8x16=128
32 |
33 | # use ddp wrapper for faster training
34 | use_ddp_wrapper = True
35 | find_unused_parameters = False
36 |
37 | runner = dict(
38 | type='DynamicIterBasedRunner',
39 | is_dynamic_ddp=False, # Note that this flag should be False.
40 | pass_training_status=True)
41 |
42 | inception_pkl = './work_dirs/inception_pkl/imagenet_64x64.pkl'
43 | metrics = dict(
44 | fid50k=dict(
45 | type='FID',
46 | num_images=50000,
47 | bgr2rgb=True,
48 | inception_pkl=inception_pkl,
49 | inception_args=dict(type='StyleGAN')))
50 |
--------------------------------------------------------------------------------
/configs/lsgan/lsgan_dcgan-archi_lr-1e-3_celeba-cropped_64_b128x1_12m.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/dcgan/dcgan_64x64.py',
3 | '../_base_/datasets/unconditional_imgs_64x64.py',
4 | '../_base_/default_runtime.py'
5 | ]
6 | model = dict(gan_loss=dict(type='GANLoss', gan_type='lsgan'))
7 | # define dataset
8 | # you must set `samples_per_gpu` and `imgs_root`
9 | data = dict(
10 | samples_per_gpu=128,
11 | train=dict(imgs_root='./data/celeba-cropped/cropped_images_aligned_png/'))
12 |
13 | optimizer = dict(
14 | generator=dict(type='Adam', lr=0.001, betas=(0.5, 0.99)),
15 | discriminator=dict(type='Adam', lr=0.001, betas=(0.5, 0.99)))
16 |
17 | # adjust running config
18 | lr_config = None
19 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=20)
20 | custom_hooks = [
21 | dict(
22 | type='VisualizeUnconditionalSamples',
23 | output_dir='training_samples',
24 | interval=10000)
25 | ]
26 |
27 | evaluation = dict(
28 | type='GenerativeEvalHook',
29 | interval=10000,
30 | metrics=dict(
31 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True),
32 | sample_kwargs=dict(sample_model='orig'))
33 |
34 | total_iters = 100000
35 | # use ddp wrapper for faster training
36 | use_ddp_wrapper = True
37 | find_unused_parameters = False
38 |
39 | runner = dict(
40 | type='DynamicIterBasedRunner',
41 | is_dynamic_ddp=False, # Note that this flag should be False.
42 | pass_training_status=True)
43 |
44 | metrics = dict(
45 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000),
46 | swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 64, 64)),
47 | fid50k=dict(type='FID', num_images=50000, inception_pkl=None))
48 |
--------------------------------------------------------------------------------
/mmgen/ops/stylegan3/ops/filtered_lrelu_rd.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for sign read mode.
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/mmgen/ops/stylegan3/ops/filtered_lrelu_wr.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for sign write mode.
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/mmgen/utils/dist_util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import numpy as np
3 | import torch
4 | import torch.distributed as dist
5 | from mmcv.runner import get_dist_info
6 |
7 |
8 | def check_dist_init():
9 | return dist.is_available() and dist.is_initialized()
10 |
11 |
12 | def sync_random_seed(seed=None, device='cuda'):
13 | """Make sure different ranks share the same seed.
14 |
15 | All workers must call
16 | this function, otherwise it will deadlock. This method is generally used in
17 | `DistributedSampler`, because the seed should be identical across all
18 | processes in the distributed group.
19 | In distributed sampling, different ranks should sample non-overlapped
20 | data in the dataset. Therefore, this function is used to make sure that
21 | each rank shuffles the data indices in the same order based
22 | on the same seed. Then different ranks could use different indices
23 | to select non-overlapped data from the same data list.
24 | Args:
25 | seed (int, Optional): The seed. Default to None.
26 | device (str): The device where the seed will be put on.
27 | Default to 'cuda'.
28 | Returns:
29 | int: Seed to be used.
30 | """
31 | if seed is None:
32 | seed = np.random.randint(2**31)
33 | assert isinstance(seed, int)
34 |
35 | rank, world_size = get_dist_info()
36 |
37 | if world_size == 1:
38 | return seed
39 |
40 | if rank == 0:
41 | random_num = torch.tensor(seed, dtype=torch.int32, device=device)
42 | else:
43 | random_num = torch.tensor(0, dtype=torch.int32, device=device)
44 | dist.broadcast(random_num, src=0)
45 | return random_num.item()
46 |
--------------------------------------------------------------------------------
/mmgen/ops/stylegan3/ops/filtered_lrelu_ns.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for no signs mode (no gradients required).
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/configs/lsgan/lsgan_dcgan-archi_lr-1e-4_lsun-bedroom_64_b128x1_12m.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/dcgan/dcgan_64x64.py',
3 | '../_base_/datasets/unconditional_imgs_64x64.py',
4 | '../_base_/default_runtime.py'
5 | ]
6 | model = dict(
7 | discriminator=dict(output_scale=4, out_channels=1),
8 | gan_loss=dict(type='GANLoss', gan_type='lsgan'))
9 | # define dataset
10 | # you must set `samples_per_gpu` and `imgs_root`
11 | data = dict(
12 | samples_per_gpu=128, train=dict(imgs_root='./data/lsun/bedroom_train'))
13 |
14 | optimizer = dict(
15 | generator=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99)),
16 | discriminator=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99)))
17 |
18 | # adjust running config
19 | lr_config = None
20 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=20)
21 | custom_hooks = [
22 | dict(
23 | type='VisualizeUnconditionalSamples',
24 | output_dir='training_samples',
25 | interval=10000)
26 | ]
27 |
28 | evaluation = dict(
29 | type='GenerativeEvalHook',
30 | interval=10000,
31 | metrics=dict(
32 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True),
33 | sample_kwargs=dict(sample_model='orig'))
34 |
35 | total_iters = 100000
36 | # use ddp wrapper for faster training
37 | use_ddp_wrapper = True
38 | find_unused_parameters = False
39 |
40 | runner = dict(
41 | type='DynamicIterBasedRunner',
42 | is_dynamic_ddp=False, # Note that this flag should be False.
43 | pass_training_status=True)
44 |
45 | metrics = dict(
46 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000),
47 | swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 64, 64)),
48 | fid50k=dict(type='FID', num_images=50000, inception_pkl=None))
49 |
--------------------------------------------------------------------------------
/configs/styleganv2/stylegan2_c2_ffhq_1024_b4x8.py:
--------------------------------------------------------------------------------
1 | """Config for the `config-f` setting in StyleGAN2."""
2 |
3 | _base_ = [
4 | '../_base_/datasets/ffhq_flip.py',
5 | '../_base_/models/stylegan/stylegan2_base.py',
6 | '../_base_/default_runtime.py'
7 | ]
8 |
9 | ema_half_life = 10. # G_smoothing_kimg
10 |
11 | model = dict(generator=dict(out_size=1024), discriminator=dict(in_size=1024))
12 |
13 | data = dict(
14 | samples_per_gpu=4,
15 | train=dict(dataset=dict(imgs_root='./data/ffhq/images')),
16 | val=dict(imgs_root='./data/ffhq/images'))
17 |
18 | custom_hooks = [
19 | dict(
20 | type='VisualizeUnconditionalSamples',
21 | output_dir='training_samples',
22 | interval=5000),
23 | dict(
24 | type='ExponentialMovingAverageHook',
25 | module_keys=('generator_ema', ),
26 | interval=1,
27 | interp_cfg=dict(momentum=0.5**(32. / (ema_half_life * 1000.))),
28 | priority='VERY_HIGH')
29 | ]
30 |
31 | metrics = dict(
32 | fid50k=dict(
33 | type='FID',
34 | num_images=50000,
35 | inception_pkl='work_dirs/inception_pkl/ffhq-1024-50k-rgb.pkl',
36 | bgr2rgb=True),
37 | pr50k3=dict(type='PR', num_images=50000, k=3),
38 | ppl_wend=dict(type='PPL', space='W', sampling='end', num_images=50000))
39 |
40 | evaluation = dict(
41 | type='GenerativeEvalHook',
42 | interval=10000,
43 | metrics=dict(
44 | type='FID',
45 | num_images=50000,
46 | inception_pkl='work_dirs/inception_pkl/ffhq-1024-50k-rgb.pkl',
47 | bgr2rgb=True),
48 | sample_kwargs=dict(sample_model='ema'))
49 |
50 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=30)
51 | lr_config = None
52 |
53 | total_iters = 800002
54 |
--------------------------------------------------------------------------------
/configs/lsgan/lsgan_dcgan-archi_lr-1e-4_celeba-cropped_128_b64x1_10m.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/dcgan/dcgan_128x128.py',
3 | '../_base_/datasets/unconditional_imgs_128x128.py',
4 | '../_base_/default_runtime.py'
5 | ]
6 | model = dict(
7 | discriminator=dict(output_scale=4, out_channels=1),
8 | gan_loss=dict(type='GANLoss', gan_type='lsgan'))
9 | # define dataset
10 | # you must set `samples_per_gpu` and `imgs_root`
11 | data = dict(
12 | samples_per_gpu=64,
13 | train=dict(imgs_root='./data/celeba-cropped/cropped_images_aligned_png/'))
14 |
15 | optimizer = dict(
16 | generator=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99)),
17 | discriminator=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99)))
18 |
19 | # adjust running config
20 | lr_config = None
21 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=20)
22 | custom_hooks = [
23 | dict(
24 | type='VisualizeUnconditionalSamples',
25 | output_dir='training_samples',
26 | interval=10000)
27 | ]
28 |
29 | evaluation = dict(
30 | type='GenerativeEvalHook',
31 | interval=10000,
32 | metrics=dict(
33 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True),
34 | sample_kwargs=dict(sample_model='orig'))
35 |
36 | total_iters = 160000
37 | # use ddp wrapper for faster training
38 | use_ddp_wrapper = True
39 | find_unused_parameters = False
40 |
41 | runner = dict(
42 | type='DynamicIterBasedRunner',
43 | is_dynamic_ddp=False, # Note that this flag should be False.
44 | pass_training_status=True)
45 |
46 | metrics = dict(
47 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000),
48 | swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 128, 128)),
49 | fid50k=dict(type='FID', num_images=50000, inception_pkl=None))
50 |
--------------------------------------------------------------------------------