├── docs ├── zh_cn │ ├── changelog.md │ ├── tutorials │ │ ├── config.md │ │ ├── customize_dataset.md │ │ ├── customize_models.md │ │ ├── customize_runtime.md │ │ ├── customize_losses.md │ │ └── index.rst │ ├── quick_run.md │ ├── faq.md │ ├── Makefile │ ├── index.rst │ ├── make.bat │ ├── stat.py │ └── api.rst └── en │ ├── switch_language.md │ ├── _static │ └── css │ │ └── readthedocs.css │ ├── tutorials │ └── index.rst │ ├── Makefile │ ├── index.rst │ ├── make.bat │ ├── faq.md │ ├── stat.py │ └── api.rst ├── requirements ├── readthedocs.txt ├── mminstall.txt ├── tests.txt ├── runtime.txt └── docs.txt ├── requirements.txt ├── tests ├── test_models │ ├── test_base_gan.py │ └── test_base_ddpm.py ├── data │ ├── image │ │ ├── baboon.png │ │ └── img_root │ │ │ ├── GT05.jpg │ │ │ ├── baboon.png │ │ │ ├── audi_a5.jpeg │ │ │ ├── horse │ │ │ └── horse.jpeg │ │ │ └── grass │ │ │ └── grass1.jpeg │ ├── paired │ │ ├── test │ │ │ ├── 3.jpg │ │ │ └── 33_AB.jpg │ │ └── train │ │ │ ├── 1.jpg │ │ │ └── 2.jpg │ └── unpaired │ │ ├── testA │ │ └── 5.jpg │ │ ├── testB │ │ └── 6.jpg │ │ ├── trainA │ │ ├── 1.jpg │ │ └── 2.jpg │ │ └── trainB │ │ ├── 3.jpg │ │ └── 4.jpg ├── test_datasets │ ├── test_quicktest_dataset.py │ ├── test_dataset_wrappers.py │ ├── test_singan_dataset.py │ ├── test_unconditional_image_dataset.py │ ├── test_persistent_worker.py │ └── test_pipelines │ │ ├── test_loading.py │ │ └── test_compose.py ├── test_utils │ └── test_io_utils.py ├── test_modules │ ├── test_lpips.py │ ├── test_fid_inception.py │ └── test_mspie_archs.py └── test_ops │ └── test_conv_gradfix.py ├── configs ├── sagan │ ├── sagan_128_cvt_studioGAN.py │ └── sagan_32_cvt_studioGAN.py ├── sngan_proj │ ├── sngan_proj_128_cvt_studioGAN.py │ └── sngan_proj_32_cvt_studioGAN.py ├── _base_ │ ├── datasets │ │ ├── singan.py │ │ ├── unconditional_imgs_128x128.py │ │ ├── unconditional_imgs_flip_256x256.py │ │ ├── unconditional_imgs_flip_512x512.py │ │ ├── unconditional_imgs_64x64.py │ │ ├── Inception_Score.py │ │ ├── unconditional_imgs_flip_lanczos_resize_256x256.py │ │ ├── lsun-car_pad_512.py │ │ ├── cifar10_rgb.py │ │ ├── grow_scale_imgs_ffhq_styleganv1.py │ │ ├── cifar10_inception_stat.py │ │ ├── grow_scale_imgs_128x128.py │ │ ├── ffhq_flip.py │ │ ├── lsun_stylegan.py │ │ ├── cifar10_nopad.py │ │ ├── cifar10.py │ │ ├── grow_scale_imgs_celeba-hq.py │ │ ├── cifar10_noaug.py │ │ ├── imagenet_rgb.py │ │ └── cifar10_random_noise.py │ ├── models │ │ ├── sngan_proj │ │ │ ├── sngan_proj_32x32.py │ │ │ └── sngan_proj_128x128.py │ │ ├── dcgan │ │ │ ├── dcgan_64x64.py │ │ │ └── dcgan_128x128.py │ │ ├── lsgan │ │ │ └── lsgan_128x128.py │ │ ├── sagan │ │ │ ├── sagan_32x32.py │ │ │ └── sagan_128x128.py │ │ ├── biggan │ │ │ ├── biggan_32x32.py │ │ │ └── biggan_128x128.py │ │ ├── wgangp │ │ │ └── wgangp_base.py │ │ ├── stylegan │ │ │ ├── styleganv1_base.py │ │ │ ├── stylegan2_base.py │ │ │ └── stylegan3_base.py │ │ ├── singan │ │ │ └── singan.py │ │ ├── pix2pix │ │ │ └── pix2pix_vanilla_unet_bn.py │ │ ├── pggan │ │ │ └── pggan_128x128.py │ │ └── improved_ddpm │ │ │ ├── ddpm_64x64.py │ │ │ └── ddpm_32x32.py │ ├── default_metrics.py │ └── default_runtime.py ├── positional_encoding_in_gans │ ├── singan_interp-pad_balloons.py │ ├── singan_interp-pad_disc-nobn_balloons.py │ ├── singan_interp-pad_disc-nobn_fish.py │ ├── singan_csg_fish.py │ ├── singan_csg_bohemian.py │ ├── singan_spe-dim4_fish.py │ ├── singan_spe-dim4_bohemian.py │ ├── singan_spe-dim8_bohemian.py │ ├── stylegan2_c2_ffhq_256_b3x8_1100k.py │ └── stylegan2_c2_ffhq_512_b3x8_1100k.py ├── styleganv2 │ ├── stylegan2_c2_apex_fp16_PL-R1-no-scaler_ffhq_256_b4x8_800k.py │ ├── stylegan2_c2_fp16-globalG-partialD_PL-R1-no-scaler_ffhq_256_b4x8_800k.py │ ├── stylegan2_c2_fp16_partial-GD_PL-no-scaler_ffhq_256_b4x8_800k.py │ ├── stylegan2_c2_apex_fp16_quicktest_ffhq_256_b4x8_800k.py │ ├── stylegan2_c2_fp16_quicktest_ffhq_256_b4x8_800k.py │ ├── stylegan2_c2_fp16-global_quicktest_ffhq_256_b4x8_800k.py │ ├── stylegan2_c2_lsun-cat_256_b4x8_800k.py │ ├── stylegan2_c2_lsun-horse_256_b4x8_800k.py │ ├── stylegan2_c2_lsun-church_256_b4x8_800k.py │ ├── stylegan2_c2_lsun-car_384x512_b4x8.py │ └── stylegan2_c2_ffhq_1024_b4x8.py ├── styleganv3 │ ├── stylegan3_t_afhqv2_512_b4x8_cvt_official_rgb.py │ ├── stylegan3_t_ffhq_1024_b4x8_cvt_official_rgb.py │ ├── stylegan3_t_ffhqu_256_b4x8_cvt_official_rgb.py │ ├── stylegan3_r_ffhqu_256_b4x8_cvt_official_rgb.py │ ├── stylegan3_r_ffhq_1024_b4x8_cvt_official_rgb.py │ └── stylegan3_r_afhqv2_512_b4x8_cvt_official_rgb.py ├── wgan-gp │ ├── wgangp_GN_GP-50_lsun-bedroom_128_b64x1_160kiter.py │ ├── wgangp_GN_celeba-cropped_128_b64x1_160kiter.py │ └── metafile.yml ├── ada │ └── metafile.yml ├── dcgan │ ├── dcgan_lsun-bedroom_64x64_b128x1_5e.py │ └── dcgan_celeba-cropped_64_b128x1_300k.py ├── biggan │ ├── biggan_128x128_cvt_BigGAN-PyTorch_rgb.py │ ├── biggan-deep_128x128_cvt_hugging-face_rgb.py │ ├── biggan-deep_256x256_cvt_hugging-face_rgb.py │ └── biggan-deep_512x512_cvt_hugging-face_rgb.py ├── singan │ ├── singan_balloons.py │ ├── singan_fish.py │ ├── singan_bohemian.py │ └── metafile.yml ├── pggan │ ├── pggan_celeba-hq_1024_g8_12Mimg.py │ ├── pggan_lsun-bedroom_128_g8_12Mimgs.py │ └── pggan_celeba-cropped_128_g8_12Mimgs.py ├── styleganv1 │ ├── metafile.yml │ ├── styleganv1_ffhq_256_g8_25Mimg.py │ └── styleganv1_ffhq_1024_g8_25Mimg.py ├── improved_ddpm │ ├── ddpm_cosine_hybird_timestep-4k_drop0.3_cifar10_32x32_b8x16_500k.py │ ├── ddpm_cosine_hybird_timestep-4k_imagenet1k_64x64_b8x16_1500k.py │ └── ddpm_cosine_hybird_timestep-4k_drop0.3_imagenet1k_64x64_b8x16_1500k.py └── lsgan │ ├── lsgan_lsgan-archi_lr-1e-4_lsun-bedroom_128_b64x1_10m.py │ ├── lsgan_dcgan-archi_lr-1e-3_celeba-cropped_64_b128x1_12m.py │ ├── lsgan_dcgan-archi_lr-1e-4_lsun-bedroom_64_b128x1_12m.py │ └── lsgan_dcgan-archi_lr-1e-4_celeba-cropped_128_b64x1_10m.py ├── mmgen ├── core │ ├── optimizer │ │ └── __init__.py │ ├── scheduler │ │ └── __init__.py │ ├── runners │ │ ├── __init__.py │ │ └── apex_amp_utils.py │ ├── __init__.py │ ├── hooks │ │ ├── __init__.py │ │ └── pggan_fetch_data_hook.py │ ├── evaluation │ │ └── __init__.py │ └── registry.py ├── models │ ├── architectures │ │ ├── arcface │ │ │ └── __init__.py │ │ ├── dcgan │ │ │ └── __init__.py │ │ ├── lsgan │ │ │ └── __init__.py │ │ ├── wgan_gp │ │ │ └── __init__.py │ │ ├── cyclegan │ │ │ └── __init__.py │ │ ├── pix2pix │ │ │ └── __init__.py │ │ ├── sngan_proj │ │ │ └── __init__.py │ │ ├── ddpm │ │ │ └── __init__.py │ │ ├── singan │ │ │ └── __init__.py │ │ ├── lpips │ │ │ └── __init__.py │ │ ├── stylegan │ │ │ ├── modules │ │ │ │ └── __init__.py │ │ │ ├── __init__.py │ │ │ └── ada │ │ │ │ └── misc.py │ │ ├── common.py │ │ ├── pggan │ │ │ └── __init__.py │ │ └── biggan │ │ │ └── __init__.py │ ├── diffusions │ │ ├── __init__.py │ │ └── sampler.py │ ├── common │ │ ├── __init__.py │ │ └── dist_utils.py │ ├── translation_models │ │ └── __init__.py │ ├── __init__.py │ ├── gans │ │ └── __init__.py │ ├── builder.py │ └── losses │ │ └── __init__.py ├── datasets │ ├── samplers │ │ └── __init__.py │ ├── quick_test_dataset.py │ ├── pipelines │ │ └── __init__.py │ ├── __init__.py │ └── dataset_wrappers.py ├── ops │ ├── __init__.py │ └── stylegan3 │ │ ├── ops │ │ ├── __init__.py │ │ ├── bias_act.h │ │ ├── filtered_lrelu_rd.cu │ │ ├── filtered_lrelu_wr.cu │ │ └── filtered_lrelu_ns.cu │ │ └── __init__.py ├── utils │ ├── __init__.py │ ├── logger.py │ └── dist_util.py ├── apis │ └── __init__.py ├── version.py └── __init__.py ├── .github ├── ISSUE_TEMPLATE │ ├── general_questions.md │ ├── config.yml │ ├── 4-documentation.yml │ ├── feature_request.md │ ├── 2-feature-request.yml │ ├── 3-new-model.yml │ └── error-report.md └── workflows │ ├── scripts │ └── get_mmcv_var.sh │ ├── publish-to-pypi.yml │ ├── lint.yml │ └── test_mim.yml ├── .readthedocs.yml ├── tools ├── eval.sh ├── dist_train.sh ├── dist_eval.sh ├── slurm_eval.sh ├── slurm_eval_multi_gpu.sh ├── slurm_train.sh ├── publish_model.py └── misc │ └── print_config.py ├── CITATION.cff ├── MANIFEST.in ├── .circleci ├── docker │ └── Dockerfile ├── scripts │ └── get_mmcv_var.sh └── config.yml ├── model-index.yml ├── setup.cfg ├── docker └── Dockerfile └── LICENSES.md /docs/zh_cn/changelog.md: -------------------------------------------------------------------------------- 1 | # 版本更新日志 2 | -------------------------------------------------------------------------------- /requirements/readthedocs.txt: -------------------------------------------------------------------------------- 1 | mmcv 2 | torch 3 | torchvision 4 | -------------------------------------------------------------------------------- /docs/zh_cn/tutorials/config.md: -------------------------------------------------------------------------------- 1 | # Tutorial 1: 配置系统 (config files) 2 | -------------------------------------------------------------------------------- /docs/zh_cn/tutorials/customize_dataset.md: -------------------------------------------------------------------------------- 1 | # Tutorial 2: 自定义数据集 2 | -------------------------------------------------------------------------------- /docs/zh_cn/tutorials/customize_models.md: -------------------------------------------------------------------------------- 1 | # Tutorial 3: 自定义模型 2 | -------------------------------------------------------------------------------- /docs/zh_cn/tutorials/customize_runtime.md: -------------------------------------------------------------------------------- 1 | # Tutorial 6: 自定义配置 2 | -------------------------------------------------------------------------------- /docs/zh_cn/tutorials/customize_losses.md: -------------------------------------------------------------------------------- 1 | # Tutorial 4: 损失函数模块的设计思路 2 | -------------------------------------------------------------------------------- /requirements/mminstall.txt: -------------------------------------------------------------------------------- 1 | mmcls>=0.18.0 2 | mmcv-full>=1.3.0,<=1.8.0 3 | -------------------------------------------------------------------------------- /docs/zh_cn/quick_run.md: -------------------------------------------------------------------------------- 1 | # 1: 在标准的数据集上训练和推理现有的模型 2 | 3 | ## 用现有的生成模型来生成图像 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements/runtime.txt 2 | -r requirements/tests.txt 3 | -------------------------------------------------------------------------------- /tests/test_models/test_base_gan.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /configs/sagan/sagan_128_cvt_studioGAN.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/models/sagan/sagan_128x128.py'] 2 | -------------------------------------------------------------------------------- /configs/sagan/sagan_32_cvt_studioGAN.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/models/sagan/sagan_32x32.py'] 2 | -------------------------------------------------------------------------------- /tests/data/image/baboon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/image/baboon.png -------------------------------------------------------------------------------- /configs/sngan_proj/sngan_proj_128_cvt_studioGAN.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/models/sngan_proj/sngan_proj_128x128.py'] 2 | -------------------------------------------------------------------------------- /configs/sngan_proj/sngan_proj_32_cvt_studioGAN.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../_base_/models/sngan_proj/sngan_proj_32x32.py'] 2 | -------------------------------------------------------------------------------- /tests/data/paired/test/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/paired/test/3.jpg -------------------------------------------------------------------------------- /tests/data/paired/train/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/paired/train/1.jpg -------------------------------------------------------------------------------- /tests/data/paired/train/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/paired/train/2.jpg -------------------------------------------------------------------------------- /tests/data/paired/test/33_AB.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/paired/test/33_AB.jpg -------------------------------------------------------------------------------- /tests/data/unpaired/testA/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/unpaired/testA/5.jpg -------------------------------------------------------------------------------- /tests/data/unpaired/testB/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/unpaired/testB/6.jpg -------------------------------------------------------------------------------- /tests/data/unpaired/trainA/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/unpaired/trainA/1.jpg -------------------------------------------------------------------------------- /tests/data/unpaired/trainA/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/unpaired/trainA/2.jpg -------------------------------------------------------------------------------- /tests/data/unpaired/trainB/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/unpaired/trainB/3.jpg -------------------------------------------------------------------------------- /tests/data/unpaired/trainB/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/unpaired/trainB/4.jpg -------------------------------------------------------------------------------- /tests/data/image/img_root/GT05.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/image/img_root/GT05.jpg -------------------------------------------------------------------------------- /requirements/tests.txt: -------------------------------------------------------------------------------- 1 | coverage < 7.0.0 2 | # codecov 3 | flake8 4 | interrogate 5 | isort==4.3.21 6 | pytest 7 | pytest-runner 8 | -------------------------------------------------------------------------------- /tests/data/image/img_root/baboon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/image/img_root/baboon.png -------------------------------------------------------------------------------- /requirements/runtime.txt: -------------------------------------------------------------------------------- 1 | mmcls 2 | ninja 3 | numpy 4 | prettytable 5 | requests 6 | scikit-image 7 | scipy 8 | tqdm 9 | yapf 10 | -------------------------------------------------------------------------------- /tests/data/image/img_root/audi_a5.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/image/img_root/audi_a5.jpeg -------------------------------------------------------------------------------- /tests/data/image/img_root/horse/horse.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/image/img_root/horse/horse.jpeg -------------------------------------------------------------------------------- /tests/data/image/img_root/grass/grass1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/mmgeneration/HEAD/tests/data/image/img_root/grass/grass1.jpeg -------------------------------------------------------------------------------- /mmgen/core/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import build_optimizers 3 | 4 | __all__ = ['build_optimizers'] 5 | -------------------------------------------------------------------------------- /mmgen/models/architectures/arcface/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .id_loss import IDLossModel 3 | 4 | __all__ = ['IDLossModel'] 5 | -------------------------------------------------------------------------------- /mmgen/core/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .lr_updater import LinearLrUpdaterHook 3 | 4 | __all__ = ['LinearLrUpdaterHook'] 5 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/general_questions.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: General questions 3 | about: Ask general questions to get help 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | --- 8 | -------------------------------------------------------------------------------- /docs/en/switch_language.md: -------------------------------------------------------------------------------- 1 | ## English 2 | 3 | ## 简体中文 4 | -------------------------------------------------------------------------------- /mmgen/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .distributed_sampler import DistributedSampler 3 | 4 | __all__ = ['DistributedSampler'] 5 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | python: 4 | version: 3.7 5 | install: 6 | - requirements: requirements/docs.txt 7 | - requirements: requirements/readthedocs.txt 8 | -------------------------------------------------------------------------------- /mmgen/core/runners/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dynamic_iterbased_runner import DynamicIterBasedRunner 3 | 4 | __all__ = ['DynamicIterBasedRunner'] 5 | -------------------------------------------------------------------------------- /mmgen/models/architectures/dcgan/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .generator_discriminator import DCGANDiscriminator, DCGANGenerator 3 | 4 | __all__ = ['DCGANGenerator', 'DCGANDiscriminator'] 5 | -------------------------------------------------------------------------------- /mmgen/models/architectures/lsgan/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .generator_discriminator import LSGANDiscriminator, LSGANGenerator 3 | 4 | __all__ = ['LSGANDiscriminator', 'LSGANGenerator'] 5 | -------------------------------------------------------------------------------- /mmgen/models/architectures/wgan_gp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .generator_discriminator import WGANGPDiscriminator, WGANGPGenerator 3 | 4 | __all__ = ['WGANGPDiscriminator', 'WGANGPGenerator'] 5 | -------------------------------------------------------------------------------- /tools/eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | CONFIG=$1 6 | CKPT=$2 7 | PY_ARGS=${@:3} 8 | 9 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 10 | python -u tools/evaluation.py ${CONFIG} ${CKPT} --launcher="none" ${PY_ARGS} 11 | -------------------------------------------------------------------------------- /docs/zh_cn/faq.md: -------------------------------------------------------------------------------- 1 | # 常见问题解答 2 | 3 | 我们在这里罗列了许多用户遇到的一些常见的问题和相应的解决方案。如果您发现任何常见问题并有办法帮助他人解决该问题,欢迎丰富该列表。如果此处的内容未涵盖您的问题,请使用[提供的模版](https://github.com/open-mmlab/mmgeneration/blob/master/.github/ISSUE_TEMPLATE/error-report.md)来创建新问题,并确保将模版中所有必需的信息填写完整。 4 | -------------------------------------------------------------------------------- /docs/zh_cn/tutorials/index.rst: -------------------------------------------------------------------------------- 1 | .. toctree:: 2 | :maxdepth: 2 3 | 4 | config.md 5 | customize_dataset.md 6 | customize_models.md 7 | customize_losses.md 8 | ddp_train_gans.md 9 | customize_runtime.md 10 | applications.md 11 | -------------------------------------------------------------------------------- /mmgen/models/diffusions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_diffusion import BasicGaussianDiffusion 3 | from .sampler import UniformTimeStepSampler 4 | 5 | __all__ = ['BasicGaussianDiffusion', 'UniformTimeStepSampler'] 6 | -------------------------------------------------------------------------------- /mmgen/models/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dist_utils import AllGatherLayer 3 | from .model_utils import GANImageBuffer, set_requires_grad 4 | 5 | __all__ = ['set_requires_grad', 'AllGatherLayer', 'GANImageBuffer'] 6 | -------------------------------------------------------------------------------- /mmgen/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .conv2d_gradfix import conv2d, conv_transpose2d 3 | from .stylegan3.ops import bias_act, filtered_lrelu 4 | 5 | __all__ = ['conv2d', 'conv_transpose2d', 'filtered_lrelu', 'bias_act'] 6 | -------------------------------------------------------------------------------- /mmgen/models/architectures/cyclegan/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .generator_discriminator import ResnetGenerator 3 | from .modules import ResidualBlockWithDropout 4 | 5 | __all__ = ['ResnetGenerator', 'ResidualBlockWithDropout'] 6 | -------------------------------------------------------------------------------- /docs/en/_static/css/readthedocs.css: -------------------------------------------------------------------------------- 1 | .header-logo { 2 | background-image: url("https://user-images.githubusercontent.com/12726765/114528756-de55af80-9c7b-11eb-94d7-d3224ada1585.png"); 3 | background-size: 173px 40px; 4 | height: 40px; 5 | width: 173px; 6 | } 7 | -------------------------------------------------------------------------------- /docs/en/tutorials/index.rst: -------------------------------------------------------------------------------- 1 | .. toctree:: 2 | :maxdepth: 2 3 | 4 | config.md 5 | customize_dataset.md 6 | customize_models.md 7 | customize_losses.md 8 | inception_stat.md 9 | ddp_train_gans.md 10 | customize_runtime.md 11 | applications.md 12 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - name: "MMGeneration Contributors" 5 | title: "OpenMMLab's next-generation toolbox for generative models" 6 | date-released: 2020-07-10 7 | url: "https://github.com/open-mmlab/mmgeneration" 8 | license: Apache-2.0 9 | -------------------------------------------------------------------------------- /configs/_base_/datasets/singan.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'SinGANDataset' 2 | 3 | data = dict( 4 | samples_per_gpu=1, 5 | workers_per_gpu=4, 6 | drop_last=False, 7 | train=dict( 8 | type=dataset_type, 9 | img_path=None, # need to set 10 | min_size=25, 11 | max_size=250, 12 | scale_factor_init=0.75)) 13 | -------------------------------------------------------------------------------- /mmgen/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .evaluation import * # noqa: F401, F403 3 | from .hooks import * # noqa: F401, F403 4 | from .optimizer import * # noqa: F401, F403 5 | from .registry import * # noqa: F401, F403 6 | from .runners import * # noqa: F401, F403 7 | from .scheduler import * # noqa: F401, F403 8 | -------------------------------------------------------------------------------- /configs/positional_encoding_in_gans/singan_interp-pad_balloons.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../singan/singan_balloons.py'] 2 | 3 | model = dict( 4 | type='PESinGAN', 5 | generator=dict( 6 | type='SinGANMSGeneratorPE', interp_pad=True, noise_with_pad=True)) 7 | 8 | train_cfg = dict(fixed_noise_with_pad=True) 9 | 10 | dist_params = dict(backend='nccl') 11 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include mmgen/model-index.yml 2 | recursive-include mmgen/configs *.py *.yml 3 | recursive-include mmgen/tools *.sh *.py 4 | 5 | include requirements/*.txt 6 | include mmgen/VERSION 7 | include mmgen/.mim/model-index.yml 8 | include mmgen/.mim/demo/*/* 9 | recursive-include mmgen/.mim/configs *.py *.yml 10 | recursive-include mmgen/.mim/tools *.sh *.py 11 | -------------------------------------------------------------------------------- /mmgen/models/architectures/pix2pix/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .generator_discriminator import PatchDiscriminator, UnetGenerator 3 | from .modules import UnetSkipConnectionBlock, generation_init_weights 4 | 5 | __all__ = [ 6 | 'PatchDiscriminator', 'UnetGenerator', 'UnetSkipConnectionBlock', 7 | 'generation_init_weights' 8 | ] 9 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | 3 | contact_links: 4 | - name: 💬 Forum 5 | url: https://github.com/open-mmlab/mmgeneration/discussions 6 | about: Ask general usage questions and discuss with other MMGeneration community members 7 | - name: 🌐 Explore OpenMMLab 8 | url: https://openmmlab.com/ 9 | about: Get know more about OpenMMLab 10 | -------------------------------------------------------------------------------- /mmgen/models/translation_models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_translation_model import BaseTranslationModel 3 | from .cyclegan import CycleGAN 4 | from .pix2pix import Pix2Pix 5 | from .static_translation_gan import StaticTranslationGAN 6 | 7 | __all__ = [ 8 | 'Pix2Pix', 'CycleGAN', 'BaseTranslationModel', 'StaticTranslationGAN' 9 | ] 10 | -------------------------------------------------------------------------------- /mmgen/models/architectures/sngan_proj/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .generator_discriminator import ProjDiscriminator, SNGANGenerator 3 | from .modules import SNGANDiscHeadResBlock, SNGANDiscResBlock, SNGANGenResBlock 4 | 5 | __all__ = [ 6 | 'ProjDiscriminator', 'SNGANGenerator', 'SNGANGenResBlock', 7 | 'SNGANDiscResBlock', 'SNGANDiscHeadResBlock' 8 | ] 9 | -------------------------------------------------------------------------------- /mmgen/models/architectures/ddpm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .denoising import DenoisingUnet 3 | from .modules import (DenoisingDownsample, DenoisingResBlock, 4 | DenoisingUpsample, TimeEmbedding) 5 | 6 | __all__ = [ 7 | 'DenoisingUnet', 'TimeEmbedding', 'DenoisingDownsample', 8 | 'DenoisingUpsample', 'DenoisingResBlock' 9 | ] 10 | -------------------------------------------------------------------------------- /configs/positional_encoding_in_gans/singan_interp-pad_disc-nobn_balloons.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../singan/singan_balloons.py'] 2 | 3 | model = dict( 4 | type='PESinGAN', 5 | generator=dict( 6 | type='SinGANMSGeneratorPE', interp_pad=True, noise_with_pad=True), 7 | discriminator=dict(norm_cfg=None)) 8 | 9 | train_cfg = dict(fixed_noise_with_pad=True) 10 | 11 | dist_params = dict(backend='nccl') 12 | -------------------------------------------------------------------------------- /mmgen/models/architectures/singan/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .generator_discriminator import (SinGANMultiScaleDiscriminator, 3 | SinGANMultiScaleGenerator) 4 | from .positional_encoding import SinGANMSGeneratorPE 5 | 6 | __all__ = [ 7 | 'SinGANMultiScaleDiscriminator', 'SinGANMultiScaleGenerator', 8 | 'SinGANMSGeneratorPE' 9 | ] 10 | -------------------------------------------------------------------------------- /mmgen/models/architectures/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | r""" 3 | The lpips module was adapted from https://github.com/rosinality/stylegan2-pytorch/tree/master/lpips , # noqa 4 | and you can see the origin implementation in https://github.com/richzhang/PerceptualSimilarity/tree/master/lpips # noqa 5 | """ 6 | from .perceptual_loss import PerceptualLoss 7 | 8 | __all__ = ['PerceptualLoss'] 9 | -------------------------------------------------------------------------------- /mmgen/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .collect_env import collect_env 3 | from .dist_util import check_dist_init, sync_random_seed 4 | from .io_utils import MMGEN_CACHE_DIR, download_from_url 5 | from .logger import get_root_logger 6 | 7 | __all__ = [ 8 | 'collect_env', 'get_root_logger', 'download_from_url', 'check_dist_init', 9 | 'MMGEN_CACHE_DIR', 'sync_random_seed' 10 | ] 11 | -------------------------------------------------------------------------------- /configs/styleganv2/stylegan2_c2_apex_fp16_PL-R1-no-scaler_ffhq_256_b4x8_800k.py: -------------------------------------------------------------------------------- 1 | """Config for the `config-f` setting in StyleGAN2.""" 2 | 3 | _base_ = ['./stylegan2_c2_ffhq_256_b4x8_800k.py'] 4 | 5 | model = dict( 6 | disc_auxiliary_loss=dict(use_apex_amp=False), 7 | gen_auxiliary_loss=dict(use_apex_amp=False), 8 | ) 9 | 10 | total_iters = 800002 11 | 12 | apex_amp = dict(mode='gan', init_args=dict(opt_level='O1', num_losses=2)) 13 | resume_from = None 14 | -------------------------------------------------------------------------------- /configs/styleganv3/stylegan3_t_afhqv2_512_b4x8_cvt_official_rgb.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./stylegan3_base.py'] 2 | 3 | synthesis_cfg = { 4 | 'type': 'SynthesisNetwork', 5 | 'channel_base': 32768, 6 | 'channel_max': 512, 7 | 'magnitude_ema_beta': 0.999 8 | } 9 | model = dict( 10 | type='StaticUnconditionalGAN', 11 | generator=dict( 12 | out_size=512, 13 | img_channels=3, 14 | rgb2bgr=True, 15 | synthesis_cfg=synthesis_cfg), 16 | discriminator=dict(in_size=512)) 17 | -------------------------------------------------------------------------------- /tests/test_datasets/test_quicktest_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmgen.datasets.quick_test_dataset import QuickTestImageDataset 3 | 4 | 5 | class TestQuickTest: 6 | 7 | @classmethod 8 | def setup_class(cls): 9 | cls.dataset = QuickTestImageDataset(size=(256, 256)) 10 | 11 | def test_quicktest_dataset(self): 12 | assert len(self.dataset) == 10000 13 | img = self.dataset[2] 14 | assert img['real_img'].shape == (3, 256, 256) 15 | -------------------------------------------------------------------------------- /mmgen/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .inference import (init_model, sample_conditional_model, 3 | sample_ddpm_model, sample_img2img_model, 4 | sample_unconditional_model) 5 | from .train import set_random_seed, train_model 6 | 7 | __all__ = [ 8 | 'set_random_seed', 'train_model', 'init_model', 'sample_img2img_model', 9 | 'sample_unconditional_model', 'sample_conditional_model', 10 | 'sample_ddpm_model' 11 | ] 12 | -------------------------------------------------------------------------------- /mmgen/ops/stylegan3/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /configs/styleganv3/stylegan3_t_ffhq_1024_b4x8_cvt_official_rgb.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./stylegan3_base.py'] 2 | 3 | synthesis_cfg = { 4 | 'type': 'SynthesisNetwork', 5 | 'channel_base': 32768, 6 | 'channel_max': 512, 7 | 'magnitude_ema_beta': 0.999 8 | } 9 | 10 | model = dict( 11 | type='StaticUnconditionalGAN', 12 | generator=dict( 13 | out_size=1024, 14 | img_channels=3, 15 | synthesis_cfg=synthesis_cfg, 16 | rgb2bgr=True), 17 | discriminator=dict(in_size=1024)) 18 | -------------------------------------------------------------------------------- /configs/wgan-gp/wgangp_GN_GP-50_lsun-bedroom_128_b64x1_160kiter.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./wgangp_GN_celeba-cropped_128_b64x1_160kiter.py'] 2 | 3 | model = dict(disc_auxiliary_loss=[ 4 | dict( 5 | type='GradientPenaltyLoss', 6 | loss_weight=50, 7 | norm_mode='HWC', 8 | data_info=dict( 9 | discriminator='disc', real_data='real_imgs', 10 | fake_data='fake_imgs')) 11 | ]) 12 | 13 | data = dict( 14 | samples_per_gpu=64, train=dict(imgs_root='./data/lsun/bedroom_train')) 15 | -------------------------------------------------------------------------------- /configs/styleganv3/stylegan3_t_ffhqu_256_b4x8_cvt_official_rgb.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./stylegan3_base.py'] 2 | 3 | synthesis_cfg = { 4 | 'type': 'SynthesisNetwork', 5 | 'channel_base': 16384, 6 | 'channel_max': 512, 7 | 'magnitude_ema_beta': 0.999 8 | } 9 | model = dict( 10 | type='StaticUnconditionalGAN', 11 | generator=dict( 12 | out_size=256, 13 | img_channels=3, 14 | rgb2bgr=True, 15 | synthesis_cfg=synthesis_cfg), 16 | discriminator=dict(in_size=256, channel_multiplier=1)) 17 | -------------------------------------------------------------------------------- /tests/test_utils/test_io_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | from tempfile import TemporaryDirectory 4 | 5 | from mmgen.utils import download_from_url 6 | 7 | 8 | def test_download_from_file(): 9 | img_url = 'https://user-images.githubusercontent.com/12726765/114528756-de55af80-9c7b-11eb-94d7-d3224ada1585.png' # noqa 10 | with TemporaryDirectory() as temp_dir: 11 | local_file = download_from_url(url=img_url, dest_dir=temp_dir) 12 | assert os.path.exists(local_file) 13 | -------------------------------------------------------------------------------- /mmgen/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .architectures import * # noqa: F401, F403 3 | from .builder import MODELS, MODULES, build_model, build_module 4 | from .common import * # noqa: F401, F403 5 | from .diffusions import * # noqa: F401, F403 6 | from .gans import * # noqa: F401, F403 7 | from .losses import * # noqa: F401, F403 8 | from .misc import * # noqa: F401, F403 9 | from .translation_models import * # noqa: F401, F403 10 | 11 | __all__ = ['build_model', 'MODELS', 'build_module', 'MODULES'] 12 | -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | NNODES=${NNODES:-1} 6 | NODE_RANK=${NODE_RANK:-0} 7 | PORT=${PORT:-29500} 8 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 9 | 10 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 11 | python -m torch.distributed.launch \ 12 | --nnodes=$NNODES \ 13 | --node_rank=$NODE_RANK \ 14 | --master_addr=$MASTER_ADDR \ 15 | --nproc_per_node=$GPUS \ 16 | --master_port=$PORT \ 17 | $(dirname "$0")/train.py \ 18 | $CONFIG \ 19 | --seed 0 \ 20 | --launcher pytorch ${@:3} 21 | -------------------------------------------------------------------------------- /configs/positional_encoding_in_gans/singan_interp-pad_disc-nobn_fish.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../singan/singan_fish.py'] 2 | 3 | model = dict( 4 | type='PESinGAN', 5 | generator=dict( 6 | type='SinGANMSGeneratorPE', interp_pad=True, noise_with_pad=True), 7 | discriminator=dict(norm_cfg=None)) 8 | 9 | train_cfg = dict(fixed_noise_with_pad=True) 10 | 11 | data = dict( 12 | train=dict( 13 | img_path='./data/singan/fish-crop.jpg', 14 | min_size=25, 15 | max_size=300, 16 | )) 17 | 18 | dist_params = dict(backend='nccl') 19 | -------------------------------------------------------------------------------- /mmgen/models/gans/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_gan import BaseGAN 3 | from .basic_conditional_gan import BasicConditionalGAN 4 | from .mspie_stylegan2 import MSPIEStyleGAN2 5 | from .progressive_growing_unconditional_gan import ProgressiveGrowingGAN 6 | from .singan import PESinGAN, SinGAN 7 | from .static_unconditional_gan import StaticUnconditionalGAN 8 | 9 | __all__ = [ 10 | 'BaseGAN', 'StaticUnconditionalGAN', 'ProgressiveGrowingGAN', 'SinGAN', 11 | 'MSPIEStyleGAN2', 'PESinGAN', 'BasicConditionalGAN' 12 | ] 13 | -------------------------------------------------------------------------------- /mmgen/ops/stylegan3/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | from .ops import filtered_lrelu 11 | 12 | __all__ = ['filtered_lrelu'] 13 | -------------------------------------------------------------------------------- /tools/dist_eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | NNODES=${NNODES:-1} 7 | NODE_RANK=${NODE_RANK:-0} 8 | PORT=${PORT:-29500} 9 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 10 | 11 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 12 | python -m torch.distributed.launch \ 13 | --nnodes=$NNODES \ 14 | --node_rank=$NODE_RANK \ 15 | --master_addr=$MASTER_ADDR \ 16 | --nproc_per_node=$GPUS \ 17 | --master_port=$PORT \ 18 | $(dirname "$0")/test.py \ 19 | $CONFIG \ 20 | $CHECKPOINT \ 21 | --launcher pytorch \ 22 | ${@:4} 23 | -------------------------------------------------------------------------------- /.circleci/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG PYTORCH="1.8.1" 2 | ARG CUDA="10.2" 3 | ARG CUDNN="7" 4 | 5 | FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel 6 | 7 | # To fix GPG key error when running apt-get update 8 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub 9 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub 10 | 11 | RUN apt-get update && apt-get install -y ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 libgl1-mesa-glx 12 | -------------------------------------------------------------------------------- /configs/styleganv3/stylegan3_r_ffhqu_256_b4x8_cvt_official_rgb.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./stylegan3_base.py'] 2 | 3 | synthesis_cfg = { 4 | 'type': 'SynthesisNetwork', 5 | 'channel_base': 32768, 6 | 'channel_max': 1024, 7 | 'magnitude_ema_beta': 0.999, 8 | 'conv_kernel': 1, 9 | 'use_radial_filters': True 10 | } 11 | model = dict( 12 | type='StaticUnconditionalGAN', 13 | generator=dict( 14 | out_size=256, 15 | img_channels=3, 16 | rgb2bgr=True, 17 | synthesis_cfg=synthesis_cfg), 18 | discriminator=dict(in_size=256, channel_multiplier=1)) 19 | -------------------------------------------------------------------------------- /mmgen/core/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .ceph_hooks import PetrelUploadHook 3 | from .ema_hook import ExponentialMovingAverageHook 4 | from .pggan_fetch_data_hook import PGGANFetchDataHook 5 | from .pickle_data_hook import PickleDataHook 6 | from .visualization import VisualizationHook 7 | from .visualize_training_samples import VisualizeUnconditionalSamples 8 | 9 | __all__ = [ 10 | 'VisualizeUnconditionalSamples', 'PGGANFetchDataHook', 11 | 'ExponentialMovingAverageHook', 'VisualizationHook', 'PickleDataHook', 12 | 'PetrelUploadHook' 13 | ] 14 | -------------------------------------------------------------------------------- /tests/test_models/test_base_ddpm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from mmgen.models.diffusions import UniformTimeStepSampler 5 | 6 | 7 | def test_uniform_sampler(): 8 | sampler = UniformTimeStepSampler(10) 9 | timesteps = sampler(2) 10 | assert timesteps.shape == torch.Size([ 11 | 2, 12 | ]) 13 | assert timesteps.max() < 10 and timesteps.min() >= 0 14 | 15 | timesteps = sampler.__call__(2) 16 | assert timesteps.shape == torch.Size([ 17 | 2, 18 | ]) 19 | assert timesteps.max() < 10 and timesteps.min() >= 0 20 | -------------------------------------------------------------------------------- /configs/_base_/models/sngan_proj/sngan_proj_32x32.py: -------------------------------------------------------------------------------- 1 | # define GAN model 2 | model = dict( 3 | type='BasiccGAN', 4 | generator=dict(type='SNGANGenerator', output_scale=32, base_channels=256), 5 | discriminator=dict( 6 | type='ProjDiscriminator', input_scale=32, base_channels=128), 7 | gan_loss=dict(type='GANLoss', gan_type='hinge')) 8 | 9 | train_cfg = dict(disc_steps=5) 10 | test_cfg = None 11 | 12 | # define optimizer 13 | optimizer = dict( 14 | generator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)), 15 | discriminator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999))) 16 | -------------------------------------------------------------------------------- /configs/_base_/models/sngan_proj/sngan_proj_128x128.py: -------------------------------------------------------------------------------- 1 | # define GAN model 2 | model = dict( 3 | type='BasiccGAN', 4 | generator=dict(type='SNGANGenerator', output_scale=128, base_channels=64), 5 | discriminator=dict( 6 | type='ProjDiscriminator', input_scale=128, base_channels=64), 7 | gan_loss=dict(type='GANLoss', gan_type='hinge')) 8 | 9 | train_cfg = dict(disc_steps=2) 10 | test_cfg = None 11 | 12 | # define optimizer 13 | optimizer = dict( 14 | generator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)), 15 | discriminator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999))) 16 | -------------------------------------------------------------------------------- /configs/_base_/default_metrics.py: -------------------------------------------------------------------------------- 1 | metrics = dict( 2 | fid50k=dict(type='FID', num_images=50000), 3 | pr50k3=dict(type='PR', num_images=50000, k=3), 4 | is50k=dict(type='IS', num_images=50000), 5 | ppl_zfull=dict(type='PPL', space='Z', sampling='full', num_images=50000), 6 | ppl_wfull=dict(type='PPL', space='W', sampling='full', num_images=50000), 7 | ppl_zend=dict(type='PPL', space='Z', sampling='end', num_images=50000), 8 | ppl_wend=dict(type='PPL', space='W', sampling='end', num_images=50000), 9 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000), 10 | swd16k=dict(type='SWD', num_images=16384)) 11 | -------------------------------------------------------------------------------- /model-index.yml: -------------------------------------------------------------------------------- 1 | Import: 2 | - configs/ada/metafile.yml 3 | - configs/biggan/metafile.yml 4 | - configs/cyclegan/metafile.yml 5 | - configs/dcgan/metafile.yml 6 | - configs/ggan/metafile.yml 7 | - configs/improved_ddpm/metafile.yml 8 | - configs/lsgan/metafile.yml 9 | - configs/pggan/metafile.yml 10 | - configs/pix2pix/metafile.yml 11 | - configs/positional_encoding_in_gans/metafile.yml 12 | - configs/sagan/metafile.yml 13 | - configs/singan/metafile.yml 14 | - configs/sngan_proj/metafile.yml 15 | - configs/styleganv1/metafile.yml 16 | - configs/styleganv2/metafile.yml 17 | - configs/styleganv3/metafile.yml 18 | - configs/wgan-gp/metafile.yml 19 | -------------------------------------------------------------------------------- /configs/styleganv3/stylegan3_r_ffhq_1024_b4x8_cvt_official_rgb.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./stylegan3_base.py'] 2 | 3 | synthesis_cfg = { 4 | 'type': 'SynthesisNetwork', 5 | 'channel_base': 65536, 6 | 'channel_max': 1024, 7 | 'magnitude_ema_beta': 0.999, 8 | 'conv_kernel': 1, 9 | 'use_radial_filters': True 10 | } 11 | 12 | r1_gamma = 32.8 13 | d_reg_interval = 16 14 | 15 | model = dict( 16 | type='StaticUnconditionalGAN', 17 | generator=dict( 18 | out_size=1024, 19 | img_channels=3, 20 | synthesis_cfg=synthesis_cfg, 21 | rgb2bgr=True), 22 | discriminator=dict(type='StyleGAN2Discriminator', in_size=1024)) 23 | -------------------------------------------------------------------------------- /configs/_base_/models/dcgan/dcgan_64x64.py: -------------------------------------------------------------------------------- 1 | # define GAN model 2 | model = dict( 3 | type='StaticUnconditionalGAN', 4 | generator=dict(type='DCGANGenerator', output_scale=64, base_channels=1024), 5 | discriminator=dict( 6 | type='DCGANDiscriminator', 7 | input_scale=64, 8 | output_scale=4, 9 | out_channels=1), 10 | gan_loss=dict(type='GANLoss', gan_type='vanilla')) 11 | 12 | train_cfg = dict(disc_steps=1) 13 | test_cfg = None 14 | 15 | # define optimizer 16 | optimizer = dict( 17 | generator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)), 18 | discriminator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999))) 19 | -------------------------------------------------------------------------------- /mmgen/models/architectures/stylegan/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .styleganv2_modules import (Blur, ConstantInput, ModulatedConv2d, 3 | ModulatedStyleConv, ModulatedToRGB, 4 | NoiseInjection) 5 | from .styleganv3_modules import (MappingNetwork, SynthesisInput, 6 | SynthesisLayer, SynthesisNetwork) 7 | 8 | __all__ = [ 9 | 'Blur', 'ModulatedStyleConv', 'ModulatedToRGB', 'NoiseInjection', 10 | 'ConstantInput', 'MappingNetwork', 'SynthesisInput', 'SynthesisLayer', 11 | 'SynthesisNetwork', 'ModulatedConv2d' 12 | ] 13 | -------------------------------------------------------------------------------- /mmgen/models/architectures/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | 5 | def get_module_device(module): 6 | """Get the device of a module. 7 | 8 | Args: 9 | module (nn.Module): A module contains the parameters. 10 | 11 | Returns: 12 | torch.device: The device of the module. 13 | """ 14 | try: 15 | next(module.parameters()) 16 | except StopIteration: 17 | raise ValueError('The input module should contain parameters.') 18 | 19 | if next(module.parameters()).is_cuda: 20 | return next(module.parameters()).get_device() 21 | 22 | return torch.device('cpu') 23 | -------------------------------------------------------------------------------- /.circleci/scripts/get_mmcv_var.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | TORCH=$1 4 | CUDA=$2 5 | 6 | # 10.2 -> cu102 7 | MMCV_CUDA="cu`echo ${CUDA} | tr -d '.'`" 8 | 9 | # MMCV only provides pre-compiled packages for torch 1.x.0 10 | # which works for any subversions of torch 1.x. 11 | # We force the torch version to be 1.x.0 to ease package searching 12 | # and avoid unnecessary rebuild during MMCV's installation. 13 | TORCH_VER_ARR=(${TORCH//./ }) 14 | TORCH_VER_ARR[2]=0 15 | printf -v MMCV_TORCH "%s." "${TORCH_VER_ARR[@]}" 16 | MMCV_TORCH=${MMCV_TORCH%?} # Remove the last dot 17 | 18 | echo "export MMCV_CUDA=${MMCV_CUDA}" >> $BASH_ENV 19 | echo "export MMCV_TORCH=${MMCV_TORCH}" >> $BASH_ENV 20 | -------------------------------------------------------------------------------- /.github/workflows/scripts/get_mmcv_var.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | TORCH=$1 4 | CUDA=$2 5 | 6 | # 10.2 -> cu102 7 | MMCV_CUDA="cu`echo ${CUDA} | tr -d '.'`" 8 | 9 | # MMCV only provides pre-compiled packages for torch 1.x.0 10 | # which works for any subversions of torch 1.x. 11 | # We force the torch version to be 1.x.0 to ease package searching 12 | # and avoid unnecessary rebuild during MMCV's installation. 13 | TORCH_VER_ARR=(${TORCH//./ }) 14 | TORCH_VER_ARR[2]=0 15 | printf -v MMCV_TORCH "%s." "${TORCH_VER_ARR[@]}" 16 | MMCV_TORCH=${MMCV_TORCH%?} # Remove the last dot 17 | 18 | echo "MMCV_CUDA=${MMCV_CUDA}" >> $GITHUB_ENV 19 | echo "MMCV_TORCH=${MMCV_TORCH}" >> $GITHUB_ENV 20 | -------------------------------------------------------------------------------- /tools/slurm_eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | CKPT=$4 9 | GPUS=${GPUS:-1} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-1} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | PY_ARGS=${@:5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/evaluation.py ${CONFIG} ${CKPT} --launcher="none" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /configs/_base_/models/dcgan/dcgan_128x128.py: -------------------------------------------------------------------------------- 1 | # define GAN model 2 | model = dict( 3 | type='StaticUnconditionalGAN', 4 | generator=dict( 5 | type='DCGANGenerator', output_scale=128, base_channels=1024), 6 | discriminator=dict( 7 | type='DCGANDiscriminator', 8 | input_scale=128, 9 | output_scale=4, 10 | out_channels=100), 11 | gan_loss=dict(type='GANLoss', gan_type='vanilla')) 12 | 13 | train_cfg = dict(disc_steps=1) 14 | test_cfg = None 15 | 16 | # define optimizer 17 | optimizer = dict( 18 | generator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)), 19 | discriminator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999))) 20 | -------------------------------------------------------------------------------- /configs/_base_/models/lsgan/lsgan_128x128.py: -------------------------------------------------------------------------------- 1 | # define GAN model 2 | model = dict( 3 | type='StaticUnconditionalGAN', 4 | generator=dict( 5 | type='LSGANGenerator', 6 | output_scale=128, 7 | base_channels=256, 8 | noise_size=1024), 9 | discriminator=dict( 10 | type='LSGANDiscriminator', input_scale=128, base_channels=64), 11 | gan_loss=dict(type='GANLoss', gan_type='lsgan')) 12 | 13 | train_cfg = dict(disc_steps=1) 14 | test_cfg = None 15 | 16 | # define optimizer 17 | optimizer = dict( 18 | generator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)), 19 | discriminator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999))) 20 | -------------------------------------------------------------------------------- /tools/slurm_eval_multi_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | CKPT=$4 9 | GPUS=${GPUS:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | PY_ARGS=${@:5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/evaluation.py ${CONFIG} ${CKPT} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /configs/styleganv2/stylegan2_c2_fp16-globalG-partialD_PL-R1-no-scaler_ffhq_256_b4x8_800k.py: -------------------------------------------------------------------------------- 1 | """Config for the `config-f` setting in StyleGAN2.""" 2 | 3 | _base_ = ['./stylegan2_c2_ffhq_256_b4x8_800k.py'] 4 | 5 | model = dict( 6 | generator=dict(out_size=256, fp16_enabled=True), 7 | discriminator=dict(in_size=256, fp16_enabled=False, num_fp16_scales=4), 8 | ) 9 | 10 | total_iters = 800000 11 | 12 | # use ddp wrapper for faster training 13 | use_ddp_wrapper = True 14 | find_unused_parameters = False 15 | 16 | runner = dict( 17 | fp16_loss_scaler=dict(init_scale=512), 18 | is_dynamic_ddp= # noqa 19 | False, # Note that this flag should be False to use DDP wrapper. 20 | ) 21 | -------------------------------------------------------------------------------- /tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | WORK_DIR=$4 9 | GPUS=${GPUS:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | PY_ARGS=${@:5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /requirements/docs.txt: -------------------------------------------------------------------------------- 1 | click 2 | docutils==0.16.0 3 | m2r 4 | mmcls==0.18.0 5 | myst-parser 6 | opencv-python!=4.5.5.62,!=4.5.5.64 7 | # Skip problematic opencv-python versions 8 | # MMCV depends opencv-python instead of headless, thus we install opencv-python 9 | # Due to a bug from upstream, we skip this two version 10 | # https://github.com/opencv/opencv-python/issues/602 11 | # https://github.com/opencv/opencv/issues/21366 12 | # It seems to be fixed in https://github.com/opencv/opencv/pull/21382opencv-python 13 | prettytable 14 | -e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 15 | scipy 16 | sphinx==4.0.2 17 | sphinx-copybutton 18 | sphinx_markdown_tables 19 | -------------------------------------------------------------------------------- /configs/styleganv3/stylegan3_r_afhqv2_512_b4x8_cvt_official_rgb.py: -------------------------------------------------------------------------------- 1 | _base_ = ['./stylegan3_base.py'] 2 | 3 | synthesis_cfg = { 4 | 'type': 'SynthesisNetwork', 5 | 'channel_base': 65536, 6 | 'channel_max': 1024, 7 | 'magnitude_ema_beta': 0.999, 8 | 'conv_kernel': 1, 9 | 'use_radial_filters': True 10 | } 11 | model = dict( 12 | type='StaticUnconditionalGAN', 13 | generator=dict( 14 | type='StyleGANv3Generator', 15 | noise_size=512, 16 | style_channels=512, 17 | out_size=512, 18 | img_channels=3, 19 | rgb2bgr=True, 20 | synthesis_cfg=synthesis_cfg), 21 | discriminator=dict(type='StyleGAN2Discriminator', in_size=512)) 22 | -------------------------------------------------------------------------------- /mmgen/models/architectures/stylegan/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .generator_discriminator_v1 import (StyleGAN1Discriminator, 3 | StyleGANv1Generator) 4 | from .generator_discriminator_v2 import (StyleGAN2Discriminator, 5 | StyleGANv2Generator) 6 | from .generator_discriminator_v3 import StyleGANv3Generator 7 | from .mspie import MSStyleGAN2Discriminator, MSStyleGANv2Generator 8 | 9 | __all__ = [ 10 | 'StyleGAN2Discriminator', 'StyleGANv2Generator', 'StyleGANv1Generator', 11 | 'StyleGAN1Discriminator', 'MSStyleGAN2Discriminator', 12 | 'MSStyleGANv2Generator', 'StyleGANv3Generator' 13 | ] 14 | -------------------------------------------------------------------------------- /mmgen/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .eval_hooks import GenerativeEvalHook, TranslationEvalHook 3 | from .evaluation import (make_metrics_table, make_vanilla_dataloader, 4 | offline_evaluation, online_evaluation) 5 | from .metric_utils import slerp 6 | from .metrics import (IS, MS_SSIM, PR, SWD, GaussianKLD, ms_ssim, 7 | sliced_wasserstein) 8 | 9 | __all__ = [ 10 | 'MS_SSIM', 'SWD', 'ms_ssim', 'sliced_wasserstein', 'offline_evaluation', 11 | 'online_evaluation', 'PR', 'IS', 'slerp', 'GenerativeEvalHook', 12 | 'make_metrics_table', 'make_vanilla_dataloader', 'GaussianKLD', 13 | 'TranslationEvalHook' 14 | ] 15 | -------------------------------------------------------------------------------- /docs/en/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/zh_cn/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /mmgen/models/architectures/pggan/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .generator_discriminator import PGGANDiscriminator, PGGANGenerator 3 | from .modules import (EqualizedLR, EqualizedLRConvDownModule, 4 | EqualizedLRConvModule, EqualizedLRConvUpModule, 5 | EqualizedLRLinearModule, MiniBatchStddevLayer, 6 | PGGANNoiseTo2DFeat, PixelNorm, equalized_lr) 7 | 8 | __all__ = [ 9 | 'EqualizedLR', 'equalized_lr', 'EqualizedLRConvModule', 10 | 'EqualizedLRLinearModule', 'EqualizedLRConvUpModule', 11 | 'EqualizedLRConvDownModule', 'PixelNorm', 'MiniBatchStddevLayer', 12 | 'PGGANNoiseTo2DFeat', 'PGGANGenerator', 'PGGANDiscriminator' 13 | ] 14 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal=1 3 | 4 | [aliases] 5 | test=pytest 6 | 7 | [yapf] 8 | based_on_style=pep8 9 | blank_line_before_nested_class_or_def=true 10 | split_before_expression_after_opening_paren=true 11 | 12 | [isort] 13 | line_length=79 14 | multi_line_output=0 15 | extra_standard_library=argparse,inspect,contextlib,hashlib,subprocess,unittest,tempfile,copy,pkg_resources,logging,pickle,platform,setuptools,abc,collections,functools,os,math,time,warnings,random,shutil,sys 16 | known_first_party=mmgen 17 | known_third_party=PIL,click,clip,cv2,imageio,mmcls,mmcv,numpy,prettytable,pytest,pytorch_sphinx_theme,recommonmark,requests,scipy,torch,torchvision,tqdm,ts 18 | no_lines_before=STDLIB,LOCALFOLDER 19 | default_section=THIRDPARTY 20 | -------------------------------------------------------------------------------- /configs/positional_encoding_in_gans/singan_csg_fish.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../singan/singan_fish.py'] 2 | 3 | num_scales = 10 # start from zero 4 | model = dict( 5 | type='PESinGAN', 6 | generator=dict( 7 | type='SinGANMSGeneratorPE', 8 | num_scales=num_scales, 9 | padding=1, 10 | pad_at_head=False, 11 | first_stage_in_channels=2, 12 | positional_encoding=dict(type='CSG')), 13 | discriminator=dict(num_scales=num_scales)) 14 | 15 | train_cfg = dict(first_fixed_noises_ch=2) 16 | 17 | data = dict( 18 | train=dict( 19 | img_path='./data/singan/fish-crop.jpg', 20 | min_size=25, 21 | max_size=300, 22 | )) 23 | 24 | dist_params = dict(backend='nccl') 25 | total_iters = 22000 26 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/4-documentation.yml: -------------------------------------------------------------------------------- 1 | name: 📚 Documentation 2 | description: Report an issue related to https://mmgeneration.readthedocs.io/en/latest/. 3 | labels: "docs" 4 | title: "[Docs] " 5 | 6 | body: 7 | - type: textarea 8 | attributes: 9 | label: 📚 The doc issue 10 | description: > 11 | A clear and concise description of what content in https://mmgeneration.readthedocs.io/en/latest/ is an issue. 12 | validations: 13 | required: true 14 | 15 | - type: textarea 16 | attributes: 17 | label: Suggest a potential alternative/fix 18 | description: > 19 | Tell us how we could improve the documentation in this regard. 20 | - type: markdown 21 | attributes: 22 | value: > 23 | Thanks for contributing 🎉! 24 | -------------------------------------------------------------------------------- /configs/positional_encoding_in_gans/singan_csg_bohemian.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../singan/singan_bohemian.py'] 2 | 3 | num_scales = 10 # start from zero 4 | model = dict( 5 | type='PESinGAN', 6 | generator=dict( 7 | type='SinGANMSGeneratorPE', 8 | num_scales=num_scales, 9 | padding=1, 10 | pad_at_head=False, 11 | first_stage_in_channels=2, 12 | positional_encoding=dict(type='CSG')), 13 | discriminator=dict(num_scales=num_scales)) 14 | 15 | train_cfg = dict(first_fixed_noises_ch=2) 16 | 17 | data = dict( 18 | train=dict( 19 | img_path='./data/singan/bohemian.png', 20 | min_size=25, 21 | max_size=500, 22 | )) 23 | 24 | dist_params = dict(backend='nccl') 25 | total_iters = 22000 26 | -------------------------------------------------------------------------------- /mmgen/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | __version__ = '0.7.3' 3 | 4 | 5 | def parse_version_info(version_str): 6 | """Parse version information. 7 | 8 | Args: 9 | version_str (str): Version string. 10 | 11 | Returns: 12 | tuple: Version information in tuple. 13 | """ 14 | version_info = [] 15 | for x in version_str.split('.'): 16 | if x.isdigit(): 17 | version_info.append(int(x)) 18 | elif x.find('rc') != -1: 19 | patch_version = x.split('rc') 20 | version_info.append(int(patch_version[0])) 21 | version_info.append(f'rc{patch_version[1]}') 22 | return tuple(version_info) 23 | 24 | 25 | version_info = parse_version_info(__version__) 26 | -------------------------------------------------------------------------------- /mmgen/datasets/quick_test_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | from .builder import DATASETS 6 | 7 | 8 | @DATASETS.register_module() 9 | class QuickTestImageDataset(Dataset): 10 | """Dataset for quickly testing the correctness. 11 | 12 | Args: 13 | size (tuple[int]): The size of the images. Defaults to `None`. 14 | """ 15 | 16 | def __init__(self, *args, size=None, **kwargs): 17 | super().__init__() 18 | self.size = size 19 | self.img_tensor = torch.randn(3, self.size[0], self.size[1]) 20 | 21 | def __len__(self): 22 | return 10000 23 | 24 | def __getitem__(self, idx): 25 | return dict(real_img=self.img_tensor) 26 | -------------------------------------------------------------------------------- /mmgen/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .augmentation import (CenterCropLongEdge, Flip, NumpyPad, 3 | RandomCropLongEdge, RandomImgNoise, Resize) 4 | from .compose import Compose 5 | from .crop import Crop, FixedCrop 6 | from .formatting import Collect, ImageToTensor, ToTensor 7 | from .loading import LoadImageFromFile 8 | from .normalize import Normalize 9 | 10 | __all__ = [ 11 | 'LoadImageFromFile', 12 | 'Compose', 13 | 'ImageToTensor', 14 | 'Collect', 15 | 'ToTensor', 16 | 'Flip', 17 | 'Resize', 18 | 'RandomImgNoise', 19 | 'RandomCropLongEdge', 20 | 'CenterCropLongEdge', 21 | 'Normalize', 22 | 'NumpyPad', 23 | 'Crop', 24 | 'FixedCrop', 25 | ] 26 | -------------------------------------------------------------------------------- /tests/test_datasets/test_dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from torch.utils.data import Dataset 3 | 4 | from mmgen.datasets import RepeatDataset 5 | 6 | 7 | def test_repeat_dataset(): 8 | 9 | class ToyDataset(Dataset): 10 | 11 | def __init__(self): 12 | super(ToyDataset, self).__init__() 13 | self.members = [1, 2, 3, 4, 5] 14 | 15 | def __len__(self): 16 | return len(self.members) 17 | 18 | def __getitem__(self, idx): 19 | return self.members[idx % 5] 20 | 21 | toy_dataset = ToyDataset() 22 | repeat_dataset = RepeatDataset(toy_dataset, 2) 23 | assert len(repeat_dataset) == 10 24 | assert repeat_dataset[2] == 3 25 | assert repeat_dataset[8] == 4 26 | -------------------------------------------------------------------------------- /docs/zh_cn/index.rst: -------------------------------------------------------------------------------- 1 | 欢迎来到 MMGeneration 的用户手册! 2 | ======================================= 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: Get Started 7 | 8 | get_started.md 9 | modelzoo_statistics.md 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | :caption: Quick Run 14 | 15 | quick_run.md 16 | 17 | .. toctree:: 18 | :maxdepth: 2 19 | :caption: Tutorials 20 | 21 | tutorials/index.rst 22 | 23 | .. toctree:: 24 | :maxdepth: 2 25 | :caption: Notes 26 | 27 | changelog.md 28 | faq.md 29 | 30 | .. toctree:: 31 | :caption: Switch Language 32 | 33 | switch_language.md 34 | 35 | .. toctree:: 36 | :caption: API Reference 37 | 38 | api.rst 39 | 40 | Indices and tables 41 | ================== 42 | 43 | * :ref:`genindex` 44 | * :ref:`search` 45 | -------------------------------------------------------------------------------- /configs/ada/metafile.yml: -------------------------------------------------------------------------------- 1 | Collections: 2 | - Metadata: 3 | Architecture: 4 | - ADA 5 | Name: ADA 6 | Paper: 7 | - https://arxiv.org/pdf/2006.06676.pdf 8 | README: configs/ada/README.md 9 | Models: 10 | - Config: https://github.com/open-mmlab/mmgeneration/tree/master/configs/styleganv3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py 11 | In Collection: ADA 12 | Metadata: 13 | Training Data: Others 14 | Name: stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8 15 | Results: 16 | - Dataset: Others 17 | Metrics: 18 | FID50k: 15.09 19 | Iter: 130000.0 20 | Log: '[log]' 21 | Task: Tricks for GANs 22 | Weights: https://download.openmmlab.com/mmgen/stylegan3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8_best_fid_iter_130000_20220401_115101-f2ef498e.pth 23 | -------------------------------------------------------------------------------- /docs/en/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to MMGeneration's documentation! 2 | ======================================= 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: Get Started 7 | 8 | get_started.md 9 | modelzoo_statistics.md 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | :caption: Quick Run 14 | 15 | quick_run.md 16 | 17 | .. toctree:: 18 | :maxdepth: 2 19 | :caption: Tutorials 20 | 21 | tutorials/index.rst 22 | 23 | .. toctree:: 24 | :maxdepth: 2 25 | :caption: Notes 26 | 27 | changelog.md 28 | faq.md 29 | 30 | .. toctree:: 31 | :caption: Switch Language 32 | 33 | switch_language.md 34 | 35 | .. toctree:: 36 | :caption: API Reference 37 | 38 | api.rst 39 | 40 | Indices and tables 41 | ================== 42 | 43 | * :ref:`genindex` 44 | * :ref:`search` 45 | -------------------------------------------------------------------------------- /configs/styleganv2/stylegan2_c2_fp16_partial-GD_PL-no-scaler_ffhq_256_b4x8_800k.py: -------------------------------------------------------------------------------- 1 | """Config for the `config-f` setting in StyleGAN2.""" 2 | 3 | _base_ = ['./stylegan2_c2_ffhq_256_b4x8_800k.py'] 4 | 5 | model = dict( 6 | generator=dict(out_size=256, num_fp16_scales=4), 7 | discriminator=dict(in_size=256, num_fp16_scales=4), 8 | disc_auxiliary_loss=dict(data_info=dict(loss_scaler='loss_scaler')), 9 | # gen_auxiliary_loss=dict(data_info=dict(loss_scaler='loss_scaler')), 10 | ) 11 | 12 | total_iters = 800002 13 | 14 | # use ddp wrapper for faster training 15 | use_ddp_wrapper = True 16 | find_unused_parameters = False 17 | 18 | runner = dict( 19 | fp16_loss_scaler=dict(init_scale=512), 20 | is_dynamic_ddp= # noqa 21 | False, # Note that this flag should be False to use DDP wrapper. 22 | ) 23 | -------------------------------------------------------------------------------- /mmgen/models/architectures/biggan/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .generator_discriminator import BigGANDiscriminator, BigGANGenerator 3 | from .generator_discriminator_deep import (BigGANDeepDiscriminator, 4 | BigGANDeepGenerator) 5 | from .modules import (BigGANConditionBN, BigGANDeepDiscResBlock, 6 | BigGANDeepGenResBlock, BigGANDiscResBlock, 7 | BigGANGenResBlock, SelfAttentionBlock, SNConvModule) 8 | 9 | __all__ = [ 10 | 'BigGANGenerator', 'BigGANGenResBlock', 'BigGANConditionBN', 11 | 'BigGANDiscriminator', 'SelfAttentionBlock', 'BigGANDiscResBlock', 12 | 'BigGANDeepDiscriminator', 'BigGANDeepGenerator', 'BigGANDeepDiscResBlock', 13 | 'BigGANDeepGenResBlock', 'SNConvModule' 14 | ] 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | --- 8 | 9 | **Describe the feature** 10 | 11 | **Motivation** 12 | A clear and concise description of the motivation of the feature. 13 | Ex1. It is inconvenient when \[....\]. 14 | Ex2. There is a recent paper \[....\], which is very helpful for \[....\]. 15 | 16 | **Related resources** 17 | If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated. 22 | -------------------------------------------------------------------------------- /configs/wgan-gp/wgangp_GN_celeba-cropped_128_b64x1_160kiter.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/unconditional_imgs_128x128.py', 3 | '../_base_/models/wgangp/wgangp_base.py' 4 | ] 5 | 6 | data = dict( 7 | samples_per_gpu=64, 8 | train=dict(imgs_root='./data/celeba-cropped/cropped_images_aligned_png/')) 9 | 10 | checkpoint_config = dict(interval=10000, by_epoch=False) 11 | log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')]) 12 | 13 | custom_hooks = [ 14 | dict( 15 | type='VisualizeUnconditionalSamples', 16 | output_dir='training_samples', 17 | interval=1000) 18 | ] 19 | 20 | lr_config = None 21 | total_iters = 160000 22 | 23 | metrics = dict( 24 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000), 25 | swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 128, 128))) 26 | -------------------------------------------------------------------------------- /configs/_base_/datasets/unconditional_imgs_128x128.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'UnconditionalImageDataset' 2 | 3 | train_pipeline = [ 4 | dict(type='LoadImageFromFile', key='real_img', io_backend='disk'), 5 | dict(type='Resize', keys=['real_img'], scale=(128, 128)), 6 | dict( 7 | type='Normalize', 8 | keys=['real_img'], 9 | mean=[127.5] * 3, 10 | std=[127.5] * 3, 11 | to_rgb=False), 12 | dict(type='ImageToTensor', keys=['real_img']), 13 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path']) 14 | ] 15 | 16 | # `samples_per_gpu` and `imgs_root` need to be set. 17 | data = dict( 18 | samples_per_gpu=None, 19 | workers_per_gpu=4, 20 | train=dict(type=dataset_type, imgs_root=None, pipeline=train_pipeline), 21 | val=dict(type=dataset_type, imgs_root=None, pipeline=train_pipeline)) 22 | -------------------------------------------------------------------------------- /configs/_base_/datasets/unconditional_imgs_flip_256x256.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'UnconditionalImageDataset' 2 | 3 | train_pipeline = [ 4 | dict(type='LoadImageFromFile', key='real_img', io_backend='disk'), 5 | dict(type='Resize', keys=['real_img'], scale=(256, 256)), 6 | dict(type='Flip', keys=['real_img'], direction='horizontal'), 7 | dict( 8 | type='Normalize', 9 | keys=['real_img'], 10 | mean=[127.5] * 3, 11 | std=[127.5] * 3, 12 | to_rgb=False), 13 | dict(type='ImageToTensor', keys=['real_img']), 14 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path']) 15 | ] 16 | 17 | # `samples_per_gpu` and `imgs_root` need to be set. 18 | data = dict( 19 | samples_per_gpu=None, 20 | workers_per_gpu=4, 21 | train=dict(type=dataset_type, imgs_root=None, pipeline=train_pipeline)) 22 | -------------------------------------------------------------------------------- /configs/_base_/datasets/unconditional_imgs_flip_512x512.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'UnconditionalImageDataset' 2 | 3 | train_pipeline = [ 4 | dict(type='LoadImageFromFile', key='real_img', io_backend='disk'), 5 | dict(type='Resize', keys=['real_img'], scale=(512, 512)), 6 | dict(type='Flip', keys=['real_img'], direction='horizontal'), 7 | dict( 8 | type='Normalize', 9 | keys=['real_img'], 10 | mean=[127.5] * 3, 11 | std=[127.5] * 3, 12 | to_rgb=False), 13 | dict(type='ImageToTensor', keys=['real_img']), 14 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path']) 15 | ] 16 | 17 | # `samples_per_gpu` and `imgs_root` need to be set. 18 | data = dict( 19 | samples_per_gpu=None, 20 | workers_per_gpu=4, 21 | train=dict(type=dataset_type, imgs_root=None, pipeline=train_pipeline)) 22 | -------------------------------------------------------------------------------- /configs/dcgan/dcgan_lsun-bedroom_64x64_b128x1_5e.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/dcgan/dcgan_64x64.py', 3 | '../_base_/datasets/unconditional_imgs_64x64.py', 4 | '../_base_/default_runtime.py' 5 | ] 6 | 7 | # define dataset 8 | # you must set `samples_per_gpu` and `imgs_root` 9 | data = dict( 10 | samples_per_gpu=128, train=dict(imgs_root='data/lsun/bedroom_train')) 11 | 12 | # adjust running config 13 | lr_config = None 14 | checkpoint_config = dict(interval=100000, by_epoch=False) 15 | custom_hooks = [ 16 | dict( 17 | type='VisualizeUnconditionalSamples', 18 | output_dir='training_samples', 19 | interval=10000) 20 | ] 21 | 22 | total_iters = 1500002 23 | 24 | metrics = dict( 25 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000), 26 | swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 64, 64))) 27 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG PYTORCH="1.8.0" 2 | ARG CUDA="11.1" 3 | ARG CUDNN="8" 4 | 5 | FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel 6 | 7 | ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX" 8 | ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all" 9 | ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" 10 | 11 | RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \ 12 | && apt-get clean \ 13 | && rm -rf /var/lib/apt/lists/* 14 | 15 | # Install MMCV 16 | RUN pip install mmcv-full==1.3.16 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.8.0/index.html 17 | 18 | # Install MMGeneration 19 | RUN conda clean --all 20 | RUN git clone https://github.com/open-mmlab/mmgeneration.git /mmgen 21 | WORKDIR /mmgen 22 | ENV FORCE_CUDA="1" 23 | RUN pip install -r requirements.txt 24 | RUN pip install --no-cache-dir -e . 25 | -------------------------------------------------------------------------------- /tests/test_datasets/test_singan_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | from mmgen.datasets import SinGANDataset 5 | 6 | 7 | class TestSinGANDataset(object): 8 | 9 | @classmethod 10 | def setup_class(cls): 11 | cls.imgs_root = osp.join( 12 | osp.dirname(osp.dirname(__file__)), 'data/image/baboon.png') 13 | cls.min_size = 25 14 | cls.max_size = 250 15 | cls.scale_factor_init = 0.75 16 | 17 | def test_singan_dataset(self): 18 | dataset = SinGANDataset( 19 | self.imgs_root, 20 | min_size=self.min_size, 21 | max_size=self.max_size, 22 | scale_factor_init=self.scale_factor_init) 23 | assert len(dataset) == 1000000 24 | 25 | data_dict = dataset[0] 26 | assert all([f'real_scale{i}' in data_dict for i in range(10)]) 27 | -------------------------------------------------------------------------------- /configs/_base_/datasets/unconditional_imgs_64x64.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'UnconditionalImageDataset' 2 | 3 | train_pipeline = [ 4 | dict( 5 | type='LoadImageFromFile', 6 | key='real_img', 7 | io_backend='disk', 8 | ), 9 | dict(type='Resize', keys=['real_img'], scale=(64, 64)), 10 | dict( 11 | type='Normalize', 12 | keys=['real_img'], 13 | mean=[127.5] * 3, 14 | std=[127.5] * 3, 15 | to_rgb=False), 16 | dict(type='ImageToTensor', keys=['real_img']), 17 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path']) 18 | ] 19 | 20 | # `samples_per_gpu` and `imgs_root` need to be set. 21 | data = dict( 22 | samples_per_gpu=None, 23 | workers_per_gpu=4, 24 | train=dict(type=dataset_type, imgs_root=None, pipeline=train_pipeline), 25 | val=dict(type=dataset_type, imgs_root=None, pipeline=train_pipeline)) 26 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | name: deploy 2 | 3 | on: push 4 | 5 | concurrency: 6 | group: ${{ github.workflow }}-${{ github.ref }} 7 | cancel-in-progress: true 8 | 9 | jobs: 10 | build-n-publish: 11 | runs-on: ubuntu-latest 12 | if: startsWith(github.event.ref, 'refs/tags') 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python 3.7 16 | uses: actions/setup-python@v1 17 | with: 18 | python-version: 3.7 19 | - name: Build MMGeneration 20 | run: | 21 | pip install torch==1.8.1+cpu torchvision==0.9.1+cpu -f https://download.pytorch.org/whl/torch_stable.html 22 | pip install wheel 23 | python setup.py sdist 24 | - name: Publish distribution to PyPI 25 | run: | 26 | pip install twine 27 | twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }} 28 | -------------------------------------------------------------------------------- /docs/en/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /LICENSES.md: -------------------------------------------------------------------------------- 1 | # Licenses for special operations 2 | 3 | In this file, we list the operations with other licenses instead of Apache 2.0. Users should be careful about adopting these operations in any commercial matters. 4 | 5 | | Operation | Files | License | 6 | | :------------------: | :-----------------------------------------------------------------------------------------------------------------------------------: | :------------: | 7 | | conv2d_gradfix | [mmgen/ops/conv2d_gradfix.py](https://github.com/open-mmlab/mmgeneration/blob/master/mmgen/ops/conv2d_gradfix.py) | NVIDIA License | 8 | | compute_pr_distances | [mmgen/core/evaluation/metric_utils.py](https://github.com/open-mmlab/mmgeneration/blob/master/mmgen/core/evaluation/metric_utils.py) | NVIDIA License | 9 | -------------------------------------------------------------------------------- /docs/zh_cn/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /configs/_base_/models/sagan/sagan_32x32.py: -------------------------------------------------------------------------------- 1 | # define GAN model 2 | model = dict( 3 | type='BasiccGAN', 4 | generator=dict( 5 | type='SAGANGenerator', 6 | output_scale=32, 7 | base_channels=256, 8 | attention_cfg=dict(type='SelfAttentionBlock'), 9 | attention_after_nth_block=2, 10 | with_spectral_norm=True), 11 | discriminator=dict( 12 | type='ProjDiscriminator', 13 | input_scale=32, 14 | base_channels=128, 15 | attention_cfg=dict(type='SelfAttentionBlock'), 16 | attention_after_nth_block=1, 17 | with_spectral_norm=True), 18 | gan_loss=dict(type='GANLoss', gan_type='hinge')) 19 | 20 | train_cfg = dict(disc_steps=5) 21 | test_cfg = None 22 | 23 | # define optimizer 24 | optimizer = dict( 25 | generator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999)), 26 | discriminator=dict(type='Adam', lr=0.0002, betas=(0.5, 0.999))) 27 | -------------------------------------------------------------------------------- /configs/_base_/models/sagan/sagan_128x128.py: -------------------------------------------------------------------------------- 1 | # define GAN model 2 | model = dict( 3 | type='BasiccGAN', 4 | generator=dict( 5 | type='SAGANGenerator', 6 | output_scale=128, 7 | base_channels=64, 8 | attention_cfg=dict(type='SelfAttentionBlock'), 9 | attention_after_nth_block=4, 10 | with_spectral_norm=True), 11 | discriminator=dict( 12 | type='ProjDiscriminator', 13 | input_scale=128, 14 | base_channels=64, 15 | attention_cfg=dict(type='SelfAttentionBlock'), 16 | attention_after_nth_block=1, 17 | with_spectral_norm=True), 18 | gan_loss=dict(type='GANLoss', gan_type='hinge')) 19 | 20 | train_cfg = dict(disc_steps=1) 21 | test_cfg = None 22 | 23 | # define optimizer 24 | optimizer = dict( 25 | generator=dict(type='Adam', lr=0.0001, betas=(0.0, 0.999)), 26 | discriminator=dict(type='Adam', lr=0.0004, betas=(0.0, 0.999))) 27 | -------------------------------------------------------------------------------- /configs/positional_encoding_in_gans/singan_spe-dim4_fish.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../singan/singan_fish.py'] 2 | 3 | embedding_dim = 4 4 | num_scales = 10 # start from zero 5 | model = dict( 6 | type='PESinGAN', 7 | generator=dict( 8 | type='SinGANMSGeneratorPE', 9 | num_scales=num_scales, 10 | padding=1, 11 | pad_at_head=False, 12 | first_stage_in_channels=embedding_dim * 2, 13 | positional_encoding=dict( 14 | type='SPE', 15 | embedding_dim=embedding_dim, 16 | padding_idx=0, 17 | init_size=512, 18 | div_half_dim=False, 19 | center_shift=200)), 20 | discriminator=dict(num_scales=num_scales)) 21 | 22 | data = dict( 23 | train=dict( 24 | img_path='./data/singan/fish-crop.jpg', 25 | min_size=25, 26 | max_size=300, 27 | )) 28 | 29 | dist_params = dict(backend='nccl') 30 | total_iters = 22000 31 | -------------------------------------------------------------------------------- /docs/en/faq.md: -------------------------------------------------------------------------------- 1 | # FAQ 2 | 3 | We list some common troubles faced by many users and their corresponding solutions here. Feel free to enrich the list if you find any frequent issues and have ways to help others to solve them. If the contents here do not cover your issue, please create an issue using the [provided templates](https://github.com/open-mmlab/mmgeneration/blob/master/.github/ISSUE_TEMPLATE/error-report.md) and make sure you fill in all required information in the template. 4 | 5 | ## Installation 6 | 7 | - Compatible MMGeneration and MMCV versions are shown as below. Please choose the correct version of MMCV to avoid installation issues. 8 | 9 | | MMGeneration version | MMCV version | 10 | | :------------------: | :--------------: | 11 | | master | mmcv-full>=1.3.0 | 12 | 13 | Note: You need to run `pip uninstall mmcv` first if you have mmcv installed. 14 | If mmcv and mmcv-full are both installed, there will be `ModuleNotFoundError`. 15 | -------------------------------------------------------------------------------- /tests/test_datasets/test_unconditional_image_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | from mmgen.datasets import UnconditionalImageDataset 5 | 6 | 7 | class TestUnconditionalImageDataset(object): 8 | 9 | @classmethod 10 | def setup_class(cls): 11 | cls.imgs_root = osp.join(osp.dirname(__file__), '..', 'data/image') 12 | cls.default_pipeline = [ 13 | dict(type='LoadImageFromFile', io_backend='disk', key='real_img') 14 | ] 15 | 16 | def test_unconditional_imgs_dataset(self): 17 | dataset = UnconditionalImageDataset( 18 | self.imgs_root, pipeline=self.default_pipeline) 19 | assert len(dataset) == 6 20 | img = dataset[2]['real_img'] 21 | assert img.ndim == 3 22 | assert repr(dataset) == ( 23 | f'dataset_name: {dataset.__class__}, ' 24 | f'total {6} images in imgs_root: {self.imgs_root}') 25 | -------------------------------------------------------------------------------- /configs/_base_/datasets/Inception_Score.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'UnconditionalImageDataset' 2 | 3 | # To be noted that, `Resize` operation with `pillow` backend and 4 | # `bicubic` interpolation is the must for correct IS evaluation 5 | val_pipeline = [ 6 | dict( 7 | type='LoadImageFromFile', 8 | key='real_img', 9 | io_backend='disk', 10 | ), 11 | dict( 12 | type='Resize', 13 | keys=['real_img'], 14 | scale=(299, 299), 15 | backend='pillow', 16 | interpolation='bicubic'), 17 | dict( 18 | type='Normalize', 19 | keys=['real_img'], 20 | mean=[127.5] * 3, 21 | std=[127.5] * 3, 22 | to_rgb=True), 23 | dict(type='ImageToTensor', keys=['real_img']), 24 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path']) 25 | ] 26 | 27 | data = dict( 28 | samples_per_gpu=None, 29 | workers_per_gpu=4, 30 | val=dict(type=dataset_type, imgs_root=None, pipeline=val_pipeline)) 31 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: [push, pull_request] 4 | 5 | concurrency: 6 | group: ${{ github.workflow }}-${{ github.ref }} 7 | cancel-in-progress: true 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Python 3.7 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: 3.7 18 | - name: Install pre-commit hook 19 | run: | 20 | pip install pre-commit 21 | pre-commit install 22 | - name: Linting 23 | run: pre-commit run --all-files 24 | - name: Check docstring coverage 25 | run: | 26 | pip install interrogate 27 | interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --exclude mmgen/ops --ignore-regex "__repr__" --fail-under 50 mmgen 28 | - name: Setup tmate session 29 | if: ${{ failure() }} 30 | uses: mxschmitt/action-tmate@v3 31 | -------------------------------------------------------------------------------- /mmgen/core/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import Registry, build_from_cfg 3 | 4 | METRICS = Registry('metric') 5 | 6 | 7 | def build(cfg, registry, default_args=None): 8 | """Build a module. 9 | 10 | Args: 11 | cfg (dict, list[dict]): The config of modules, is is either a dict 12 | or a list of configs. 13 | registry (:obj:`Registry`): A registry the module belongs to. 14 | default_args (dict, optional): Default arguments to build the module. 15 | Defaults to None. 16 | Returns: 17 | nn.Module: A built nn module. 18 | """ 19 | if isinstance(cfg, list): 20 | modules = [ 21 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg 22 | ] 23 | return modules 24 | 25 | return build_from_cfg(cfg, registry, default_args) 26 | 27 | 28 | def build_metric(cfg): 29 | """Build a metric calculator.""" 30 | return build(cfg, METRICS) 31 | -------------------------------------------------------------------------------- /configs/positional_encoding_in_gans/singan_spe-dim4_bohemian.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../singan/singan_bohemian.py'] 2 | 3 | embedding_dim = 4 4 | num_scales = 10 # start from zero 5 | model = dict( 6 | type='PESinGAN', 7 | generator=dict( 8 | type='SinGANMSGeneratorPE', 9 | num_scales=num_scales, 10 | padding=1, 11 | pad_at_head=False, 12 | first_stage_in_channels=embedding_dim * 2, 13 | positional_encoding=dict( 14 | type='SPE', 15 | embedding_dim=embedding_dim, 16 | padding_idx=0, 17 | init_size=512, 18 | div_half_dim=False, 19 | center_shift=200)), 20 | discriminator=dict(num_scales=num_scales)) 21 | 22 | train_cfg = dict(first_fixed_noises_ch=embedding_dim * 2) 23 | 24 | data = dict( 25 | train=dict( 26 | img_path='./data/singan/bohemian.png', 27 | min_size=25, 28 | max_size=500, 29 | )) 30 | 31 | dist_params = dict(backend='nccl') 32 | total_iters = 22000 33 | -------------------------------------------------------------------------------- /configs/positional_encoding_in_gans/singan_spe-dim8_bohemian.py: -------------------------------------------------------------------------------- 1 | _base_ = ['../singan/singan_bohemian.py'] 2 | 3 | embedding_dim = 4 4 | num_scales = 10 # start from zero 5 | model = dict( 6 | type='PESinGAN', 7 | generator=dict( 8 | type='SinGANMSGeneratorPE', 9 | num_scales=num_scales, 10 | padding=1, 11 | pad_at_head=False, 12 | first_stage_in_channels=embedding_dim * 2, 13 | positional_encoding=dict( 14 | type='SPE', 15 | embedding_dim=embedding_dim, 16 | padding_idx=0, 17 | init_size=512, 18 | div_half_dim=False, 19 | center_shift=200)), 20 | discriminator=dict(num_scales=num_scales)) 21 | 22 | train_cfg = dict(first_fixed_noises_ch=embedding_dim * 2) 23 | 24 | data = dict( 25 | train=dict( 26 | img_path='./data/singan/bohemian.png', 27 | min_size=25, 28 | max_size=500, 29 | )) 30 | 31 | dist_params = dict(backend='nccl') 32 | total_iters = 22000 33 | -------------------------------------------------------------------------------- /configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=10000, by_epoch=False) 2 | # yapf:disable 3 | log_config = dict( 4 | interval=100, 5 | hooks=[ 6 | dict(type='TextLoggerHook'), 7 | # dict(type='TensorboardLoggerHook'), 8 | ]) 9 | # yapf:enable 10 | 11 | custom_hooks = [ 12 | dict( 13 | type='VisualizeUnconditionalSamples', 14 | output_dir='training_samples', 15 | interval=1000) 16 | ] 17 | 18 | # use dynamic runner 19 | runner = dict( 20 | type='DynamicIterBasedRunner', 21 | is_dynamic_ddp=True, 22 | pass_training_status=True) 23 | 24 | dist_params = dict(backend='nccl') 25 | log_level = 'INFO' 26 | load_from = None 27 | resume_from = None 28 | workflow = [('train', 10000)] 29 | find_unused_parameters = True 30 | cudnn_benchmark = True 31 | 32 | # disable opencv multithreading to avoid system being overloaded 33 | opencv_num_threads = 0 34 | # set multi-process start method as `fork` to speed up the training 35 | mp_start_method = 'fork' 36 | -------------------------------------------------------------------------------- /configs/styleganv2/stylegan2_c2_apex_fp16_quicktest_ffhq_256_b4x8_800k.py: -------------------------------------------------------------------------------- 1 | """Config for the `config-f` setting in StyleGAN2.""" 2 | 3 | _base_ = ['./stylegan2_c2_ffhq_256_b4x8_800k.py'] 4 | 5 | model = dict( 6 | generator=dict(out_size=256), 7 | discriminator=dict(in_size=256, convert_input_fp32=False), 8 | # disc_auxiliary_loss=dict(use_apex_amp=True), 9 | # gen_auxiliary_loss=dict(use_apex_amp=True), 10 | ) 11 | 12 | dataset_type = 'QuickTestImageDataset' 13 | data = dict( 14 | samples_per_gpu=2, 15 | train=dict(type=dataset_type, size=(256, 256)), 16 | val=dict(type=dataset_type, size=(256, 256))) 17 | 18 | log_config = dict(interval=1) 19 | 20 | total_iters = 800002 21 | 22 | apex_amp = dict( 23 | mode='gan', init_args=dict(opt_level='O1', num_losses=2, loss_scale=512.)) 24 | 25 | evaluation = dict( 26 | type='GenerativeEvalHook', 27 | interval=10000, 28 | metrics=dict( 29 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True), 30 | sample_kwargs=dict(sample_model='ema')) 31 | -------------------------------------------------------------------------------- /configs/styleganv2/stylegan2_c2_fp16_quicktest_ffhq_256_b4x8_800k.py: -------------------------------------------------------------------------------- 1 | """Config for the `config-f` setting in StyleGAN2.""" 2 | 3 | _base_ = ['./stylegan2_c2_ffhq_256_b4x8_800k.py'] 4 | 5 | model = dict( 6 | generator=dict(out_size=256, num_fp16_scales=4), 7 | discriminator=dict(in_size=256, num_fp16_scales=4), 8 | disc_auxiliary_loss=dict(data_info=dict(loss_scaler='loss_scaler')), 9 | # gen_auxiliary_loss=dict(data_info=dict(loss_scaler='loss_scaler')), 10 | ) 11 | 12 | dataset_type = 'QuickTestImageDataset' 13 | data = dict( 14 | samples_per_gpu=2, 15 | train=dict(type=dataset_type, size=(256, 256)), 16 | val=dict(type=dataset_type, size=(256, 256))) 17 | 18 | log_config = dict(interval=1) 19 | 20 | total_iters = 800002 21 | 22 | runner = dict(fp16_loss_scaler=dict(init_scale=512)) 23 | 24 | evaluation = dict( 25 | type='GenerativeEvalHook', 26 | interval=10000, 27 | metrics=dict( 28 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True), 29 | sample_kwargs=dict(sample_model='ema')) 30 | -------------------------------------------------------------------------------- /mmgen/models/common/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.autograd as autograd 4 | import torch.distributed as dist 5 | 6 | 7 | class AllGatherLayer(autograd.Function): 8 | """All gather layer with backward propagation path. 9 | 10 | Indeed, this module is to make ``dist.all_gather()`` in the backward graph. 11 | Such kind of operation has been widely used in Moco and other contrastive 12 | learning algorithms. 13 | """ 14 | 15 | @staticmethod 16 | def forward(ctx, x): 17 | """Forward function.""" 18 | ctx.save_for_backward(x) 19 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 20 | dist.all_gather(output, x) 21 | return tuple(output) 22 | 23 | @staticmethod 24 | def backward(ctx, *grad_outputs): 25 | """Backward function.""" 26 | x, = ctx.saved_tensors 27 | grad_out = torch.zeros_like(x) 28 | grad_out = grad_outputs[dist.get_rank()] 29 | return grad_out 30 | -------------------------------------------------------------------------------- /tests/test_modules/test_lpips.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmgen.models.architectures import PerceptualLoss 6 | 7 | 8 | class TestLpips: 9 | 10 | @classmethod 11 | def setup_class(cls): 12 | cls.pretrained = False 13 | 14 | def test_lpips(self): 15 | percept = PerceptualLoss(use_gpu=False, pretrained=self.pretrained) 16 | img_a = torch.randn((2, 3, 256, 256)) 17 | img_b = torch.randn((2, 3, 256, 256)) 18 | perceptual_loss = percept(img_a, img_b) 19 | assert perceptual_loss.shape == (2, 1, 1, 1) 20 | 21 | @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') 22 | def test_lpips_cuda(self): 23 | percept = PerceptualLoss(use_gpu=True, pretrained=self.pretrained) 24 | img_a = torch.randn((2, 3, 256, 256)).cuda() 25 | img_b = torch.randn((2, 3, 256, 256)).cuda() 26 | perceptual_loss = percept(img_a, img_b) 27 | assert perceptual_loss.shape == (2, 1, 1, 1) 28 | -------------------------------------------------------------------------------- /configs/styleganv2/stylegan2_c2_fp16-global_quicktest_ffhq_256_b4x8_800k.py: -------------------------------------------------------------------------------- 1 | """Config for the `config-f` setting in StyleGAN2.""" 2 | 3 | _base_ = ['./stylegan2_c2_ffhq_256_b4x8_800k.py'] 4 | 5 | model = dict( 6 | generator=dict(out_size=256, fp16_enabled=True), 7 | discriminator=dict(in_size=256, fp16_enabled=True), 8 | disc_auxiliary_loss=dict(data_info=dict(loss_scaler='loss_scaler')), 9 | # gen_auxiliary_loss=dict(data_info=dict(loss_scaler='loss_scaler')), 10 | ) 11 | 12 | dataset_type = 'QuickTestImageDataset' 13 | data = dict( 14 | samples_per_gpu=2, 15 | train=dict(type=dataset_type, size=(256, 256)), 16 | val=dict(type=dataset_type, size=(256, 256))) 17 | 18 | log_config = dict(interval=1) 19 | 20 | total_iters = 800002 21 | 22 | runner = dict(fp16_loss_scaler=dict(init_scale=512)) 23 | 24 | evaluation = dict( 25 | type='GenerativeEvalHook', 26 | interval=10000, 27 | metrics=dict( 28 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True), 29 | sample_kwargs=dict(sample_model='ema')) 30 | -------------------------------------------------------------------------------- /tests/test_datasets/test_persistent_worker.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | from mmgen.datasets.builder import build_dataloader, build_dataset 5 | 6 | 7 | class TestPersistentWorker(object): 8 | 9 | @classmethod 10 | def setup_class(cls): 11 | imgs_root = osp.join(osp.dirname(__file__), '..', 'data/image') 12 | train_pipeline = [ 13 | dict(type='LoadImageFromFile', io_backend='disk', key='real_img') 14 | ] 15 | cls.config = dict( 16 | samples_per_gpu=1, 17 | workers_per_gpu=4, 18 | drop_last=True, 19 | persistent_workers=True) 20 | 21 | cls.data_cfg = dict( 22 | type='UnconditionalImageDataset', 23 | imgs_root=imgs_root, 24 | pipeline=train_pipeline, 25 | test_mode=False) 26 | 27 | def test_persistent_worker(self): 28 | # test non-persistent-worker 29 | dataset = build_dataset(self.data_cfg) 30 | build_dataloader(dataset, **self.config) 31 | -------------------------------------------------------------------------------- /mmgen/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | 4 | from .version import __version__, parse_version_info, version_info 5 | 6 | 7 | def digit_version(version_str): 8 | digit_version = [] 9 | for x in version_str.split('.'): 10 | if x.isdigit(): 11 | digit_version.append(int(x)) 12 | elif x.find('rc') != -1: 13 | patch_version = x.split('rc') 14 | digit_version.append(int(patch_version[0]) - 1) 15 | digit_version.append(int(patch_version[1])) 16 | return digit_version 17 | 18 | 19 | mmcv_minimum_version = '1.3.0' 20 | mmcv_maximum_version = '1.8.0' 21 | mmcv_version = digit_version(mmcv.__version__) 22 | 23 | 24 | assert (mmcv_version >= digit_version(mmcv_minimum_version) 25 | and mmcv_version <= digit_version(mmcv_maximum_version)), \ 26 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 27 | f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.' 28 | 29 | __all__ = ['__version__', 'version_info', 'parse_version_info'] 30 | -------------------------------------------------------------------------------- /configs/_base_/models/biggan/biggan_32x32.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='BasiccGAN', 3 | num_classes=10, 4 | generator=dict( 5 | type='BigGANGenerator', 6 | output_scale=32, 7 | noise_size=128, 8 | num_classes=10, 9 | base_channels=64, 10 | with_shared_embedding=False, 11 | sn_eps=1e-8, 12 | sn_style='torch', 13 | init_type='N02', 14 | split_noise=False, 15 | auto_sync_bn=False), 16 | discriminator=dict( 17 | type='BigGANDiscriminator', 18 | input_scale=32, 19 | num_classes=10, 20 | base_channels=64, 21 | sn_eps=1e-8, 22 | sn_style='torch', 23 | init_type='N02', 24 | with_spectral_norm=True), 25 | gan_loss=dict(type='GANLoss', gan_type='hinge')) 26 | 27 | train_cfg = dict( 28 | disc_steps=4, gen_steps=1, batch_accumulation_steps=1, use_ema=True) 29 | test_cfg = None 30 | optimizer = dict( 31 | generator=dict(type='Adam', lr=0.0002, betas=(0.0, 0.999)), 32 | discriminator=dict(type='Adam', lr=0.0002, betas=(0.0, 0.999))) 33 | -------------------------------------------------------------------------------- /mmgen/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import build_dataloader, build_dataset 3 | from .dataset_wrappers import RepeatDataset 4 | from .grow_scale_image_dataset import GrowScaleImgDataset 5 | from .paired_image_dataset import PairedImageDataset 6 | from .pipelines import (Collect, Compose, Flip, ImageToTensor, 7 | LoadImageFromFile, Normalize, Resize, ToTensor) 8 | from .quick_test_dataset import QuickTestImageDataset 9 | from .samplers import DistributedSampler 10 | from .singan_dataset import SinGANDataset 11 | from .unconditional_image_dataset import UnconditionalImageDataset 12 | from .unpaired_image_dataset import UnpairedImageDataset 13 | 14 | __all__ = [ 15 | 'build_dataloader', 'build_dataset', 'LoadImageFromFile', 16 | 'DistributedSampler', 'UnconditionalImageDataset', 'Compose', 'ToTensor', 17 | 'ImageToTensor', 'Collect', 'Flip', 'Resize', 'RepeatDataset', 'Normalize', 18 | 'GrowScaleImgDataset', 'SinGANDataset', 'PairedImageDataset', 19 | 'UnpairedImageDataset', 'QuickTestImageDataset' 20 | ] 21 | -------------------------------------------------------------------------------- /configs/_base_/datasets/unconditional_imgs_flip_lanczos_resize_256x256.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'UnconditionalImageDataset' 2 | 3 | train_pipeline = [ 4 | dict(type='LoadImageFromFile', key='real_img'), 5 | dict( 6 | type='Resize', 7 | keys=['real_img'], 8 | scale=(256, 256), 9 | interpolation='lanczos', 10 | backend='pillow'), 11 | dict(type='Flip', keys=['real_img'], direction='horizontal'), 12 | dict( 13 | type='Normalize', 14 | keys=['real_img'], 15 | mean=[127.5] * 3, 16 | std=[127.5] * 3, 17 | to_rgb=False), 18 | dict(type='ImageToTensor', keys=['real_img']), 19 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path']) 20 | ] 21 | 22 | # `samples_per_gpu` and `imgs_root` need to be set. 23 | data = dict( 24 | samples_per_gpu=None, 25 | workers_per_gpu=4, 26 | train=dict( 27 | type='RepeatDataset', 28 | times=100, 29 | dataset=dict( 30 | type=dataset_type, imgs_root=None, pipeline=train_pipeline)), 31 | val=dict(type=dataset_type, imgs_root=None, pipeline=train_pipeline)) 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/2-feature-request.yml: -------------------------------------------------------------------------------- 1 | name: 🚀 Feature request 2 | description: Suggest an idea for this project 3 | labels: [feature-request] 4 | title: "[Feature] " 5 | 6 | body: 7 | - type: markdown 8 | attributes: 9 | value: | 10 | We strongly appreciate you creating a PR to implete this feature [here](https://github.com/open-mmlab/mmgeneration/pulls)! 11 | If you need our help, please fill in as much of the following form as you're able. 12 | 13 | **The less clear the description, the longer it will take to solve it.** 14 | 15 | - type: textarea 16 | attributes: 17 | label: What is the problem this feature will solve? 18 | placeholder: | 19 | E.g., It is inconvenient when \[....\]. 20 | validations: 21 | required: true 22 | 23 | - type: textarea 24 | attributes: 25 | label: What is the feature? 26 | validations: 27 | required: true 28 | 29 | - type: textarea 30 | attributes: 31 | label: What alternatives have you considered? 32 | description: | 33 | Add any other context or screenshots about the feature request here. 34 | -------------------------------------------------------------------------------- /configs/_base_/models/biggan/biggan_128x128.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='BasiccGAN', 3 | generator=dict( 4 | type='BigGANGenerator', 5 | output_scale=128, 6 | noise_size=120, 7 | num_classes=1000, 8 | base_channels=96, 9 | shared_dim=128, 10 | with_shared_embedding=True, 11 | sn_eps=1e-6, 12 | init_type='ortho', 13 | act_cfg=dict(type='ReLU', inplace=True), 14 | split_noise=True, 15 | auto_sync_bn=False), 16 | discriminator=dict( 17 | type='BigGANDiscriminator', 18 | input_scale=128, 19 | num_classes=1000, 20 | base_channels=96, 21 | sn_eps=1e-6, 22 | init_type='ortho', 23 | act_cfg=dict(type='ReLU', inplace=True), 24 | with_spectral_norm=True), 25 | gan_loss=dict(type='GANLoss', gan_type='hinge')) 26 | 27 | train_cfg = dict( 28 | disc_steps=8, gen_steps=1, batch_accumulation_steps=8, use_ema=True) 29 | test_cfg = None 30 | optimizer = dict( 31 | generator=dict(type='Adam', lr=0.0001, betas=(0.0, 0.999), eps=1e-6), 32 | discriminator=dict(type='Adam', lr=0.0004, betas=(0.0, 0.999), eps=1e-6)) 33 | -------------------------------------------------------------------------------- /docs/en/stat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import glob 3 | import os.path as osp 4 | import re 5 | 6 | url_prefix = 'https://github.com/open-mmlab/mmgeneration/blob/master/' 7 | 8 | files = sorted(glob.glob('../../configs/*/README.md')) 9 | 10 | stats = [] 11 | titles = [] 12 | num_ckpts = 0 13 | 14 | for f in files: 15 | url = osp.dirname(f.replace('../../', url_prefix)) 16 | 17 | with open(f, 'r') as content_file: 18 | content = content_file.read() 19 | 20 | title = content.split('\n')[0].replace('# ', '') 21 | 22 | titles.append(title) 23 | ckpts = set(x.lower().strip() 24 | for x in re.findall(r'https?://download.(.*?)\.pth', content) 25 | if 'mmgen' in x) 26 | 27 | num_ckpts += len(ckpts) 28 | statsmsg = f""" 29 | \t* [{title}]({url}) ({len(ckpts)} ckpts) 30 | """ 31 | stats.append((title, ckpts, statsmsg)) 32 | 33 | msglist = '\n'.join(x for _, _, x in stats) 34 | 35 | modelzoo = f""" 36 | # Model Zoo Statistics 37 | 38 | * Number of papers: {len(titles)} 39 | * Number of checkpoints: {num_ckpts} 40 | {msglist} 41 | """ 42 | 43 | with open('modelzoo_statistics.md', 'w') as f: 44 | f.write(modelzoo) 45 | -------------------------------------------------------------------------------- /docs/zh_cn/stat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import glob 3 | import os.path as osp 4 | import re 5 | 6 | url_prefix = 'https://github.com/open-mmlab/mmgeneration/blob/master/' 7 | 8 | files = sorted(glob.glob('../../configs/*/README.md')) 9 | 10 | stats = [] 11 | titles = [] 12 | num_ckpts = 0 13 | 14 | for f in files: 15 | url = osp.dirname(f.replace('../', url_prefix)) 16 | 17 | with open(f, 'r') as content_file: 18 | content = content_file.read() 19 | 20 | title = content.split('\n')[0].replace('# ', '') 21 | 22 | titles.append(title) 23 | ckpts = set(x.lower().strip() 24 | for x in re.findall(r'https?://download.(.*?)\.pth', content) 25 | if 'mmgen' in x) 26 | 27 | num_ckpts += len(ckpts) 28 | statsmsg = f""" 29 | \t* [{title}]({url}) ({len(ckpts)} ckpts) 30 | """ 31 | stats.append((title, ckpts, statsmsg)) 32 | 33 | msglist = '\n'.join(x for _, _, x in stats) 34 | 35 | modelzoo = f""" 36 | # Model Zoo Statistics 37 | 38 | * Number of papers: {len(titles)} 39 | * Number of checkpoints: {num_ckpts} 40 | {msglist} 41 | """ 42 | 43 | with open('modelzoo_statistics.md', 'w') as f: 44 | f.write(modelzoo) 45 | -------------------------------------------------------------------------------- /configs/dcgan/dcgan_celeba-cropped_64_b128x1_300k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/dcgan/dcgan_64x64.py', 3 | '../_base_/datasets/unconditional_imgs_64x64.py', 4 | '../_base_/default_runtime.py' 5 | ] 6 | 7 | # define dataset 8 | # you must set `samples_per_gpu` and `imgs_root` 9 | data = dict( 10 | samples_per_gpu=128, 11 | train=dict(imgs_root='data/celeba-cropped/cropped_images_aligned_png')) 12 | 13 | # adjust running config 14 | lr_config = None 15 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=20) 16 | custom_hooks = [ 17 | dict( 18 | type='VisualizeUnconditionalSamples', 19 | output_dir='training_samples', 20 | interval=10000) 21 | ] 22 | 23 | total_iters = 300002 24 | 25 | # use ddp wrapper for faster training 26 | use_ddp_wrapper = True 27 | find_unused_parameters = False 28 | 29 | runner = dict( 30 | type='DynamicIterBasedRunner', 31 | is_dynamic_ddp=False, # Note that this flag should be False. 32 | pass_training_status=True) 33 | 34 | metrics = dict( 35 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000), 36 | swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 64, 64))) 37 | -------------------------------------------------------------------------------- /configs/_base_/datasets/lsun-car_pad_512.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'UnconditionalImageDataset' 2 | 3 | train_pipeline = [ 4 | dict( 5 | type='LoadImageFromFile', 6 | key='real_img', 7 | io_backend='disk', 8 | ), 9 | dict(type='Resize', keys=['real_img'], scale=(512, 384)), 10 | dict( 11 | type='NumpyPad', 12 | keys=['real_img'], 13 | padding=((64, 64), (0, 0), (0, 0)), 14 | ), 15 | dict(type='Flip', keys=['real_img'], direction='horizontal'), 16 | dict( 17 | type='Normalize', 18 | keys=['real_img'], 19 | mean=[127.5] * 3, 20 | std=[127.5] * 3, 21 | to_rgb=False), 22 | dict(type='ImageToTensor', keys=['real_img']), 23 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path']) 24 | ] 25 | 26 | # `samples_per_gpu` and `imgs_root` need to be set. 27 | data = dict( 28 | samples_per_gpu=None, 29 | workers_per_gpu=4, 30 | train=dict( 31 | type='RepeatDataset', 32 | times=5, 33 | dataset=dict( 34 | type=dataset_type, imgs_root=None, pipeline=train_pipeline)), 35 | val=dict(type=dataset_type, imgs_root=None, pipeline=train_pipeline)) 36 | -------------------------------------------------------------------------------- /configs/biggan/biggan_128x128_cvt_BigGAN-PyTorch_rgb.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='BasiccGAN', 3 | generator=dict( 4 | type='BigGANGenerator', 5 | output_scale=128, 6 | noise_size=120, 7 | num_classes=1000, 8 | base_channels=96, 9 | shared_dim=128, 10 | with_shared_embedding=True, 11 | sn_eps=1e-6, 12 | init_type='ortho', 13 | act_cfg=dict(type='ReLU', inplace=True), 14 | split_noise=True, 15 | auto_sync_bn=False, 16 | rgb2bgr=True), 17 | discriminator=dict( 18 | type='BigGANDiscriminator', 19 | input_scale=128, 20 | num_classes=1000, 21 | base_channels=96, 22 | sn_eps=1e-6, 23 | init_type='ortho', 24 | act_cfg=dict(type='ReLU', inplace=True), 25 | with_spectral_norm=True), 26 | gan_loss=dict(type='GANLoss', gan_type='hinge')) 27 | 28 | train_cfg = dict( 29 | disc_steps=8, gen_steps=1, batch_accumulation_steps=8, use_ema=True) 30 | test_cfg = None 31 | optimizer = dict( 32 | generator=dict(type='Adam', lr=0.0001, betas=(0.0, 0.999), eps=1e-6), 33 | discriminator=dict(type='Adam', lr=0.0004, betas=(0.0, 0.999), eps=1e-6)) 34 | -------------------------------------------------------------------------------- /configs/singan/singan_balloons.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/singan/singan.py', '../_base_/datasets/singan.py', 3 | '../_base_/default_runtime.py' 4 | ] 5 | 6 | num_scales = 8 # start from zero 7 | model = dict( 8 | generator=dict(num_scales=num_scales), 9 | discriminator=dict(num_scales=num_scales)) 10 | 11 | train_cfg = dict( 12 | noise_weight_init=0.1, 13 | iters_per_scale=2000, 14 | ) 15 | 16 | # test_cfg = dict( 17 | # _delete_ = True 18 | # pkl_data = 'path to pkl data' 19 | # ) 20 | 21 | data = dict(train=dict(img_path='./data/singan/balloons.png')) 22 | 23 | optimizer = None 24 | lr_config = None 25 | checkpoint_config = dict(by_epoch=False, interval=2000, max_keep_ckpts=3) 26 | 27 | custom_hooks = [ 28 | dict( 29 | type='MMGenVisualizationHook', 30 | output_dir='visual', 31 | interval=500, 32 | bgr2rgb=True, 33 | res_name_list=['fake_imgs', 'recon_imgs', 'real_imgs']), 34 | dict( 35 | type='PickleDataHook', 36 | output_dir='pickle', 37 | interval=-1, 38 | after_run=True, 39 | data_name_list=['noise_weights', 'fixed_noises', 'curr_stage']) 40 | ] 41 | 42 | total_iters = 18000 43 | -------------------------------------------------------------------------------- /mmgen/datasets/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import DATASETS 3 | 4 | 5 | @DATASETS.register_module() 6 | class RepeatDataset: 7 | """A wrapper of repeated dataset. 8 | 9 | The length of repeated dataset will be `times` larger than the original 10 | dataset. This is useful when the data loading time is long but the dataset 11 | is small. Using RepeatDataset can reduce the data loading time between 12 | epochs. 13 | 14 | Args: 15 | dataset (:obj:`Dataset`): The dataset to be repeated. 16 | times (int): Repeat times. 17 | """ 18 | 19 | def __init__(self, dataset, times): 20 | self.dataset = dataset 21 | self.times = times 22 | 23 | self._ori_len = len(self.dataset) 24 | 25 | def __getitem__(self, idx): 26 | """Get item at each call. 27 | 28 | Args: 29 | idx (int): Index for getting each item. 30 | """ 31 | return self.dataset[idx % self._ori_len] 32 | 33 | def __len__(self): 34 | """Length of the dataset. 35 | 36 | Returns: 37 | int: Length of the dataset. 38 | """ 39 | return self.times * self._ori_len 40 | -------------------------------------------------------------------------------- /mmgen/models/architectures/stylegan/ada/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | 5 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 6 | # same constant is used multiple times. 7 | 8 | _constant_cache = dict() 9 | 10 | 11 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 12 | value = np.asarray(value) 13 | if shape is not None: 14 | shape = tuple(shape) 15 | if dtype is None: 16 | dtype = torch.get_default_dtype() 17 | if device is None: 18 | device = torch.device('cpu') 19 | if memory_format is None: 20 | memory_format = torch.contiguous_format 21 | 22 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, 23 | memory_format) 24 | tensor = _constant_cache.get(key, None) 25 | if tensor is None: 26 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 27 | if shape is not None: 28 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 29 | tensor = tensor.contiguous(memory_format=memory_format) 30 | _constant_cache[key] = tensor 31 | return tensor 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/3-new-model.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F31F New model/dataset/scheduler addition" 2 | description: Submit a proposal/request to implement a new model / dataset / scheduler 3 | labels: [ "feature-request" ] 4 | title: "[New Models] " 5 | 6 | 7 | body: 8 | - type: textarea 9 | id: description-request 10 | validations: 11 | required: true 12 | attributes: 13 | label: Model/Dataset/Scheduler description 14 | description: | 15 | Put any and all important information relative to the model/dataset/scheduler 16 | 17 | - type: checkboxes 18 | attributes: 19 | label: Open source status 20 | description: | 21 | Please provide the open-source status, which would be very helpful 22 | options: 23 | - label: "The model implementation is available" 24 | - label: "The model weights are available." 25 | 26 | - type: textarea 27 | id: additional-info 28 | attributes: 29 | label: Provide useful links for the implementation 30 | description: | 31 | Please provide information regarding the implementation, the weights, and the authors. 32 | Please mention the authors by @gh-username if you're aware of their usernames. 33 | -------------------------------------------------------------------------------- /mmgen/core/hooks/pggan_fetch_data_hook.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcv.parallel import is_module_wrapper 4 | from mmcv.runner import HOOKS, Hook 5 | 6 | 7 | @HOOKS.register_module() 8 | class PGGANFetchDataHook(Hook): 9 | """PGGAN Fetch Data Hook. 10 | 11 | Args: 12 | interval (int, optional): The interval of calling this hook. If set 13 | to -1, the visualization hook will not be called. Defaults to 1. 14 | """ 15 | 16 | def __init__(self, interval=1): 17 | super().__init__() 18 | self.interval = interval 19 | 20 | def before_fetch_train_data(self, runner): 21 | """The behavior before fetch train data. 22 | 23 | Args: 24 | runner (object): The runner. 25 | """ 26 | if not self.every_n_iters(runner, self.interval): 27 | return 28 | _module = runner.model.module if is_module_wrapper( 29 | runner.model) else runner.model 30 | 31 | _next_scale_int = _module._next_scale_int 32 | if isinstance(_next_scale_int, torch.Tensor): 33 | _next_scale_int = _next_scale_int.item() 34 | runner.data_loader.update_dataloader(_next_scale_int) 35 | -------------------------------------------------------------------------------- /tests/test_modules/test_fid_inception.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import pytest 3 | import torch 4 | 5 | from mmgen.models.architectures import InceptionV3 6 | 7 | 8 | class TestFIDInception: 9 | 10 | @classmethod 11 | def setup_class(cls): 12 | cls.load_fid_inception = False 13 | 14 | def test_fid_inception(self): 15 | inception = InceptionV3(load_fid_inception=self.load_fid_inception) 16 | imgs = torch.randn((2, 3, 256, 256)) 17 | out = inception(imgs)[0] 18 | assert out.shape == (2, 2048, 1, 1) 19 | 20 | imgs = torch.randn((2, 3, 512, 512)) 21 | out = inception(imgs)[0] 22 | assert out.shape == (2, 2048, 1, 1) 23 | 24 | @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') 25 | def test_fid_inception_cuda(self): 26 | inception = InceptionV3( 27 | load_fid_inception=self.load_fid_inception).cuda() 28 | imgs = torch.randn((2, 3, 256, 256)).cuda() 29 | out = inception(imgs)[0] 30 | assert out.shape == (2, 2048, 1, 1) 31 | 32 | imgs = torch.randn((2, 3, 512, 512)).cuda() 33 | out = inception(imgs)[0] 34 | assert out.shape == (2, 2048, 1, 1) 35 | -------------------------------------------------------------------------------- /configs/_base_/models/wgangp/wgangp_base.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='StaticUnconditionalGAN', 3 | generator=dict(type='WGANGPGenerator', noise_size=128, out_scale=128), 4 | discriminator=dict( 5 | type='WGANGPDiscriminator', 6 | in_channel=3, 7 | in_scale=128, 8 | conv_module_cfg=dict( 9 | conv_cfg=None, 10 | kernel_size=3, 11 | stride=1, 12 | padding=1, 13 | bias=True, 14 | act_cfg=dict(type='LeakyReLU', negative_slope=0.2), 15 | norm_cfg=dict(type='GN'), 16 | order=('conv', 'norm', 'act'))), 17 | gan_loss=dict(type='GANLoss', gan_type='wgan'), 18 | disc_auxiliary_loss=[ 19 | dict( 20 | type='GradientPenaltyLoss', 21 | loss_weight=10, 22 | norm_mode='HWC', 23 | data_info=dict( 24 | discriminator='disc', 25 | real_data='real_imgs', 26 | fake_data='fake_imgs')) 27 | ]) 28 | 29 | train_cfg = dict(disc_steps=5) 30 | 31 | test_cfg = None 32 | 33 | optimizer = dict( 34 | generator=dict(type='Adam', lr=0.0001, betas=(0.5, 0.9)), 35 | discriminator=dict(type='Adam', lr=0.0001, betas=(0.5, 0.9))) 36 | -------------------------------------------------------------------------------- /configs/_base_/models/stylegan/styleganv1_base.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='StyleGANV1', 3 | generator=dict( 4 | type='StyleGANv1Generator', out_size=None, style_channels=512), 5 | discriminator=dict(type='StyleGAN1Discriminator', in_size=None), 6 | gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'), 7 | disc_auxiliary_loss=[ 8 | dict( 9 | type='R1GradientPenalty', 10 | loss_weight=10, 11 | norm_mode='HWC', 12 | data_info=dict( 13 | discriminator='disc_partial', real_data='real_imgs')) 14 | ]) 15 | 16 | train_cfg = dict( 17 | use_ema=True, 18 | transition_kimgs=600, 19 | optimizer_cfg=dict( 20 | generator=dict(type='Adam', lr=0.001, betas=(0.0, 0.99)), 21 | discriminator=dict(type='Adam', lr=0.001, betas=(0.0, 0.99))), 22 | g_lr_base=0.001, 23 | d_lr_base=0.001, 24 | g_lr_schedule=dict({ 25 | '128': 0.0015, 26 | '256': 0.002, 27 | '512': 0.003, 28 | '1024': 0.003 29 | }), 30 | d_lr_schedule=dict({ 31 | '128': 0.0015, 32 | '256': 0.002, 33 | '512': 0.003, 34 | '1024': 0.003 35 | })) 36 | 37 | test_cfg = None 38 | optimizer = None 39 | -------------------------------------------------------------------------------- /configs/singan/singan_fish.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/singan/singan.py', '../_base_/datasets/singan.py', 3 | '../_base_/default_runtime.py' 4 | ] 5 | 6 | num_scales = 10 # start from zero 7 | model = dict( 8 | generator=dict(num_scales=num_scales), 9 | discriminator=dict(num_scales=num_scales)) 10 | 11 | train_cfg = dict( 12 | noise_weight_init=0.1, 13 | iters_per_scale=2000, 14 | ) 15 | 16 | # test_cfg = dict( 17 | # _delete_ = True 18 | # pkl_data = 'path to pkl data' 19 | # ) 20 | 21 | data = dict( 22 | train=dict( 23 | img_path='./data/singan/fish-crop.jpg', min_size=25, max_size=300)) 24 | 25 | optimizer = None 26 | lr_config = None 27 | checkpoint_config = dict(by_epoch=False, interval=2000, max_keep_ckpts=3) 28 | 29 | custom_hooks = [ 30 | dict( 31 | type='MMGenVisualizationHook', 32 | output_dir='visual', 33 | interval=500, 34 | bgr2rgb=True, 35 | res_name_list=['fake_imgs', 'recon_imgs', 'real_imgs']), 36 | dict( 37 | type='PickleDataHook', 38 | output_dir='pickle', 39 | interval=-1, 40 | after_run=True, 41 | data_name_list=['noise_weights', 'fixed_noises', 'curr_stage']) 42 | ] 43 | 44 | total_iters = 22000 45 | -------------------------------------------------------------------------------- /configs/singan/singan_bohemian.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/singan/singan.py', '../_base_/datasets/singan.py', 3 | '../_base_/default_runtime.py' 4 | ] 5 | 6 | num_scales = 10 # start from zero 7 | model = dict( 8 | generator=dict(num_scales=num_scales), 9 | discriminator=dict(num_scales=num_scales)) 10 | 11 | train_cfg = dict( 12 | noise_weight_init=0.1, 13 | iters_per_scale=2000, 14 | ) 15 | 16 | # test_cfg = dict( 17 | # _delete_ = True 18 | # pkl_data = 'path to pkl data' 19 | # ) 20 | 21 | data = dict( 22 | train=dict( 23 | img_path='./data/singan/bohemian.png', min_size=25, max_size=500)) 24 | 25 | optimizer = None 26 | lr_config = None 27 | checkpoint_config = dict(by_epoch=False, interval=2000, max_keep_ckpts=3) 28 | 29 | custom_hooks = [ 30 | dict( 31 | type='MMGenVisualizationHook', 32 | output_dir='visual', 33 | interval=500, 34 | bgr2rgb=True, 35 | res_name_list=['fake_imgs', 'recon_imgs', 'real_imgs']), 36 | dict( 37 | type='PickleDataHook', 38 | output_dir='pickle', 39 | interval=-1, 40 | after_run=True, 41 | data_name_list=['noise_weights', 'fixed_noises', 'curr_stage']) 42 | ] 43 | 44 | total_iters = 22000 45 | -------------------------------------------------------------------------------- /docs/en/api.rst: -------------------------------------------------------------------------------- 1 | mmgen.apis 2 | -------------- 3 | .. automodule:: mmgen.apis 4 | :members: 5 | 6 | mmgen.core 7 | -------------- 8 | 9 | evaluation 10 | ^^^^^^^^^^ 11 | .. automodule:: mmgen.core.evaluation 12 | :members: 13 | 14 | hooks 15 | ^^^^^^^^^^ 16 | .. automodule:: mmgen.core.hooks 17 | :members: 18 | 19 | optimizer 20 | ^^^^^^^^^^ 21 | .. automodule:: mmgen.core.optimizer 22 | :members: 23 | 24 | runners 25 | ^^^^^^^^^^ 26 | .. automodule:: mmgen.core.runners 27 | :members: 28 | 29 | scheduler 30 | ^^^^^^^^^^ 31 | .. automodule:: mmgen.core.scheduler 32 | :members: 33 | 34 | 35 | 36 | mmgen.datasets 37 | -------------- 38 | 39 | datasets 40 | ^^^^^^^^^^ 41 | .. automodule:: mmgen.datasets 42 | :members: 43 | 44 | pipelines 45 | ^^^^^^^^^^ 46 | .. automodule:: mmgen.datasets.pipelines 47 | :members: 48 | 49 | mmgen.models 50 | -------------- 51 | 52 | architectures 53 | ^^^^^^^^^^ 54 | .. automodule:: mmgen.models.architectures 55 | :members: 56 | 57 | common 58 | ^^^^^^^^^^ 59 | .. automodule:: mmgen.models.common 60 | :members: 61 | 62 | gans 63 | ^^^^^^^^^^^^ 64 | .. automodule:: mmgen.models.gans 65 | :members: 66 | 67 | losses 68 | ^^^^^^^^^^^^ 69 | .. automodule:: mmgen.models.losses 70 | :members: 71 | -------------------------------------------------------------------------------- /configs/biggan/biggan-deep_128x128_cvt_hugging-face_rgb.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='BasiccGAN', 3 | generator=dict( 4 | type='BigGANDeepGenerator', 5 | output_scale=128, 6 | noise_size=128, 7 | num_classes=1000, 8 | base_channels=128, 9 | shared_dim=128, 10 | with_shared_embedding=True, 11 | sn_eps=1e-6, 12 | sn_style='torch', 13 | init_type='ortho', 14 | act_cfg=dict(type='ReLU', inplace=True), 15 | concat_noise=True, 16 | auto_sync_bn=False, 17 | rgb2bgr=True), 18 | discriminator=dict( 19 | type='BigGANDeepDiscriminator', 20 | input_scale=128, 21 | num_classes=1000, 22 | base_channels=128, 23 | sn_eps=1e-6, 24 | sn_style='torch', 25 | init_type='ortho', 26 | act_cfg=dict(type='ReLU', inplace=True), 27 | with_spectral_norm=True), 28 | gan_loss=dict(type='GANLoss', gan_type='hinge')) 29 | 30 | train_cfg = dict( 31 | disc_steps=8, gen_steps=1, batch_accumulation_steps=8, use_ema=True) 32 | test_cfg = None 33 | optimizer = dict( 34 | generator=dict(type='Adam', lr=0.0001, betas=(0.0, 0.999), eps=1e-6), 35 | discriminator=dict(type='Adam', lr=0.0004, betas=(0.0, 0.999), eps=1e-6)) 36 | -------------------------------------------------------------------------------- /configs/biggan/biggan-deep_256x256_cvt_hugging-face_rgb.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='BasiccGAN', 3 | generator=dict( 4 | type='BigGANDeepGenerator', 5 | output_scale=256, 6 | noise_size=128, 7 | num_classes=1000, 8 | base_channels=128, 9 | shared_dim=128, 10 | with_shared_embedding=True, 11 | sn_eps=1e-6, 12 | sn_style='torch', 13 | init_type='ortho', 14 | act_cfg=dict(type='ReLU', inplace=True), 15 | concat_noise=True, 16 | auto_sync_bn=False, 17 | rgb2bgr=True), 18 | discriminator=dict( 19 | type='BigGANDeepDiscriminator', 20 | input_scale=256, 21 | num_classes=1000, 22 | base_channels=128, 23 | sn_eps=1e-6, 24 | sn_style='torch', 25 | init_type='ortho', 26 | act_cfg=dict(type='ReLU', inplace=True), 27 | with_spectral_norm=True), 28 | gan_loss=dict(type='GANLoss', gan_type='hinge')) 29 | 30 | train_cfg = dict( 31 | disc_steps=8, gen_steps=1, batch_accumulation_steps=8, use_ema=True) 32 | test_cfg = None 33 | optimizer = dict( 34 | generator=dict(type='Adam', lr=0.0001, betas=(0.0, 0.999), eps=1e-6), 35 | discriminator=dict(type='Adam', lr=0.0004, betas=(0.0, 0.999), eps=1e-6)) 36 | -------------------------------------------------------------------------------- /configs/biggan/biggan-deep_512x512_cvt_hugging-face_rgb.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='BasiccGAN', 3 | generator=dict( 4 | type='BigGANDeepGenerator', 5 | output_scale=512, 6 | noise_size=128, 7 | num_classes=1000, 8 | base_channels=128, 9 | shared_dim=128, 10 | with_shared_embedding=True, 11 | sn_eps=1e-6, 12 | sn_style='torch', 13 | init_type='ortho', 14 | act_cfg=dict(type='ReLU', inplace=True), 15 | concat_noise=True, 16 | auto_sync_bn=False, 17 | rgb2bgr=True), 18 | discriminator=dict( 19 | type='BigGANDeepDiscriminator', 20 | input_scale=512, 21 | num_classes=1000, 22 | base_channels=128, 23 | sn_eps=1e-6, 24 | sn_style='torch', 25 | init_type='ortho', 26 | act_cfg=dict(type='ReLU', inplace=True), 27 | with_spectral_norm=True), 28 | gan_loss=dict(type='GANLoss', gan_type='hinge')) 29 | 30 | train_cfg = dict( 31 | disc_steps=8, gen_steps=1, batch_accumulation_steps=8, use_ema=True) 32 | test_cfg = None 33 | optimizer = dict( 34 | generator=dict(type='Adam', lr=0.0001, betas=(0.0, 0.999), eps=1e-6), 35 | discriminator=dict(type='Adam', lr=0.0004, betas=(0.0, 0.999), eps=1e-6)) 36 | -------------------------------------------------------------------------------- /configs/pggan/pggan_celeba-hq_1024_g8_12Mimg.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/pggan/pggan_1024.py', 3 | '../_base_/datasets/grow_scale_imgs_celeba-hq.py', 4 | '../_base_/default_runtime.py' 5 | ] 6 | 7 | optimizer = None 8 | checkpoint_config = dict(interval=5000, by_epoch=False, max_keep_ckpts=20) 9 | 10 | data = dict( 11 | samples_per_gpu=64, 12 | train=dict( 13 | gpu_samples_base=4, 14 | # note that this should be changed with total gpu number 15 | gpu_samples_per_scale={ 16 | '4': 64, 17 | '8': 32, 18 | '16': 16, 19 | '32': 8, 20 | '64': 4 21 | }, 22 | )) 23 | 24 | custom_hooks = [ 25 | dict( 26 | type='VisualizeUnconditionalSamples', 27 | output_dir='training_samples', 28 | interval=5000), 29 | dict(type='PGGANFetchDataHook', interval=1), 30 | dict( 31 | type='ExponentialMovingAverageHook', 32 | module_keys=('generator_ema', ), 33 | interval=1, 34 | priority='VERY_HIGH') 35 | ] 36 | 37 | lr_config = None 38 | 39 | total_iters = 280000 40 | 41 | metrics = dict( 42 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000), 43 | swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 1024, 1024))) 44 | -------------------------------------------------------------------------------- /mmgen/models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.utils import Registry, build_from_cfg 4 | 5 | MODELS = Registry('model') 6 | MODULES = Registry('module') 7 | 8 | 9 | def build(cfg, registry, default_args=None): 10 | """Build a module. 11 | 12 | Args: 13 | cfg (dict, list[dict]): The config of modules, is is either a dict 14 | or a list of configs. 15 | registry (:obj:`Registry`): A registry the module belongs to. 16 | default_args (dict, optional): Default arguments to build the module. 17 | Defaults to None. 18 | Returns: 19 | nn.Module: A built nn module. 20 | """ 21 | if isinstance(cfg, list): 22 | modules = [ 23 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg 24 | ] 25 | return nn.ModuleList(modules) 26 | 27 | return build_from_cfg(cfg, registry, default_args) 28 | 29 | 30 | def build_model(cfg, train_cfg=None, test_cfg=None): 31 | """Build model (GAN).""" 32 | return build(cfg, MODELS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) 33 | 34 | 35 | def build_module(cfg, default_args=None): 36 | """Build a module or modules from a list.""" 37 | return build(cfg, MODULES, default_args) 38 | -------------------------------------------------------------------------------- /mmgen/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .ddpm_loss import DDPMVLBLoss 3 | from .disc_auxiliary_loss import (DiscShiftLoss, GradientPenaltyLoss, 4 | R1GradientPenalty, disc_shift_loss, 5 | gradient_penalty_loss, 6 | r1_gradient_penalty_loss) 7 | from .gan_loss import GANLoss 8 | from .gen_auxiliary_loss import (CLIPLoss, FaceIdLoss, 9 | GeneratorPathRegularizer, PerceptualLoss, 10 | gen_path_regularizer) 11 | from .pixelwise_loss import (DiscretizedGaussianLogLikelihoodLoss, 12 | GaussianKLDLoss, L1Loss, MSELoss, 13 | discretized_gaussian_log_likelihood, gaussian_kld) 14 | 15 | __all__ = [ 16 | 'GANLoss', 'DiscShiftLoss', 'disc_shift_loss', 'gradient_penalty_loss', 17 | 'GradientPenaltyLoss', 'R1GradientPenalty', 'r1_gradient_penalty_loss', 18 | 'GeneratorPathRegularizer', 'gen_path_regularizer', 'MSELoss', 'L1Loss', 19 | 'gaussian_kld', 'GaussianKLDLoss', 'DiscretizedGaussianLogLikelihoodLoss', 20 | 'DDPMVLBLoss', 'discretized_gaussian_log_likelihood', 'FaceIdLoss', 21 | 'CLIPLoss', 'PerceptualLoss' 22 | ] 23 | -------------------------------------------------------------------------------- /docs/zh_cn/api.rst: -------------------------------------------------------------------------------- 1 | 接口参考手册 2 | ================= 3 | 4 | mmgen.apis 5 | -------------- 6 | .. automodule:: mmgen.apis 7 | :members: 8 | 9 | mmgen.core 10 | -------------- 11 | 12 | evaluation 13 | ^^^^^^^^^^ 14 | .. automodule:: mmgen.core.evaluation 15 | :members: 16 | 17 | hooks 18 | ^^^^^^^^^^ 19 | .. automodule:: mmgen.core.hooks 20 | :members: 21 | 22 | optimizer 23 | ^^^^^^^^^^ 24 | .. automodule:: mmgen.core.optimizer 25 | :members: 26 | 27 | runners 28 | ^^^^^^^^^^ 29 | .. automodule:: mmgen.core.runners 30 | :members: 31 | 32 | scheduler 33 | ^^^^^^^^^^ 34 | .. automodule:: mmgen.core.scheduler 35 | :members: 36 | 37 | mmgen.datasets 38 | -------------- 39 | 40 | datasets 41 | ^^^^^^^^^^ 42 | .. automodule:: mmgen.datasets 43 | :members: 44 | 45 | pipelines 46 | ^^^^^^^^^^ 47 | .. automodule:: mmgen.datasets.pipelines 48 | :members: 49 | 50 | mmgen.models 51 | -------------- 52 | 53 | architectures 54 | ^^^^^^^^^^ 55 | .. automodule:: mmgen.models.architectures 56 | :members: 57 | 58 | common 59 | ^^^^^^^^^^ 60 | .. automodule:: mmgen.models.common 61 | :members: 62 | 63 | gans 64 | ^^^^^^^^^^^^ 65 | .. automodule:: mmgen.models.gans 66 | :members: 67 | 68 | losses 69 | ^^^^^^^^^^^^ 70 | .. automodule:: mmgen.models.losses 71 | :members: 72 | -------------------------------------------------------------------------------- /configs/_base_/models/singan/singan.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='SinGAN', 3 | generator=dict( 4 | type='SinGANMultiScaleGenerator', 5 | in_channels=3, 6 | out_channels=3, 7 | num_scales=None, # need to be specified 8 | ), 9 | discriminator=dict( 10 | type='SinGANMultiScaleDiscriminator', 11 | in_channels=3, 12 | num_scales=None, # need to be specified 13 | ), 14 | gan_loss=dict(type='GANLoss', gan_type='wgan', loss_weight=1), 15 | disc_auxiliary_loss=[ 16 | dict( 17 | type='GradientPenaltyLoss', 18 | loss_weight=0.1, 19 | norm_mode='pixel', 20 | data_info=dict( 21 | discriminator='disc_partial', 22 | real_data='real_imgs', 23 | fake_data='fake_imgs')) 24 | ], 25 | gen_auxiliary_loss=dict( 26 | type='MSELoss', 27 | loss_weight=10, 28 | data_info=dict(pred='recon_imgs', target='real_imgs'), 29 | )) 30 | 31 | train_cfg = dict( 32 | noise_weight_init=0.1, 33 | iters_per_scale=2000, 34 | curr_scale=-1, 35 | disc_steps=3, 36 | generator_steps=3, 37 | lr_d=0.0005, 38 | lr_g=0.0005, 39 | lr_scheduler_args=dict(milestones=[1600], gamma=0.1)) 40 | 41 | test_cfg = None 42 | -------------------------------------------------------------------------------- /mmgen/models/diffusions/sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | 5 | from ..builder import MODULES 6 | 7 | 8 | @MODULES.register_module() 9 | class UniformTimeStepSampler: 10 | """Timestep sampler for DDPM-based models. This sampler sample all 11 | timesteps with the same probabilistic. 12 | 13 | Args: 14 | num_timesteps (int): Total timesteps of the diffusion process. 15 | """ 16 | 17 | def __init__(self, num_timesteps): 18 | self.num_timesteps = num_timesteps 19 | self.prob = [1 / self.num_timesteps for _ in range(self.num_timesteps)] 20 | 21 | def sample(self, batch_size): 22 | """Sample timesteps. 23 | Args: 24 | batch_size (int): The desired batch size of the sampled timesteps. 25 | 26 | Returns: 27 | torch.Tensor: Sampled timesteps. 28 | """ 29 | # use numpy to make sure our implementation is consistent with the 30 | # official ones. 31 | return torch.from_numpy( 32 | np.random.choice( 33 | self.num_timesteps, size=(batch_size, ), p=self.prob)).long() 34 | 35 | def __call__(self, batch_size): 36 | """Return sampled results.""" 37 | return self.sample(batch_size) 38 | -------------------------------------------------------------------------------- /mmgen/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import logging 3 | 4 | from mmcv.utils import get_logger 5 | 6 | 7 | def get_root_logger(log_file=None, log_level=logging.INFO, file_mode='w'): 8 | """Initialize and get a logger with name of mmgen. 9 | 10 | If the logger has not been initialized, this method will initialize the 11 | logger by adding one or two handlers, otherwise the initialized logger will 12 | be directly returned. During initialization, a StreamHandler will always be 13 | added. If `log_file` is specified and the process rank is 0, a FileHandler 14 | will also be added. 15 | 16 | Args: 17 | log_file (str | None): The log filename. If specified, a FileHandler 18 | will be added to the logger. Defaults to ``None``. 19 | log_level (int): The logger level. Note that only the process of 20 | rank 0 is affected, and other processes will set the level to 21 | "Error" thus be silent most of the time. 22 | Defaults to ``logging.INFO``. 23 | file_mode (str): The file mode used in opening log file. 24 | Defaults to 'w'. 25 | 26 | Returns: 27 | logging.Logger: The expected logger. 28 | """ 29 | return get_logger('mmgen', log_file, log_level, file_mode=file_mode) 30 | -------------------------------------------------------------------------------- /tests/test_datasets/test_pipelines/test_loading.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from pathlib import Path 3 | 4 | import mmcv 5 | import numpy as np 6 | 7 | from mmgen.datasets import LoadImageFromFile 8 | 9 | 10 | def test_load_image_from_file(): 11 | path_baboon = Path( 12 | __file__).parent / '..' / '..' / 'data' / 'image' / 'baboon.png' 13 | img_baboon = mmcv.imread(str(path_baboon), flag='color') 14 | 15 | # read gt image 16 | # input path is Path object 17 | results = dict(gt_path=path_baboon) 18 | config = dict(io_backend='disk', key='gt') 19 | image_loader = LoadImageFromFile(**config) 20 | results = image_loader(results) 21 | assert results['gt'].shape == (480, 500, 3) 22 | np.testing.assert_almost_equal(results['gt'], img_baboon) 23 | assert results['gt_path'] == str(path_baboon) 24 | # input path is str 25 | results = dict(gt_path=str(path_baboon)) 26 | results = image_loader(results) 27 | assert results['gt'].shape == (480, 500, 3) 28 | np.testing.assert_almost_equal(results['gt'], img_baboon) 29 | assert results['gt_path'] == str(path_baboon) 30 | 31 | assert repr(image_loader) == ( 32 | image_loader.__class__.__name__ + 33 | ('(io_backend=disk, key=gt, ' 34 | 'flag=color, save_original_img=False)')) 35 | -------------------------------------------------------------------------------- /configs/pggan/pggan_lsun-bedroom_128_g8_12Mimgs.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/pggan/pggan_128x128.py', 3 | '../_base_/datasets/grow_scale_imgs_128x128.py', 4 | '../_base_/default_runtime.py' 5 | ] 6 | 7 | optimizer = None 8 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=20) 9 | 10 | data = dict( 11 | samples_per_gpu=64, 12 | train=dict( 13 | imgs_roots={'128': './data/lsun/bedroom_train'}, 14 | gpu_samples_base=4, 15 | # note that this should be changed with total gpu number 16 | gpu_samples_per_scale={ 17 | '4': 64, 18 | '8': 32, 19 | '16': 16, 20 | '32': 8, 21 | '64': 4 22 | }, 23 | )) 24 | 25 | custom_hooks = [ 26 | dict( 27 | type='VisualizeUnconditionalSamples', 28 | output_dir='training_samples', 29 | interval=5000), 30 | dict(type='PGGANFetchDataHook', interval=1), 31 | dict( 32 | type='ExponentialMovingAverageHook', 33 | module_keys=('generator_ema', ), 34 | interval=1, 35 | priority='VERY_HIGH') 36 | ] 37 | 38 | lr_config = None 39 | 40 | total_iters = 280000 41 | 42 | metrics = dict( 43 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000), 44 | swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 128, 128))) 45 | -------------------------------------------------------------------------------- /tools/publish_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import subprocess 4 | from datetime import datetime 5 | 6 | import torch 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser( 11 | description='Process a checkpoint to be published') 12 | parser.add_argument('in_file', help='input checkpoint filename') 13 | parser.add_argument('out_file', help='output checkpoint filename') 14 | args = parser.parse_args() 15 | return args 16 | 17 | 18 | def process_checkpoint(in_file, out_file): 19 | checkpoint = torch.load(in_file, map_location='cpu') 20 | # remove optimizer for smaller file size 21 | if 'optimizer' in checkpoint: 22 | del checkpoint['optimizer'] 23 | # if it is necessary to remove some sensitive data in checkpoint['meta'], 24 | # add the code here. 25 | torch.save(checkpoint, out_file) 26 | now = datetime.now() 27 | time = now.strftime('%Y%m%d_%H%M%S') 28 | sha = subprocess.check_output(['sha256sum', out_file]).decode() 29 | final_file = out_file.rstrip('.pth') + f'_{time}-{sha[:8]}.pth' 30 | subprocess.Popen(['mv', out_file, final_file]) 31 | 32 | 33 | def main(): 34 | args = parse_args() 35 | process_checkpoint(args.in_file, args.out_file) 36 | 37 | 38 | if __name__ == '__main__': 39 | main() 40 | -------------------------------------------------------------------------------- /.github/workflows/test_mim.yml: -------------------------------------------------------------------------------- 1 | name: test-mim 2 | 3 | on: 4 | push: 5 | paths: 6 | - 'model-index.yml' 7 | - 'configs/**' 8 | 9 | pull_request: 10 | paths: 11 | - 'model-index.yml' 12 | - 'configs/**' 13 | 14 | concurrency: 15 | group: ${{ github.workflow }}-${{ github.ref }} 16 | cancel-in-progress: true 17 | 18 | jobs: 19 | build_cpu: 20 | runs-on: ubuntu-18.04 21 | strategy: 22 | matrix: 23 | python-version: [3.7] 24 | torch: [1.8.0] 25 | include: 26 | - torch: 1.8.0 27 | torch_version: torch1.8 28 | torchvision: 0.9.0 29 | steps: 30 | - uses: actions/checkout@v2 31 | - name: Set up Python ${{ matrix.python-version }} 32 | uses: actions/setup-python@v2 33 | with: 34 | python-version: ${{ matrix.python-version }} 35 | - name: Upgrade pip 36 | run: pip install pip --upgrade 37 | - name: Install PyTorch 38 | run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html 39 | - name: Install openmim 40 | run: pip install openmim 41 | - name: Build and install 42 | run: rm -rf .eggs && mim install -e . 43 | - name: test commands of mim 44 | run: mim search mmgen 45 | -------------------------------------------------------------------------------- /configs/_base_/datasets/cifar10_rgb.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'mmcls.CIFAR10' 2 | 3 | # different from mmcls, we adopt the setting used in BigGAN 4 | # Note that the pipelines below are from MMClassification 5 | img_norm_cfg = dict( 6 | mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=False) 7 | train_pipeline = [ 8 | dict(type='RandomCrop', size=32, padding=4), 9 | dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), 10 | dict(type='Normalize', **img_norm_cfg), 11 | dict(type='ImageToTensor', keys=['img']), 12 | dict(type='ToTensor', keys=['gt_label']), 13 | dict(type='Collect', keys=['img', 'gt_label']) 14 | ] 15 | test_pipeline = [ 16 | dict(type='Normalize', **img_norm_cfg), 17 | dict(type='ImageToTensor', keys=['img']), 18 | dict(type='Collect', keys=['img']) 19 | ] 20 | 21 | # Different from the classification task, the val/test split also use the 22 | # training part, which is the same to StyleGAN-ADA. 23 | data = dict( 24 | samples_per_gpu=None, 25 | workers_per_gpu=4, 26 | train=dict( 27 | type=dataset_type, data_prefix='data/cifar10', 28 | pipeline=train_pipeline), 29 | val=dict( 30 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline), 31 | test=dict( 32 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline)) 33 | -------------------------------------------------------------------------------- /configs/_base_/datasets/grow_scale_imgs_ffhq_styleganv1.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'GrowScaleImgDataset' 2 | 3 | train_pipeline = [ 4 | dict( 5 | type='LoadImageFromFile', 6 | key='real_img', 7 | io_backend='disk', 8 | ), 9 | dict(type='Flip', keys=['real_img'], direction='horizontal'), 10 | dict( 11 | type='Normalize', 12 | keys=['real_img'], 13 | mean=[127.5, 127.5, 127.5], 14 | std=[127.5, 127.5, 127.5], 15 | to_rgb=False), 16 | dict(type='ImageToTensor', keys=['real_img']), 17 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path']) 18 | ] 19 | 20 | data = dict( 21 | samples_per_gpu=64, 22 | workers_per_gpu=4, 23 | train=dict( 24 | type='GrowScaleImgDataset', 25 | imgs_roots=dict({ 26 | '1024': './data/ffhq/images', 27 | '256': './data/ffhq/ffhq_imgs/ffhq_256', 28 | '64': './data/ffhq/ffhq_imgs/ffhq_64' 29 | }), 30 | pipeline=train_pipeline, 31 | gpu_samples_base=4, 32 | gpu_samples_per_scale={ 33 | '4': 64, 34 | '8': 32, 35 | '16': 16, 36 | '32': 8, 37 | '64': 4, 38 | '128': 4, 39 | '256': 4, 40 | '512': 4, 41 | '1024': 4 42 | }, 43 | len_per_stage=300000)) 44 | -------------------------------------------------------------------------------- /configs/pggan/pggan_celeba-cropped_128_g8_12Mimgs.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/pggan/pggan_128x128.py', 3 | '../_base_/datasets/grow_scale_imgs_128x128.py', 4 | '../_base_/default_runtime.py' 5 | ] 6 | 7 | optimizer = None 8 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=20) 9 | 10 | data = dict( 11 | samples_per_gpu=64, 12 | train=dict( 13 | imgs_roots={'128': './data/celeba-cropped/cropped_images_aligned_png'}, 14 | gpu_samples_base=4, 15 | # note that this should be changed with total gpu number 16 | gpu_samples_per_scale={ 17 | '4': 64, 18 | '8': 32, 19 | '16': 16, 20 | '32': 8, 21 | '64': 4 22 | })) 23 | 24 | custom_hooks = [ 25 | dict( 26 | type='VisualizeUnconditionalSamples', 27 | output_dir='training_samples', 28 | interval=5000), 29 | dict(type='PGGANFetchDataHook', interval=1), 30 | dict( 31 | type='ExponentialMovingAverageHook', 32 | module_keys=('generator_ema', ), 33 | interval=1, 34 | priority='VERY_HIGH') 35 | ] 36 | 37 | lr_config = None 38 | 39 | total_iters = 280000 40 | 41 | metrics = dict( 42 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000), 43 | swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 128, 128))) 44 | -------------------------------------------------------------------------------- /configs/styleganv1/metafile.yml: -------------------------------------------------------------------------------- 1 | Collections: 2 | - Metadata: 3 | Architecture: 4 | - StyleGANv1 5 | Name: StyleGANv1 6 | Paper: 7 | - https://openaccess.thecvf.com/content_CVPR_2019/html/Karras_A_Style-Based_Generator_Architecture_for_Generative_Adversarial_Networks_CVPR_2019_paper.html 8 | README: configs/styleganv1/README.md 9 | Models: 10 | - Config: https://github.com/open-mmlab/mmgeneration/tree/master/configs/styleganv1/styleganv1_ffhq_256_g8_25Mimg.py 11 | In Collection: StyleGANv1 12 | Metadata: 13 | Training Data: FFHQ 14 | Name: styleganv1_ffhq_256_g8_25Mimg 15 | Results: 16 | - Dataset: FFHQ 17 | Metrics: 18 | FID50k: 6.09 19 | P&R50k_full: 70.228/27.050 20 | Task: Unconditional GANs 21 | Weights: https://download.openmmlab.com/mmgen/styleganv1/styleganv1_ffhq_256_g8_25Mimg_20210407_161748-0094da86.pth 22 | - Config: https://github.com/open-mmlab/mmgeneration/tree/master/configs/styleganv1/styleganv1_ffhq_1024_g8_25Mimg.py 23 | In Collection: StyleGANv1 24 | Metadata: 25 | Training Data: FFHQ 26 | Name: styleganv1_ffhq_1024_g8_25Mimg 27 | Results: 28 | - Dataset: FFHQ 29 | Metrics: 30 | FID50k: 4.056 31 | P&R50k_full: 70.302/36.869 32 | Task: Unconditional GANs 33 | Weights: https://download.openmmlab.com/mmgen/styleganv1/styleganv1_ffhq_1024_g8_25Mimg_20210407_161627-850a7234.pth 34 | -------------------------------------------------------------------------------- /configs/_base_/models/pix2pix/pix2pix_vanilla_unet_bn.py: -------------------------------------------------------------------------------- 1 | source_domain = None # set by user 2 | target_domain = None # set by user 3 | # model settings 4 | model = dict( 5 | type='Pix2Pix', 6 | generator=dict( 7 | type='UnetGenerator', 8 | in_channels=3, 9 | out_channels=3, 10 | num_down=8, 11 | base_channels=64, 12 | norm_cfg=dict(type='BN'), 13 | use_dropout=True, 14 | init_cfg=dict(type='normal', gain=0.02)), 15 | discriminator=dict( 16 | type='PatchDiscriminator', 17 | in_channels=6, 18 | base_channels=64, 19 | num_conv=3, 20 | norm_cfg=dict(type='BN'), 21 | init_cfg=dict(type='normal', gain=0.02)), 22 | gan_loss=dict( 23 | type='GANLoss', 24 | gan_type='vanilla', 25 | real_label_val=1.0, 26 | fake_label_val=0.0, 27 | loss_weight=1.0), 28 | default_domain=target_domain, 29 | reachable_domains=[target_domain], 30 | related_domains=[target_domain, source_domain], 31 | gen_auxiliary_loss=dict( 32 | type='L1Loss', 33 | loss_weight=100.0, 34 | loss_name='pixel_loss', 35 | data_info=dict( 36 | pred=f'fake_{target_domain}', target=f'real_{target_domain}'), 37 | reduction='mean')) 38 | # model training and testing settings 39 | train_cfg = None 40 | test_cfg = None 41 | -------------------------------------------------------------------------------- /configs/wgan-gp/metafile.yml: -------------------------------------------------------------------------------- 1 | Collections: 2 | - Metadata: 3 | Architecture: 4 | - WGAN-GP 5 | Name: WGAN-GP 6 | Paper: 7 | - https://arxiv.org/abs/1704.00028 8 | README: configs/wgan-gp/README.md 9 | Models: 10 | - Config: https://github.com/open-mmlab/mmgeneration/tree/master/configs/wgan-gp/wgangp_GN_celeba-cropped_128_b64x1_160kiter.py 11 | In Collection: WGAN-GP 12 | Metadata: 13 | Training Data: CELEBA 14 | Name: wgangp_GN_celeba-cropped_128_b64x1_160kiter 15 | Results: 16 | - Dataset: CELEBA 17 | Metrics: 18 | Details: GN 19 | MS-SSIM: 0.2601 20 | SWD: 5.87, 9.76, 9.43, 18.84/10.97 21 | Task: Unconditional GANs 22 | Weights: https://download.openmmlab.com/mmgen/wgangp/wgangp_GN_celeba-cropped_128_b64x1_160k_20210408_170611-f8a99336.pth 23 | - Config: https://github.com/open-mmlab/mmgeneration/tree/master/configs/wgan-gp/wgangp_GN_GP-50_lsun-bedroom_128_b64x1_160kiter.py 24 | In Collection: WGAN-GP 25 | Metadata: 26 | Training Data: LSUN 27 | Name: wgangp_GN_GP-50_lsun-bedroom_128_b64x1_160kiter 28 | Results: 29 | - Dataset: LSUN 30 | Metrics: 31 | Details: GN, GP-lambda = 50 32 | MS-SSIM: 0.059 33 | SWD: 11.7, 7.87, 9.82, 25.36/13.69 34 | Task: Unconditional GANs 35 | Weights: https://download.openmmlab.com/mmgen/wgangp/wgangp_GN_GP-50_lsun-bedroom_128_b64x1_130k_20210408_170509-56f2a37c.pth 36 | -------------------------------------------------------------------------------- /configs/_base_/datasets/cifar10_inception_stat.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'mmcls.CIFAR10' 2 | 3 | # This config is set for extract inception state of CIFAR dataset. 4 | 5 | # Different from mmcls, we adopt the setting used in BigGAN. 6 | # Note that the pipelines below are from MMClassification. 7 | # The default order in Cifar10 is RGB. Thus, we set `to_rgb` as `False`. 8 | img_norm_cfg = dict( 9 | mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=False) 10 | train_pipeline = [ 11 | dict(type='Normalize', **img_norm_cfg), 12 | dict(type='ImageToTensor', keys=['img']), 13 | dict(type='ToTensor', keys=['gt_label']), 14 | dict(type='Collect', keys=['img', 'gt_label']) 15 | ] 16 | test_pipeline = [ 17 | dict(type='Normalize', **img_norm_cfg), 18 | dict(type='ImageToTensor', keys=['img']), 19 | dict(type='Collect', keys=['img']) 20 | ] 21 | 22 | # Different from the classification task, the val/test split also use the 23 | # training part, which is the same to StyleGAN-ADA. 24 | data = dict( 25 | samples_per_gpu=None, 26 | workers_per_gpu=4, 27 | train=dict( 28 | type=dataset_type, data_prefix='data/cifar10', 29 | pipeline=train_pipeline), 30 | val=dict( 31 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline), 32 | test=dict( 33 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline)) 34 | -------------------------------------------------------------------------------- /configs/_base_/datasets/grow_scale_imgs_128x128.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'GrowScaleImgDataset' 2 | 3 | train_pipeline = [ 4 | dict( 5 | type='LoadImageFromFile', 6 | key='real_img', 7 | io_backend='disk', 8 | ), 9 | dict(type='Resize', keys=['real_img'], scale=(128, 128)), 10 | dict(type='Flip', keys=['real_img'], direction='horizontal'), 11 | dict( 12 | type='Normalize', 13 | keys=['real_img'], 14 | mean=[127.5] * 3, 15 | std=[127.5] * 3, 16 | to_rgb=False), 17 | dict(type='ImageToTensor', keys=['real_img']), 18 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path']) 19 | ] 20 | 21 | # `samples_per_gpu` and `imgs_root` need to be set. 22 | data = dict( 23 | # samples per gpu should be the same as the first scale, e.g. '4': 64 24 | # in this case 25 | samples_per_gpu=None, 26 | workers_per_gpu=4, 27 | train=dict( 28 | type=dataset_type, 29 | # just an example 30 | imgs_roots={'128': './data/lsun/bedroom_train'}, 31 | pipeline=train_pipeline, 32 | gpu_samples_base=4, 33 | # note that this should be changed with total gpu number 34 | gpu_samples_per_scale={ 35 | '4': 64, 36 | '8': 32, 37 | '16': 16, 38 | '32': 8, 39 | '64': 4 40 | }, 41 | len_per_stage=-1)) 42 | -------------------------------------------------------------------------------- /mmgen/core/runners/apex_amp_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | try: 3 | from apex import amp 4 | except ImportError: 5 | amp = None 6 | 7 | 8 | def apex_amp_initialize(models, optimizers, init_args=None, mode='gan'): 9 | """Initialize apex.amp for mixed-precision training. 10 | 11 | Args: 12 | models (nn.Module | list[Module]): Modules to be wrapped with apex.amp. 13 | optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. 14 | init_args (dict | None, optional): Config for amp initialization. 15 | Defaults to None. 16 | mode (str, optional): The moded used to initialize the apex.map. 17 | Different modes lead to different wrapping mode for models and 18 | optimizers. Defaults to 'gan'. 19 | 20 | Returns: 21 | Module, :obj:`Optimizer`: Wrapped module and optimizer. 22 | """ 23 | init_args = init_args or dict() 24 | 25 | if mode == 'gan': 26 | _optmizers = [optimizers['generator'], optimizers['discriminator']] 27 | 28 | models, _optmizers = amp.initialize(models, _optmizers, **init_args) 29 | optimizers['generator'] = _optmizers[0] 30 | optimizers['discriminator'] = _optmizers[1] 31 | 32 | return models, optimizers 33 | 34 | else: 35 | raise NotImplementedError( 36 | f'Cannot initialize apex.amp with mode {mode}') 37 | -------------------------------------------------------------------------------- /tests/test_datasets/test_pipelines/test_compose.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import pytest 4 | 5 | from mmgen.datasets.pipelines import Compose, ImageToTensor 6 | 7 | 8 | def check_keys_equal(result_keys, target_keys): 9 | """Check if all elements in target_keys is in result_keys.""" 10 | return set(target_keys) == set(result_keys) 11 | 12 | 13 | def test_compose(): 14 | with pytest.raises(TypeError): 15 | Compose('LoadAlpha') 16 | 17 | target_keys = ['img', 'meta'] 18 | 19 | img = np.random.randn(256, 256, 3) 20 | results = dict(img=img, abandoned_key=None, img_name='test_image.png') 21 | test_pipeline = [ 22 | dict(type='Collect', keys=['img'], meta_keys=['img_name']), 23 | dict(type='ImageToTensor', keys=['img']) 24 | ] 25 | compose = Compose(test_pipeline) 26 | compose_results = compose(results) 27 | assert check_keys_equal(compose_results.keys(), target_keys) 28 | assert check_keys_equal(compose_results['meta'].data.keys(), ['img_name']) 29 | 30 | results = None 31 | image_to_tensor = ImageToTensor(keys=[]) 32 | test_pipeline = [image_to_tensor] 33 | compose = Compose(test_pipeline) 34 | compose_results = compose(results) 35 | assert compose_results is None 36 | 37 | assert repr(compose) == ( 38 | compose.__class__.__name__ + f'(\n {image_to_tensor}\n)') 39 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | # this allows you to use CircleCI's dynamic configuration feature 4 | setup: true 5 | 6 | # the path-filtering orb is required to continue a pipeline based on 7 | # the path of an updated fileset 8 | orbs: 9 | path-filtering: circleci/path-filtering@0.1.2 10 | 11 | workflows: 12 | # the always-run workflow is always triggered, regardless of the pipeline parameters. 13 | always-run: 14 | jobs: 15 | # the path-filtering/filter job determines which pipeline 16 | # parameters to update. 17 | - path-filtering/filter: 18 | name: check-updated-files 19 | # 3-column, whitespace-delimited mapping. One mapping per 20 | # line: 21 | # 22 | mapping: | 23 | mmgen/.* lint_only false 24 | requirements/.* lint_only false 25 | tests/.* lint_only false 26 | tools/.* lint_only false 27 | configs/.* lint_only false 28 | .circleci/.* lint_only false 29 | base-revision: master 30 | # this is the path of the configuration we should trigger once 31 | # path filtering and pipeline parameter value updates are 32 | # complete. In this case, we are using the parent dynamic 33 | # configuration itself. 34 | config-path: .circleci/test.yml 35 | -------------------------------------------------------------------------------- /tools/misc/print_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | 4 | from mmcv import Config, DictAction 5 | 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser(description='Print the whole config') 9 | parser.add_argument('config', help='config file path') 10 | parser.add_argument( 11 | '--cfg-options', 12 | nargs='+', 13 | action=DictAction, 14 | help='override some settings in the used config, the key-value pair ' 15 | 'in xxx=yyy format will be merged into config file. If the value to ' 16 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 17 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 18 | 'Note that the quotation marks are necessary and that no white space ' 19 | 'is allowed.') 20 | args = parser.parse_args() 21 | 22 | return args 23 | 24 | 25 | def main(): 26 | args = parse_args() 27 | 28 | cfg = Config.fromfile(args.config) 29 | if args.cfg_options is not None: 30 | cfg.merge_from_dict(args.cfg_options) 31 | # import modules from string list. 32 | if cfg.get('custom_imports', None): 33 | from mmcv.utils import import_modules_from_strings 34 | import_modules_from_strings(**cfg['custom_imports']) 35 | print(f'Config:\n{cfg.pretty_text}') 36 | 37 | 38 | if __name__ == '__main__': 39 | main() 40 | -------------------------------------------------------------------------------- /configs/_base_/datasets/ffhq_flip.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'UnconditionalImageDataset' 2 | 3 | train_pipeline = [ 4 | dict( 5 | type='LoadImageFromFile', 6 | key='real_img', 7 | io_backend='disk', 8 | ), 9 | dict(type='Flip', keys=['real_img'], direction='horizontal'), 10 | dict( 11 | type='Normalize', 12 | keys=['real_img'], 13 | mean=[127.5] * 3, 14 | std=[127.5] * 3, 15 | to_rgb=False), 16 | dict(type='ImageToTensor', keys=['real_img']), 17 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path']) 18 | ] 19 | 20 | val_pipeline = [ 21 | dict( 22 | type='LoadImageFromFile', 23 | key='real_img', 24 | io_backend='disk', 25 | ), 26 | dict( 27 | type='Normalize', 28 | keys=['real_img'], 29 | mean=[127.5] * 3, 30 | std=[127.5] * 3, 31 | to_rgb=True), 32 | dict(type='ImageToTensor', keys=['real_img']), 33 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path']) 34 | ] 35 | 36 | # `samples_per_gpu` and `imgs_root` need to be set. 37 | data = dict( 38 | samples_per_gpu=None, 39 | workers_per_gpu=4, 40 | train=dict( 41 | type='RepeatDataset', 42 | times=100, 43 | dataset=dict( 44 | type=dataset_type, imgs_root=None, pipeline=train_pipeline)), 45 | val=dict(type=dataset_type, imgs_root=None, pipeline=val_pipeline)) 46 | -------------------------------------------------------------------------------- /configs/_base_/datasets/lsun_stylegan.py: -------------------------------------------------------------------------------- 1 | # Style-based GANs do not perform any augmentation for the LSUN datasets 2 | dataset_type = 'UnconditionalImageDataset' 3 | 4 | train_pipeline = [ 5 | dict( 6 | type='LoadImageFromFile', 7 | key='real_img', 8 | io_backend='disk', 9 | ), 10 | dict( 11 | type='Normalize', 12 | keys=['real_img'], 13 | mean=[127.5] * 3, 14 | std=[127.5] * 3, 15 | to_rgb=False), 16 | dict(type='ImageToTensor', keys=['real_img']), 17 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path']) 18 | ] 19 | 20 | val_pipeline = [ 21 | dict( 22 | type='LoadImageFromFile', 23 | key='real_img', 24 | io_backend='disk', 25 | ), 26 | dict( 27 | type='Normalize', 28 | keys=['real_img'], 29 | mean=[127.5] * 3, 30 | std=[127.5] * 3, 31 | to_rgb=True), 32 | dict(type='ImageToTensor', keys=['real_img']), 33 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path']) 34 | ] 35 | 36 | # `samples_per_gpu` and `imgs_root` need to be set. 37 | data = dict( 38 | samples_per_gpu=None, 39 | workers_per_gpu=4, 40 | train=dict( 41 | type='RepeatDataset', 42 | times=100, 43 | dataset=dict( 44 | type=dataset_type, imgs_root=None, pipeline=train_pipeline)), 45 | val=dict(type=dataset_type, imgs_root=None, pipeline=val_pipeline)) 46 | -------------------------------------------------------------------------------- /configs/_base_/datasets/cifar10_nopad.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'mmcls.CIFAR10' 2 | 3 | # different from mmcls, we adopt the setting used in BigGAN 4 | # Note that the pipelines below are from MMClassification. Importantly, the 5 | # `to_rgb` is set to `True` to convert image to BGR orders. The default order 6 | # in Cifar10 is RGB. Thus, we have to convert it to BGR. 7 | img_norm_cfg = dict( 8 | mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) 9 | train_pipeline = [ 10 | dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), 11 | dict(type='Normalize', **img_norm_cfg), 12 | dict(type='ImageToTensor', keys=['img']), 13 | dict(type='ToTensor', keys=['gt_label']), 14 | dict(type='Collect', keys=['img', 'gt_label']) 15 | ] 16 | test_pipeline = [ 17 | dict(type='Normalize', **img_norm_cfg), 18 | dict(type='ImageToTensor', keys=['img']), 19 | dict(type='Collect', keys=['img']) 20 | ] 21 | 22 | # Different from the classification task, the val/test split also use the 23 | # training part, which is the same to StyleGAN-ADA. 24 | data = dict( 25 | samples_per_gpu=None, 26 | workers_per_gpu=4, 27 | train=dict( 28 | type=dataset_type, data_prefix='data/cifar10', 29 | pipeline=train_pipeline), 30 | val=dict( 31 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline), 32 | test=dict( 33 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline)) 34 | -------------------------------------------------------------------------------- /mmgen/ops/stylegan3/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /configs/styleganv1/styleganv1_ffhq_256_g8_25Mimg.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/stylegan/styleganv1_base.py', 3 | '../_base_/datasets/grow_scale_imgs_ffhq_styleganv1.py', 4 | '../_base_/default_runtime.py', 5 | ] 6 | 7 | model = dict(generator=dict(out_size=256), discriminator=dict(in_size=256)) 8 | 9 | train_cfg = dict(nkimgs_per_scale={ 10 | '8': 1200, 11 | '16': 1200, 12 | '32': 1200, 13 | '64': 1200, 14 | '128': 1200, 15 | '256': 190000 16 | }) 17 | 18 | checkpoint_config = dict(interval=5000, by_epoch=False, max_keep_ckpts=20) 19 | lr_config = None 20 | 21 | ema_half_life = 10. # G_smoothing_kimg 22 | 23 | custom_hooks = [ 24 | dict( 25 | type='VisualizeUnconditionalSamples', 26 | output_dir='training_samples', 27 | interval=5000), 28 | dict(type='PGGANFetchDataHook', interval=1), 29 | dict( 30 | type='ExponentialMovingAverageHook', 31 | module_keys=('generator_ema', ), 32 | interval=1, 33 | interp_cfg=dict(momentum=0.5**(32. / (ema_half_life * 1000.))), 34 | priority='VERY_HIGH') 35 | ] 36 | 37 | total_iters = 670000 38 | 39 | metrics = dict( 40 | fid50k=dict( 41 | type='FID', 42 | num_images=50000, 43 | inception_pkl='work_dirs/inception_pkl/ffhq-256-50k-rgb.pkl', 44 | bgr2rgb=True), 45 | pr50k3=dict(type='PR', num_images=50000, k=3), 46 | ppl_wend=dict(type='PPL', space='W', sampling='end', num_images=50000)) 47 | -------------------------------------------------------------------------------- /configs/_base_/models/stylegan/stylegan2_base.py: -------------------------------------------------------------------------------- 1 | # define GAN model 2 | 3 | d_reg_interval = 16 4 | g_reg_interval = 4 5 | 6 | g_reg_ratio = g_reg_interval / (g_reg_interval + 1) 7 | d_reg_ratio = d_reg_interval / (d_reg_interval + 1) 8 | 9 | model = dict( 10 | type='StaticUnconditionalGAN', 11 | generator=dict( 12 | type='StyleGANv2Generator', 13 | out_size=None, # Need to be set. 14 | style_channels=512, 15 | ), 16 | discriminator=dict( 17 | type='StyleGAN2Discriminator', 18 | in_size=None, # Need to be set. 19 | ), 20 | gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'), 21 | disc_auxiliary_loss=dict( 22 | type='R1GradientPenalty', 23 | loss_weight=10. / 2. * d_reg_interval, 24 | interval=d_reg_interval, 25 | norm_mode='HWC', 26 | data_info=dict(real_data='real_imgs', discriminator='disc')), 27 | gen_auxiliary_loss=dict( 28 | type='GeneratorPathRegularizer', 29 | loss_weight=2. * g_reg_interval, 30 | pl_batch_shrink=2, 31 | interval=g_reg_interval, 32 | data_info=dict(generator='gen', num_batches='batch_size'))) 33 | 34 | train_cfg = dict(use_ema=True) 35 | test_cfg = None 36 | 37 | # define optimizer 38 | optimizer = dict( 39 | generator=dict( 40 | type='Adam', lr=0.002 * g_reg_ratio, betas=(0, 0.99**g_reg_ratio)), 41 | discriminator=dict( 42 | type='Adam', lr=0.002 * d_reg_ratio, betas=(0, 0.99**d_reg_ratio))) 43 | -------------------------------------------------------------------------------- /configs/styleganv2/stylegan2_c2_lsun-cat_256_b4x8_800k.py: -------------------------------------------------------------------------------- 1 | """Note that this config is just for testing.""" 2 | 3 | _base_ = [ 4 | '../_base_/datasets/lsun_stylegan.py', 5 | '../_base_/models/stylegan/stylegan2_base.py', 6 | '../_base_/default_runtime.py' 7 | ] 8 | 9 | model = dict(generator=dict(out_size=256), discriminator=dict(in_size=256)) 10 | 11 | data = dict( 12 | samples_per_gpu=4, train=dict(dataset=dict(imgs_root='./data/lsun-cat'))) 13 | 14 | ema_half_life = 10. # G_smoothing_kimg 15 | 16 | custom_hooks = [ 17 | dict( 18 | type='VisualizeUnconditionalSamples', 19 | output_dir='training_samples', 20 | interval=5000), 21 | dict( 22 | type='ExponentialMovingAverageHook', 23 | module_keys=('generator_ema', ), 24 | interval=1, 25 | interp_cfg=dict(momentum=0.5**(32. / (ema_half_life * 1000.))), 26 | priority='VERY_HIGH') 27 | ] 28 | 29 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=30) 30 | lr_config = None 31 | 32 | log_config = dict( 33 | interval=100, 34 | hooks=[ 35 | dict(type='TextLoggerHook'), 36 | # dict(type='TensorboardLoggerHook'), 37 | ]) 38 | 39 | total_iters = 800002 # need to modify 40 | 41 | metrics = dict( 42 | fid50k=dict( 43 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True), 44 | pr50k3=dict(type='PR', num_images=50000, k=3), 45 | ppl_wend=dict(type='PPL', space='W', sampling='end', num_images=50000)) 46 | -------------------------------------------------------------------------------- /configs/_base_/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'mmcls.CIFAR10' 2 | 3 | # different from mmcls, we adopt the setting used in BigGAN 4 | # Note that the pipelines below are from MMClassification. Importantly, the 5 | # `to_rgb` is set to `True` to convert image to BGR orders. The default order 6 | # in Cifar10 is RGB. Thus, we have to convert it to BGR. 7 | img_norm_cfg = dict( 8 | mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) 9 | train_pipeline = [ 10 | dict(type='RandomCrop', size=32, padding=4), 11 | dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), 12 | dict(type='Normalize', **img_norm_cfg), 13 | dict(type='ImageToTensor', keys=['img']), 14 | dict(type='ToTensor', keys=['gt_label']), 15 | dict(type='Collect', keys=['img', 'gt_label']) 16 | ] 17 | test_pipeline = [ 18 | dict(type='Normalize', **img_norm_cfg), 19 | dict(type='ImageToTensor', keys=['img']), 20 | dict(type='Collect', keys=['img']) 21 | ] 22 | 23 | # Different from the classification task, the val/test split also use the 24 | # training part, which is the same to StyleGAN-ADA. 25 | data = dict( 26 | samples_per_gpu=None, 27 | workers_per_gpu=4, 28 | train=dict( 29 | type=dataset_type, data_prefix='data/cifar10', 30 | pipeline=train_pipeline), 31 | val=dict( 32 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline), 33 | test=dict( 34 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline)) 35 | -------------------------------------------------------------------------------- /configs/styleganv2/stylegan2_c2_lsun-horse_256_b4x8_800k.py: -------------------------------------------------------------------------------- 1 | """Note that this config is just for testing.""" 2 | 3 | _base_ = [ 4 | '../_base_/datasets/lsun_stylegan.py', 5 | '../_base_/models/stylegan/stylegan2_base.py', 6 | '../_base_/default_runtime.py' 7 | ] 8 | 9 | model = dict(generator=dict(out_size=256), discriminator=dict(in_size=256)) 10 | 11 | data = dict( 12 | samples_per_gpu=4, train=dict(dataset=dict(imgs_root='./data/lsun-horse'))) 13 | 14 | ema_half_life = 10. # G_smoothing_kimg 15 | 16 | custom_hooks = [ 17 | dict( 18 | type='VisualizeUnconditionalSamples', 19 | output_dir='training_samples', 20 | interval=5000), 21 | dict( 22 | type='ExponentialMovingAverageHook', 23 | module_keys=('generator_ema', ), 24 | interval=1, 25 | interp_cfg=dict(momentum=0.5**(32. / (ema_half_life * 1000.))), 26 | priority='VERY_HIGH') 27 | ] 28 | 29 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=30) 30 | lr_config = None 31 | 32 | log_config = dict( 33 | interval=100, 34 | hooks=[ 35 | dict(type='TextLoggerHook'), 36 | # dict(type='TensorboardLoggerHook'), 37 | ]) 38 | 39 | total_iters = 800002 # need to modify 40 | 41 | metrics = dict( 42 | fid50k=dict( 43 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True), 44 | pr50k3=dict(type='PR', num_images=50000, k=3), 45 | ppl_wend=dict(type='PPL', space='W', sampling='end', num_images=50000)) 46 | -------------------------------------------------------------------------------- /configs/_base_/datasets/grow_scale_imgs_celeba-hq.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'GrowScaleImgDataset' 2 | 3 | train_pipeline = [ 4 | dict( 5 | type='LoadImageFromFile', 6 | key='real_img', 7 | io_backend='disk', 8 | ), 9 | dict(type='Flip', keys=['real_img'], direction='horizontal'), 10 | dict( 11 | type='Normalize', 12 | keys=['real_img'], 13 | mean=[127.5] * 3, 14 | std=[127.5] * 3, 15 | to_rgb=False), 16 | dict(type='ImageToTensor', keys=['real_img']), 17 | dict(type='Collect', keys=['real_img'], meta_keys=['real_img_path']) 18 | ] 19 | 20 | # `samples_per_gpu` and `imgs_root` need to be set. 21 | data = dict( 22 | # samples per gpu should be the same as the first scale, e.g. '4': 64 23 | # in this case 24 | samples_per_gpu=None, 25 | workers_per_gpu=4, 26 | train=dict( 27 | type=dataset_type, 28 | # just an example 29 | imgs_roots={ 30 | '64': './data/celebahq/imgs_64', 31 | '256': './data/celebahq/imgs_256', 32 | '512': './data/celebahq/imgs_512', 33 | '1024': './data/celebahq/imgs_1024' 34 | }, 35 | pipeline=train_pipeline, 36 | gpu_samples_base=4, 37 | # note that this should be changed with total gpu number 38 | gpu_samples_per_scale={ 39 | '4': 64, 40 | '8': 32, 41 | '16': 16, 42 | '32': 8, 43 | '64': 4 44 | }, 45 | len_per_stage=300000)) 46 | -------------------------------------------------------------------------------- /configs/styleganv2/stylegan2_c2_lsun-church_256_b4x8_800k.py: -------------------------------------------------------------------------------- 1 | """Note that this config is just for testing.""" 2 | 3 | _base_ = [ 4 | '../_base_/datasets/lsun_stylegan.py', 5 | '../_base_/models/stylegan/stylegan2_base.py', 6 | '../_base_/default_runtime.py' 7 | ] 8 | 9 | model = dict(generator=dict(out_size=256), discriminator=dict(in_size=256)) 10 | 11 | data = dict( 12 | samples_per_gpu=4, 13 | train=dict(dataset=dict(imgs_root='./data/lsun-church'))) 14 | 15 | ema_half_life = 10. # G_smoothing_kimg 16 | 17 | custom_hooks = [ 18 | dict( 19 | type='VisualizeUnconditionalSamples', 20 | output_dir='training_samples', 21 | interval=5000), 22 | dict( 23 | type='ExponentialMovingAverageHook', 24 | module_keys=('generator_ema', ), 25 | interval=1, 26 | interp_cfg=dict(momentum=0.5**(32. / (ema_half_life * 1000.))), 27 | priority='VERY_HIGH') 28 | ] 29 | 30 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=30) 31 | lr_config = None 32 | 33 | log_config = dict( 34 | interval=100, 35 | hooks=[ 36 | dict(type='TextLoggerHook'), 37 | # dict(type='TensorboardLoggerHook'), 38 | ]) 39 | 40 | total_iters = 800002 # need to modify 41 | 42 | metrics = dict( 43 | fid50k=dict( 44 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True), 45 | pr50k3=dict(type='PR', num_images=50000, k=3), 46 | ppl_wend=dict(type='PPL', space='W', sampling='end', num_images=50000)) 47 | -------------------------------------------------------------------------------- /configs/positional_encoding_in_gans/stylegan2_c2_ffhq_256_b3x8_1100k.py: -------------------------------------------------------------------------------- 1 | """Config for the `config-f` setting in StyleGAN2.""" 2 | 3 | _base_ = [ 4 | '../_base_/datasets/ffhq_flip.py', 5 | '../_base_/models/stylegan/stylegan2_base.py', 6 | '../_base_/default_runtime.py' 7 | ] 8 | 9 | model = dict(generator=dict(out_size=256), discriminator=dict(in_size=256)) 10 | 11 | data = dict( 12 | samples_per_gpu=3, 13 | train=dict(dataset=dict(imgs_root='./data/ffhq/ffhq_imgs/ffhq_256'))) 14 | 15 | ema_half_life = 10. # G_smoothing_kimg 16 | 17 | custom_hooks = [ 18 | dict( 19 | type='VisualizeUnconditionalSamples', 20 | output_dir='training_samples', 21 | interval=5000), 22 | dict( 23 | type='ExponentialMovingAverageHook', 24 | module_keys=('generator_ema', ), 25 | interval=1, 26 | interp_cfg=dict(momentum=0.5**(32. / (ema_half_life * 1000.))), 27 | priority='VERY_HIGH') 28 | ] 29 | 30 | metrics = dict( 31 | fid50k=dict( 32 | type='FID', 33 | num_images=50000, 34 | inception_pkl='work_dirs/inception_pkl/ffhq-256-50k-rgb.pkl', 35 | bgr2rgb=True), 36 | pr10k3=dict(type='PR', num_images=10000, k=3)) 37 | 38 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=30) 39 | lr_config = None 40 | 41 | log_config = dict( 42 | interval=100, 43 | hooks=[ 44 | dict(type='TextLoggerHook'), 45 | # dict(type='TensorboardLoggerHook'), 46 | ]) 47 | 48 | total_iters = 1100002 49 | -------------------------------------------------------------------------------- /configs/positional_encoding_in_gans/stylegan2_c2_ffhq_512_b3x8_1100k.py: -------------------------------------------------------------------------------- 1 | """Config for the `config-f` setting in StyleGAN2.""" 2 | 3 | _base_ = [ 4 | '../_base_/datasets/ffhq_flip.py', 5 | '../_base_/models/stylegan/stylegan2_base.py', 6 | '../_base_/default_runtime.py' 7 | ] 8 | 9 | model = dict(generator=dict(out_size=512), discriminator=dict(in_size=512)) 10 | 11 | data = dict( 12 | samples_per_gpu=3, 13 | train=dict(dataset=dict(imgs_root='./data/ffhq/ffhq_imgs/ffhq_512'))) 14 | 15 | ema_half_life = 10. # G_smoothing_kimg 16 | 17 | custom_hooks = [ 18 | dict( 19 | type='VisualizeUnconditionalSamples', 20 | output_dir='training_samples', 21 | interval=5000), 22 | dict( 23 | type='ExponentialMovingAverageHook', 24 | module_keys=('generator_ema', ), 25 | interval=1, 26 | interp_cfg=dict(momentum=0.5**(32. / (ema_half_life * 1000.))), 27 | priority='VERY_HIGH') 28 | ] 29 | 30 | metrics = dict( 31 | fid50k=dict( 32 | type='FID', 33 | num_images=50000, 34 | inception_pkl='work_dirs/inception_pkl/ffhq-512-50k-rgb.pkl', 35 | bgr2rgb=True), 36 | pr10k3=dict(type='PR', num_images=10000, k=3)) 37 | 38 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=30) 39 | lr_config = None 40 | 41 | log_config = dict( 42 | interval=100, 43 | hooks=[ 44 | dict(type='TextLoggerHook'), 45 | # dict(type='TensorboardLoggerHook'), 46 | ]) 47 | 48 | total_iters = 1100002 49 | -------------------------------------------------------------------------------- /configs/_base_/models/pggan/pggan_128x128.py: -------------------------------------------------------------------------------- 1 | # define GAN model 2 | model = dict( 3 | type='ProgressiveGrowingGAN', 4 | generator=dict(type='PGGANGenerator', out_scale=128, noise_size=512), 5 | discriminator=dict(type='PGGANDiscriminator', in_scale=128), 6 | gan_loss=dict(type='GANLoss', gan_type='wgan'), 7 | disc_auxiliary_loss=[ 8 | dict( 9 | type='DiscShiftLoss', 10 | loss_weight=0.001 * 0.5, 11 | data_info=dict(pred='disc_pred_fake')), 12 | dict( 13 | type='DiscShiftLoss', 14 | loss_weight=0.001 * 0.5, 15 | data_info=dict(pred='disc_pred_real')), 16 | dict( 17 | type='GradientPenaltyLoss', 18 | loss_weight=10, 19 | norm_mode='HWC', 20 | data_info=dict( 21 | discriminator='disc_partial', 22 | real_data='real_imgs', 23 | fake_data='fake_imgs')) 24 | ]) 25 | 26 | train_cfg = dict( 27 | use_ema=True, 28 | nkimgs_per_scale={ 29 | '4': 600, 30 | '8': 1200, 31 | '16': 1200, 32 | '32': 1200, 33 | '64': 1200, 34 | '128': 12000 35 | }, 36 | transition_kimgs=600, 37 | optimizer_cfg=dict( 38 | generator=dict(type='Adam', lr=0.001, betas=(0., 0.99)), 39 | discriminator=dict(type='Adam', lr=0.001, betas=(0., 0.99))), 40 | g_lr_base=0.001, 41 | d_lr_base=0.001, 42 | g_lr_schedule={'128': 0.0015}, 43 | d_lr_schedule={'128': 0.0015}) 44 | 45 | test_cfg = None 46 | -------------------------------------------------------------------------------- /configs/_base_/models/stylegan/stylegan3_base.py: -------------------------------------------------------------------------------- 1 | # define GAN model 2 | 3 | d_reg_interval = 16 4 | g_reg_interval = 4 5 | 6 | g_reg_ratio = g_reg_interval / (g_reg_interval + 1) 7 | d_reg_ratio = d_reg_interval / (d_reg_interval + 1) 8 | 9 | model = dict( 10 | type='StaticUnconditionalGAN', 11 | generator=dict( 12 | type='StyleGANv3Generator', 13 | noise_size=512, 14 | style_channels=512, 15 | out_size=None, # Need to be set. 16 | img_channels=3, 17 | ), 18 | discriminator=dict( 19 | type='StyleGAN2Discriminator', 20 | in_size=None, # Need to be set. 21 | ), 22 | gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'), 23 | disc_auxiliary_loss=dict( 24 | type='R1GradientPenalty', 25 | loss_weight=10. / 2. * d_reg_interval, 26 | interval=d_reg_interval, 27 | norm_mode='HWC', 28 | data_info=dict(real_data='real_imgs', discriminator='disc')), 29 | gen_auxiliary_loss=dict( 30 | type='GeneratorPathRegularizer', 31 | loss_weight=2. * g_reg_interval, 32 | pl_batch_shrink=2, 33 | interval=g_reg_interval, 34 | data_info=dict(generator='gen', num_batches='batch_size'))) 35 | 36 | train_cfg = dict(use_ema=True) 37 | test_cfg = None 38 | 39 | # define optimizer 40 | optimizer = dict( 41 | generator=dict( 42 | type='Adam', lr=0.0025 * g_reg_ratio, betas=(0, 0.99**g_reg_ratio)), 43 | discriminator=dict( 44 | type='Adam', lr=0.002 * d_reg_ratio, betas=(0, 0.99**d_reg_ratio))) 45 | -------------------------------------------------------------------------------- /configs/styleganv1/styleganv1_ffhq_1024_g8_25Mimg.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/stylegan/styleganv1_base.py', 3 | '../_base_/datasets/grow_scale_imgs_ffhq_styleganv1.py', 4 | '../_base_/default_runtime.py', 5 | ] 6 | 7 | model = dict(generator=dict(out_size=1024), discriminator=dict(in_size=1024)) 8 | 9 | train_cfg = dict( 10 | nkimgs_per_scale={ 11 | '8': 1200, 12 | '16': 1200, 13 | '32': 1200, 14 | '64': 1200, 15 | '128': 1200, 16 | '256': 1200, 17 | '512': 1200, 18 | '1024': 166000 19 | }) 20 | 21 | checkpoint_config = dict(interval=5000, by_epoch=False, max_keep_ckpts=20) 22 | lr_config = None 23 | 24 | ema_half_life = 10. # G_smoothing_kimg 25 | 26 | custom_hooks = [ 27 | dict( 28 | type='VisualizeUnconditionalSamples', 29 | output_dir='training_samples', 30 | interval=5000), 31 | dict(type='PGGANFetchDataHook', interval=1), 32 | dict( 33 | type='ExponentialMovingAverageHook', 34 | module_keys=('generator_ema', ), 35 | interval=1, 36 | interp_cfg=dict(momentum=0.5**(32. / (ema_half_life * 1000.))), 37 | priority='VERY_HIGH') 38 | ] 39 | 40 | total_iters = 670000 41 | 42 | metrics = dict( 43 | fid50k=dict( 44 | type='FID', 45 | num_images=50000, 46 | inception_pkl='work_dirs/inception_pkl/ffhq-1024-50k-rgb.pkl', 47 | bgr2rgb=True), 48 | pr50k3=dict(type='PR', num_images=50000, k=3), 49 | ppl_wend=dict(type='PPL', space='W', sampling='end', num_images=50000)) 50 | -------------------------------------------------------------------------------- /configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_cifar10_32x32_b8x16_500k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/improved_ddpm/ddpm_32x32.py', 3 | '../_base_/datasets/cifar10_noaug.py', '../_base_/default_runtime.py' 4 | ] 5 | 6 | lr_config = None 7 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=20) 8 | custom_hooks = [ 9 | dict( 10 | type='MMGenVisualizationHook', 11 | output_dir='training_samples', 12 | res_name_list=['real_imgs', 'x_0_pred', 'x_t', 'x_t_1'], 13 | padding=1, 14 | interval=1000), 15 | dict( 16 | type='ExponentialMovingAverageHook', 17 | module_keys=('denoising_ema'), 18 | interval=1, 19 | start_iter=0, 20 | interp_cfg=dict(momentum=0.9999), 21 | priority='VERY_HIGH') 22 | ] 23 | 24 | # do not evaluation in training process because evaluation take too much time. 25 | evaluation = None 26 | 27 | total_iters = 500000 # 500k 28 | data = dict(samples_per_gpu=16) # 8x16=128 29 | 30 | # use ddp wrapper for faster training 31 | use_ddp_wrapper = True 32 | find_unused_parameters = False 33 | 34 | runner = dict( 35 | type='DynamicIterBasedRunner', 36 | is_dynamic_ddp=False, # Note that this flag should be False. 37 | pass_training_status=True) 38 | 39 | inception_pkl = './work_dirs/inception_pkl/cifar10.pkl' 40 | metrics = dict( 41 | fid50k=dict( 42 | type='FID', 43 | num_images=50000, 44 | bgr2rgb=True, 45 | inception_pkl=inception_pkl, 46 | inception_args=dict(type='StyleGAN'))) 47 | -------------------------------------------------------------------------------- /configs/singan/metafile.yml: -------------------------------------------------------------------------------- 1 | Collections: 2 | - Metadata: 3 | Architecture: 4 | - SinGAN 5 | Name: SinGAN 6 | Paper: 7 | - https://openaccess.thecvf.com/content_ICCV_2019/html/Shaham_SinGAN_Learning_a_Generative_Model_From_a_Single_Natural_Image_ICCV_2019_paper.html 8 | README: configs/singan/README.md 9 | Models: 10 | - Config: https://github.com/open-mmlab/mmgeneration/tree/master/configs/singan/singan_balloons.py 11 | In Collection: SinGAN 12 | Metadata: 13 | Training Data: Others 14 | Name: singan_balloons 15 | Results: 16 | - Dataset: Others 17 | Metrics: 18 | Num Scales: 8.0 19 | Task: Internal Learning 20 | Weights: https://download.openmmlab.com/mmgen/singan/singan_balloons_20210406_191047-8fcd94cf.pth 21 | - Config: https://github.com/open-mmlab/mmgeneration/tree/master/configs/singan/singan_fish.py 22 | In Collection: SinGAN 23 | Metadata: 24 | Training Data: Others 25 | Name: singan_fish 26 | Results: 27 | - Dataset: Others 28 | Metrics: 29 | Num Scales: 10.0 30 | Task: Internal Learning 31 | Weights: https://download.openmmlab.com/mmgen/singan/singan_fis_20210406_201006-860d91b6.pth 32 | - Config: https://github.com/open-mmlab/mmgeneration/tree/master/configs/singan/singan_bohemian.py 33 | In Collection: SinGAN 34 | Metadata: 35 | Training Data: Others 36 | Name: singan_bohemian 37 | Results: 38 | - Dataset: Others 39 | Metrics: 40 | Num Scales: 10.0 41 | Task: Internal Learning 42 | Weights: https://download.openmmlab.com/mmgen/singan/singan_bohemian_20210406_175439-f964ee38.pth 43 | -------------------------------------------------------------------------------- /configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_imagenet1k_64x64_b8x16_1500k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/improved_ddpm/ddpm_64x64.py', 3 | '../_base_/datasets/imagenet_noaug_64.py', '../_base_/default_runtime.py' 4 | ] 5 | 6 | lr_config = None 7 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=20) 8 | custom_hooks = [ 9 | dict( 10 | type='MMGenVisualizationHook', 11 | output_dir='training_samples', 12 | res_name_list=['real_imgs', 'x_0_pred', 'x_t', 'x_t_1'], 13 | padding=1, 14 | interval=1000), 15 | dict( 16 | type='ExponentialMovingAverageHook', 17 | module_keys=('denoising_ema'), 18 | interval=1, 19 | start_iter=0, 20 | interp_cfg=dict(momentum=0.9999), 21 | priority='VERY_HIGH') 22 | ] 23 | 24 | # do not evaluation in training process because evaluation take too much time. 25 | evaluation = None 26 | 27 | total_iters = 1500000 # 1500k 28 | data = dict(samples_per_gpu=16) # 8x16=128 29 | 30 | # use ddp wrapper for faster training 31 | use_ddp_wrapper = True 32 | find_unused_parameters = False 33 | 34 | runner = dict( 35 | type='DynamicIterBasedRunner', 36 | is_dynamic_ddp=False, # Note that this flag should be False. 37 | pass_training_status=True) 38 | 39 | inception_pkl = './work_dirs/inception_pkl/imagenet_64x64.pkl' 40 | metrics = dict( 41 | fid50k=dict( 42 | type='FID', 43 | num_images=50000, 44 | bgr2rgb=True, 45 | inception_pkl=inception_pkl, 46 | inception_args=dict(type='StyleGAN'))) 47 | -------------------------------------------------------------------------------- /configs/_base_/datasets/cifar10_noaug.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'mmcls.CIFAR10' 2 | 3 | # different from mmcls, we adopt the setting used in BigGAN 4 | # Note that the pipelines below are from MMClassification. Importantly, the 5 | # `to_rgb` is set to `True` to convert image to BGR orders. The default order 6 | # in Cifar10 is RGB. Thus, we have to convert it to BGR. 7 | 8 | # Cifar dataset w/o augmentations. Remove `RandomFlip` and `RandomCrop` 9 | # augmentations. 10 | img_norm_cfg = dict( 11 | mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) 12 | train_pipeline = [ 13 | dict(type='Normalize', **img_norm_cfg), 14 | dict(type='ImageToTensor', keys=['img']), 15 | dict(type='ToTensor', keys=['gt_label']), 16 | dict(type='Collect', keys=['img', 'gt_label']) 17 | ] 18 | test_pipeline = [ 19 | dict(type='Normalize', **img_norm_cfg), 20 | dict(type='ImageToTensor', keys=['img']), 21 | dict(type='Collect', keys=['img']) 22 | ] 23 | 24 | # Different from the classification task, the val/test split also use the 25 | # training part, which is the same to StyleGAN-ADA. 26 | data = dict( 27 | samples_per_gpu=None, 28 | workers_per_gpu=4, 29 | train=dict( 30 | type='RepeatDataset', 31 | times=500, 32 | dataset=dict( 33 | type=dataset_type, 34 | data_prefix='data/cifar10', 35 | pipeline=train_pipeline)), 36 | val=dict( 37 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline), 38 | test=dict( 39 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline)) 40 | -------------------------------------------------------------------------------- /configs/_base_/models/improved_ddpm/ddpm_64x64.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='BasicGaussianDiffusion', 3 | num_timesteps=4000, 4 | betas_cfg=dict(type='cosine'), 5 | denoising=dict( 6 | type='DenoisingUnet', 7 | image_size=64, 8 | in_channels=3, 9 | base_channels=128, 10 | resblocks_per_downsample=3, 11 | attention_res=[16, 8], 12 | use_scale_shift_norm=True, 13 | dropout=0, 14 | num_heads=4, 15 | use_rescale_timesteps=True, 16 | output_cfg=dict(mean='eps', var='learned_range')), 17 | timestep_sampler=dict(type='UniformTimeStepSampler'), 18 | ddpm_loss=[ 19 | dict( 20 | type='DDPMVLBLoss', 21 | rescale_mode='constant', 22 | rescale_cfg=dict(scale=4000 / 1000), 23 | data_info=dict( 24 | mean_pred='mean_pred', 25 | mean_target='mean_posterior', 26 | logvar_pred='logvar_pred', 27 | logvar_target='logvar_posterior'), 28 | log_cfgs=[ 29 | dict( 30 | type='quartile', 31 | prefix_name='loss_vlb', 32 | total_timesteps=4000), 33 | dict(type='name') 34 | ]), 35 | dict( 36 | type='DDPMMSELoss', 37 | log_cfgs=dict( 38 | type='quartile', prefix_name='loss_mse', total_timesteps=4000), 39 | ) 40 | ], 41 | ) 42 | 43 | train_cfg = dict(use_ema=True, real_img_key='img') 44 | test_cfg = None 45 | optimizer = dict(denoising=dict(type='AdamW', lr=1e-4, weight_decay=0)) 46 | -------------------------------------------------------------------------------- /configs/lsgan/lsgan_lsgan-archi_lr-1e-4_lsun-bedroom_128_b64x1_10m.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/lsgan/lsgan_128x128.py', 3 | '../_base_/datasets/unconditional_imgs_128x128.py', 4 | '../_base_/default_runtime.py' 5 | ] 6 | # define dataset 7 | # you must set `samples_per_gpu` and `imgs_root` 8 | data = dict( 9 | samples_per_gpu=64, train=dict(imgs_root='./data/lsun/bedroom_train')) 10 | 11 | optimizer = dict( 12 | generator=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99)), 13 | discriminator=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99))) 14 | 15 | # adjust running config 16 | lr_config = None 17 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=20) 18 | custom_hooks = [ 19 | dict( 20 | type='VisualizeUnconditionalSamples', 21 | output_dir='training_samples', 22 | interval=10000) 23 | ] 24 | 25 | evaluation = dict( 26 | type='GenerativeEvalHook', 27 | interval=10000, 28 | metrics=dict( 29 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True), 30 | sample_kwargs=dict(sample_model='orig')) 31 | 32 | total_iters = 160000 33 | # use ddp wrapper for faster training 34 | use_ddp_wrapper = True 35 | find_unused_parameters = False 36 | 37 | runner = dict( 38 | type='DynamicIterBasedRunner', 39 | is_dynamic_ddp=False, # Note that this flag should be False. 40 | pass_training_status=True) 41 | 42 | metrics = dict( 43 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000), 44 | swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 128, 128)), 45 | fid50k=dict(type='FID', num_images=50000, inception_pkl=None)) 46 | -------------------------------------------------------------------------------- /configs/_base_/models/improved_ddpm/ddpm_32x32.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='BasicGaussianDiffusion', 3 | num_timesteps=4000, 4 | betas_cfg=dict(type='cosine'), 5 | denoising=dict( 6 | type='DenoisingUnet', 7 | image_size=32, 8 | in_channels=3, 9 | base_channels=128, 10 | resblocks_per_downsample=3, 11 | attention_res=[16, 8], 12 | use_scale_shift_norm=True, 13 | dropout=0.3, 14 | num_heads=4, 15 | use_rescale_timesteps=True, 16 | output_cfg=dict(mean='eps', var='learned_range')), 17 | timestep_sampler=dict(type='UniformTimeStepSampler'), 18 | ddpm_loss=[ 19 | dict( 20 | type='DDPMVLBLoss', 21 | rescale_mode='constant', 22 | rescale_cfg=dict(scale=4000 / 1000), 23 | data_info=dict( 24 | mean_pred='mean_pred', 25 | mean_target='mean_posterior', 26 | logvar_pred='logvar_pred', 27 | logvar_target='logvar_posterior'), 28 | log_cfgs=[ 29 | dict( 30 | type='quartile', 31 | prefix_name='loss_vlb', 32 | total_timesteps=4000), 33 | dict(type='name') 34 | ]), 35 | dict( 36 | type='DDPMMSELoss', 37 | log_cfgs=dict( 38 | type='quartile', prefix_name='loss_mse', total_timesteps=4000), 39 | ) 40 | ], 41 | ) 42 | 43 | train_cfg = dict(use_ema=True, real_img_key='img') 44 | test_cfg = None 45 | optimizer = dict(denoising=dict(type='AdamW', lr=1e-4, weight_decay=0)) 46 | -------------------------------------------------------------------------------- /configs/styleganv2/stylegan2_c2_lsun-car_384x512_b4x8.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/lsun-car_pad_512.py', 3 | '../_base_/models/stylegan/stylegan2_base.py', 4 | '../_base_/default_runtime.py' 5 | ] 6 | 7 | model = dict(generator=dict(out_size=512), discriminator=dict(in_size=512)) 8 | 9 | data = dict( 10 | samples_per_gpu=4, 11 | train=dict(dataset=dict(imgs_root='./data/lsun/images/car')), 12 | val=dict(imgs_root='./data/lsun/images/car')) 13 | 14 | ema_half_life = 10. # G_smoothing_kimg 15 | 16 | custom_hooks = [ 17 | dict( 18 | type='VisualizeUnconditionalSamples', 19 | output_dir='training_samples', 20 | interval=5000), 21 | dict( 22 | type='ExponentialMovingAverageHook', 23 | module_keys=('generator_ema', ), 24 | interval=1, 25 | interp_cfg=dict(momentum=0.5**(32. / (ema_half_life * 1000.))), 26 | priority='VERY_HIGH') 27 | ] 28 | 29 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=40) 30 | lr_config = None 31 | 32 | total_iters = 1800002 33 | 34 | metrics = dict( 35 | fid50k=dict( 36 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True), 37 | pr50k3=dict(type='PR', num_images=50000, k=3), 38 | ppl_wend=dict(type='PPL', space='W', sampling='end', num_images=50000)) 39 | 40 | evaluation = dict( 41 | type='GenerativeEvalHook', 42 | interval=10000, 43 | metrics=dict( 44 | type='FID', 45 | num_images=50000, 46 | inception_pkl='work_dirs/inception_pkl/lsun-car-512-50k-rgb.pkl', 47 | bgr2rgb=True), 48 | sample_kwargs=dict(sample_model='ema')) 49 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/error-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Error report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | --- 8 | 9 | Thanks for your error report and we appreciate it a lot. 10 | 11 | **Checklist** 12 | 13 | 1. I have searched related issues but cannot get the expected help. 14 | 2. I have read the [FAQ documentation](https://mmgeneration.readthedocs.io/en/latest/faq.html) but cannot get the expected help. 15 | 3. The bug has not been fixed in the latest version. 16 | 17 | **Describe the bug** 18 | A clear and concise description of what the bug is. 19 | 20 | **Reproduction** 21 | 22 | 1. What command or script did you run? 23 | 24 | ```none 25 | A placeholder for the command. 26 | ``` 27 | 28 | 2. Did you make any modifications on the code or config? Did you understand what you have modified? 29 | 3. What dataset did you use? 30 | 31 | **Environment** 32 | 33 | 1. Please run `python mmgen/utils/collect_env.py` to collect necessary environment information and paste it here. 34 | 2. You may add addition that may be helpful for locating the problem, such as 35 | - How you installed PyTorch \[e.g., pip, conda, source\] 36 | - Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.) 37 | 38 | **Error traceback** 39 | If applicable, paste the error trackback here. 40 | 41 | ```none 42 | A placeholder for trackback. 43 | ``` 44 | 45 | **Bug fix** 46 | If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated! 47 | -------------------------------------------------------------------------------- /configs/_base_/datasets/imagenet_rgb.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'mmcls.ImageNet' 3 | 4 | # Note that the pipelines below are from MMClassification 5 | img_norm_cfg = dict( 6 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 7 | train_pipeline = [ 8 | dict(type='LoadImageFromFile'), 9 | dict(type='RandomResizedCrop', size=224, backend='pillow'), 10 | dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), 11 | dict(type='Normalize', **img_norm_cfg), 12 | dict(type='ImageToTensor', keys=['img']), 13 | dict(type='ToTensor', keys=['gt_label']), 14 | dict(type='Collect', keys=['img', 'gt_label']) 15 | ] 16 | 17 | test_pipeline = [ 18 | dict(type='LoadImageFromFile'), 19 | dict(type='Resize', size=(256, -1), backend='pillow'), 20 | dict(type='CenterCrop', crop_size=224), 21 | dict(type='Normalize', **img_norm_cfg), 22 | dict(type='ImageToTensor', keys=['img']), 23 | dict(type='Collect', keys=['img']) 24 | ] 25 | 26 | data = dict( 27 | samples_per_gpu=64, 28 | workers_per_gpu=2, 29 | train=dict( 30 | type=dataset_type, 31 | data_prefix='data/imagenet/train', 32 | pipeline=train_pipeline), 33 | val=dict( 34 | type=dataset_type, 35 | data_prefix='data/imagenet/val', 36 | ann_file='data/imagenet/meta/val.txt', 37 | pipeline=test_pipeline), 38 | test=dict( 39 | # replace `data/val` with `data/test` for standard test 40 | type=dataset_type, 41 | data_prefix='data/imagenet/val', 42 | ann_file='data/imagenet/meta/val.txt', 43 | pipeline=test_pipeline)) 44 | -------------------------------------------------------------------------------- /configs/_base_/datasets/cifar10_random_noise.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'mmcls.CIFAR10' 2 | 3 | # cifar dataset without augmentation 4 | # different from mmcls, we adopt the setting used in BigGAN 5 | # Note that the pipelines below are from MMClassification. Importantly, the 6 | # `to_rgb` is set to `True` to convert image to BGR orders. The default order 7 | # in Cifar10 is RGB. Thus, we have to convert it to BGR. 8 | 9 | # Follow the pipeline in 10 | # https://github.com/pfnet-research/sngan_projection/blob/master/datasets/cifar10.py 11 | # Only `RandomImageNoise` augmentation is adopted. 12 | img_norm_cfg = dict( 13 | mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) 14 | train_pipeline = [ 15 | dict(type='Normalize', **img_norm_cfg), 16 | dict(type='RandomImgNoise', keys=['img']), 17 | dict(type='ImageToTensor', keys=['img']), 18 | dict(type='ToTensor', keys=['gt_label']), 19 | dict(type='Collect', keys=['img', 'gt_label']) 20 | ] 21 | test_pipeline = [ 22 | dict(type='Normalize', **img_norm_cfg), 23 | dict(type='ImageToTensor', keys=['img']), 24 | dict(type='Collect', keys=['img']) 25 | ] 26 | 27 | # Different from the classification task, the val/test split also use the 28 | # training part, which is the same to StyleGAN-ADA. 29 | data = dict( 30 | samples_per_gpu=None, 31 | workers_per_gpu=4, 32 | train=dict( 33 | type=dataset_type, data_prefix='data/cifar10', 34 | pipeline=train_pipeline), 35 | val=dict( 36 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline), 37 | test=dict( 38 | type=dataset_type, data_prefix='data/cifar10', pipeline=test_pipeline)) 39 | -------------------------------------------------------------------------------- /tests/test_modules/test_mspie_archs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from copy import deepcopy 3 | 4 | import torch 5 | 6 | from mmgen.models.architectures.stylegan import MSStyleGANv2Generator 7 | 8 | 9 | class TestMSStyleGAN2: 10 | 11 | @classmethod 12 | def setup_class(cls): 13 | cls.generator_cfg = dict(out_size=32, style_channels=16) 14 | cls.disc_cfg = dict(in_size=32, with_adaptive_pool=True) 15 | 16 | def test_msstylegan2_cpu(self): 17 | 18 | # test normal forward 19 | cfg_ = deepcopy(self.generator_cfg) 20 | g = MSStyleGANv2Generator(**cfg_) 21 | res = g(None, num_batches=2) 22 | assert res.shape == (2, 3, 32, 32) 23 | 24 | # set mix_prob as 1.0 and 0 to force cover lines 25 | cfg_ = deepcopy(self.generator_cfg) 26 | cfg_['mix_prob'] = 1 27 | g = MSStyleGANv2Generator(**cfg_) 28 | res = g(torch.randn, num_batches=2) 29 | assert res.shape == (2, 3, 32, 32) 30 | 31 | cfg_ = deepcopy(self.generator_cfg) 32 | cfg_['mix_prob'] = 0 33 | g = MSStyleGANv2Generator(**cfg_) 34 | res = g(torch.randn, num_batches=2) 35 | assert res.shape == (2, 3, 32, 32) 36 | 37 | cfg_ = deepcopy(self.generator_cfg) 38 | cfg_['mix_prob'] = 1 39 | g = MSStyleGANv2Generator(**cfg_) 40 | res = g(None, num_batches=2) 41 | assert res.shape == (2, 3, 32, 32) 42 | 43 | cfg_ = deepcopy(self.generator_cfg) 44 | cfg_['mix_prob'] = 0 45 | g = MSStyleGANv2Generator(**cfg_) 46 | res = g(None, num_batches=2) 47 | assert res.shape == (2, 3, 32, 32) 48 | -------------------------------------------------------------------------------- /tests/test_ops/test_conv_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from functools import partial 3 | 4 | import pytest 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import gradgradcheck 8 | 9 | from mmgen.ops import conv2d, conv_transpose2d 10 | 11 | 12 | class TestCond2d: 13 | 14 | @classmethod 15 | def setup_class(cls): 16 | cls.input = torch.randn((1, 3, 32, 32)) 17 | cls.weight = nn.Parameter(torch.randn(1, 3, 3, 3)) 18 | 19 | @pytest.mark.skipif( 20 | not torch.cuda.is_available() 21 | or not hasattr(torch.backends.cudnn, 'allow_tf32'), 22 | reason='requires cuda') 23 | def test_conv2d_cuda(self): 24 | x = self.input.cuda() 25 | weight = self.weight.cuda() 26 | res = conv2d(x, weight, None, 1, 1) 27 | assert res.shape == (1, 1, 32, 32) 28 | gradgradcheck(partial(conv2d, weight=weight, padding=1, stride=1), x) 29 | 30 | 31 | class TestCond2dTansposed: 32 | 33 | @classmethod 34 | def setup_class(cls): 35 | cls.input = torch.randn((1, 3, 32, 32)) 36 | cls.weight = nn.Parameter(torch.randn(3, 1, 3, 3)) 37 | 38 | @pytest.mark.skipif( 39 | not torch.cuda.is_available() 40 | or not hasattr(torch.backends.cudnn, 'allow_tf32'), 41 | reason='requires cuda') 42 | def test_conv2d_transposed_cuda(self): 43 | x = self.input.cuda() 44 | weight = self.weight.cuda() 45 | res = conv_transpose2d(x, weight, None, 1, 1) 46 | assert res.shape == (1, 1, 32, 32) 47 | gradgradcheck( 48 | partial(conv_transpose2d, weight=weight, padding=1, stride=1), x) 49 | -------------------------------------------------------------------------------- /configs/improved_ddpm/ddpm_cosine_hybird_timestep-4k_drop0.3_imagenet1k_64x64_b8x16_1500k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/improved_ddpm/ddpm_64x64.py', 3 | '../_base_/datasets/imagenet_noaug_64.py', '../_base_/default_runtime.py' 4 | ] 5 | 6 | # set dropout prob as 0.3 7 | model = dict(denoising=dict(dropout=0.3)) 8 | 9 | lr_config = None 10 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=20) 11 | custom_hooks = [ 12 | dict( 13 | type='MMGenVisualizationHook', 14 | output_dir='training_samples', 15 | res_name_list=['real_imgs', 'x_0_pred', 'x_t', 'x_t_1'], 16 | padding=1, 17 | interval=1000), 18 | dict( 19 | type='ExponentialMovingAverageHook', 20 | module_keys=('denoising_ema'), 21 | interval=1, 22 | start_iter=0, 23 | interp_cfg=dict(momentum=0.9999), 24 | priority='VERY_HIGH') 25 | ] 26 | 27 | # do not evaluation in training process because evaluation take too much time. 28 | evaluation = None 29 | 30 | total_iters = 1500000 # 1500k 31 | data = dict(samples_per_gpu=16) # 8x16=128 32 | 33 | # use ddp wrapper for faster training 34 | use_ddp_wrapper = True 35 | find_unused_parameters = False 36 | 37 | runner = dict( 38 | type='DynamicIterBasedRunner', 39 | is_dynamic_ddp=False, # Note that this flag should be False. 40 | pass_training_status=True) 41 | 42 | inception_pkl = './work_dirs/inception_pkl/imagenet_64x64.pkl' 43 | metrics = dict( 44 | fid50k=dict( 45 | type='FID', 46 | num_images=50000, 47 | bgr2rgb=True, 48 | inception_pkl=inception_pkl, 49 | inception_args=dict(type='StyleGAN'))) 50 | -------------------------------------------------------------------------------- /configs/lsgan/lsgan_dcgan-archi_lr-1e-3_celeba-cropped_64_b128x1_12m.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/dcgan/dcgan_64x64.py', 3 | '../_base_/datasets/unconditional_imgs_64x64.py', 4 | '../_base_/default_runtime.py' 5 | ] 6 | model = dict(gan_loss=dict(type='GANLoss', gan_type='lsgan')) 7 | # define dataset 8 | # you must set `samples_per_gpu` and `imgs_root` 9 | data = dict( 10 | samples_per_gpu=128, 11 | train=dict(imgs_root='./data/celeba-cropped/cropped_images_aligned_png/')) 12 | 13 | optimizer = dict( 14 | generator=dict(type='Adam', lr=0.001, betas=(0.5, 0.99)), 15 | discriminator=dict(type='Adam', lr=0.001, betas=(0.5, 0.99))) 16 | 17 | # adjust running config 18 | lr_config = None 19 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=20) 20 | custom_hooks = [ 21 | dict( 22 | type='VisualizeUnconditionalSamples', 23 | output_dir='training_samples', 24 | interval=10000) 25 | ] 26 | 27 | evaluation = dict( 28 | type='GenerativeEvalHook', 29 | interval=10000, 30 | metrics=dict( 31 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True), 32 | sample_kwargs=dict(sample_model='orig')) 33 | 34 | total_iters = 100000 35 | # use ddp wrapper for faster training 36 | use_ddp_wrapper = True 37 | find_unused_parameters = False 38 | 39 | runner = dict( 40 | type='DynamicIterBasedRunner', 41 | is_dynamic_ddp=False, # Note that this flag should be False. 42 | pass_training_status=True) 43 | 44 | metrics = dict( 45 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000), 46 | swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 64, 64)), 47 | fid50k=dict(type='FID', num_images=50000, inception_pkl=None)) 48 | -------------------------------------------------------------------------------- /mmgen/ops/stylegan3/ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign read mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /mmgen/ops/stylegan3/ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign write mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /mmgen/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | import torch.distributed as dist 5 | from mmcv.runner import get_dist_info 6 | 7 | 8 | def check_dist_init(): 9 | return dist.is_available() and dist.is_initialized() 10 | 11 | 12 | def sync_random_seed(seed=None, device='cuda'): 13 | """Make sure different ranks share the same seed. 14 | 15 | All workers must call 16 | this function, otherwise it will deadlock. This method is generally used in 17 | `DistributedSampler`, because the seed should be identical across all 18 | processes in the distributed group. 19 | In distributed sampling, different ranks should sample non-overlapped 20 | data in the dataset. Therefore, this function is used to make sure that 21 | each rank shuffles the data indices in the same order based 22 | on the same seed. Then different ranks could use different indices 23 | to select non-overlapped data from the same data list. 24 | Args: 25 | seed (int, Optional): The seed. Default to None. 26 | device (str): The device where the seed will be put on. 27 | Default to 'cuda'. 28 | Returns: 29 | int: Seed to be used. 30 | """ 31 | if seed is None: 32 | seed = np.random.randint(2**31) 33 | assert isinstance(seed, int) 34 | 35 | rank, world_size = get_dist_info() 36 | 37 | if world_size == 1: 38 | return seed 39 | 40 | if rank == 0: 41 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 42 | else: 43 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 44 | dist.broadcast(random_num, src=0) 45 | return random_num.item() 46 | -------------------------------------------------------------------------------- /mmgen/ops/stylegan3/ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for no signs mode (no gradients required). 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /configs/lsgan/lsgan_dcgan-archi_lr-1e-4_lsun-bedroom_64_b128x1_12m.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/dcgan/dcgan_64x64.py', 3 | '../_base_/datasets/unconditional_imgs_64x64.py', 4 | '../_base_/default_runtime.py' 5 | ] 6 | model = dict( 7 | discriminator=dict(output_scale=4, out_channels=1), 8 | gan_loss=dict(type='GANLoss', gan_type='lsgan')) 9 | # define dataset 10 | # you must set `samples_per_gpu` and `imgs_root` 11 | data = dict( 12 | samples_per_gpu=128, train=dict(imgs_root='./data/lsun/bedroom_train')) 13 | 14 | optimizer = dict( 15 | generator=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99)), 16 | discriminator=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99))) 17 | 18 | # adjust running config 19 | lr_config = None 20 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=20) 21 | custom_hooks = [ 22 | dict( 23 | type='VisualizeUnconditionalSamples', 24 | output_dir='training_samples', 25 | interval=10000) 26 | ] 27 | 28 | evaluation = dict( 29 | type='GenerativeEvalHook', 30 | interval=10000, 31 | metrics=dict( 32 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True), 33 | sample_kwargs=dict(sample_model='orig')) 34 | 35 | total_iters = 100000 36 | # use ddp wrapper for faster training 37 | use_ddp_wrapper = True 38 | find_unused_parameters = False 39 | 40 | runner = dict( 41 | type='DynamicIterBasedRunner', 42 | is_dynamic_ddp=False, # Note that this flag should be False. 43 | pass_training_status=True) 44 | 45 | metrics = dict( 46 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000), 47 | swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 64, 64)), 48 | fid50k=dict(type='FID', num_images=50000, inception_pkl=None)) 49 | -------------------------------------------------------------------------------- /configs/styleganv2/stylegan2_c2_ffhq_1024_b4x8.py: -------------------------------------------------------------------------------- 1 | """Config for the `config-f` setting in StyleGAN2.""" 2 | 3 | _base_ = [ 4 | '../_base_/datasets/ffhq_flip.py', 5 | '../_base_/models/stylegan/stylegan2_base.py', 6 | '../_base_/default_runtime.py' 7 | ] 8 | 9 | ema_half_life = 10. # G_smoothing_kimg 10 | 11 | model = dict(generator=dict(out_size=1024), discriminator=dict(in_size=1024)) 12 | 13 | data = dict( 14 | samples_per_gpu=4, 15 | train=dict(dataset=dict(imgs_root='./data/ffhq/images')), 16 | val=dict(imgs_root='./data/ffhq/images')) 17 | 18 | custom_hooks = [ 19 | dict( 20 | type='VisualizeUnconditionalSamples', 21 | output_dir='training_samples', 22 | interval=5000), 23 | dict( 24 | type='ExponentialMovingAverageHook', 25 | module_keys=('generator_ema', ), 26 | interval=1, 27 | interp_cfg=dict(momentum=0.5**(32. / (ema_half_life * 1000.))), 28 | priority='VERY_HIGH') 29 | ] 30 | 31 | metrics = dict( 32 | fid50k=dict( 33 | type='FID', 34 | num_images=50000, 35 | inception_pkl='work_dirs/inception_pkl/ffhq-1024-50k-rgb.pkl', 36 | bgr2rgb=True), 37 | pr50k3=dict(type='PR', num_images=50000, k=3), 38 | ppl_wend=dict(type='PPL', space='W', sampling='end', num_images=50000)) 39 | 40 | evaluation = dict( 41 | type='GenerativeEvalHook', 42 | interval=10000, 43 | metrics=dict( 44 | type='FID', 45 | num_images=50000, 46 | inception_pkl='work_dirs/inception_pkl/ffhq-1024-50k-rgb.pkl', 47 | bgr2rgb=True), 48 | sample_kwargs=dict(sample_model='ema')) 49 | 50 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=30) 51 | lr_config = None 52 | 53 | total_iters = 800002 54 | -------------------------------------------------------------------------------- /configs/lsgan/lsgan_dcgan-archi_lr-1e-4_celeba-cropped_128_b64x1_10m.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/dcgan/dcgan_128x128.py', 3 | '../_base_/datasets/unconditional_imgs_128x128.py', 4 | '../_base_/default_runtime.py' 5 | ] 6 | model = dict( 7 | discriminator=dict(output_scale=4, out_channels=1), 8 | gan_loss=dict(type='GANLoss', gan_type='lsgan')) 9 | # define dataset 10 | # you must set `samples_per_gpu` and `imgs_root` 11 | data = dict( 12 | samples_per_gpu=64, 13 | train=dict(imgs_root='./data/celeba-cropped/cropped_images_aligned_png/')) 14 | 15 | optimizer = dict( 16 | generator=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99)), 17 | discriminator=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99))) 18 | 19 | # adjust running config 20 | lr_config = None 21 | checkpoint_config = dict(interval=10000, by_epoch=False, max_keep_ckpts=20) 22 | custom_hooks = [ 23 | dict( 24 | type='VisualizeUnconditionalSamples', 25 | output_dir='training_samples', 26 | interval=10000) 27 | ] 28 | 29 | evaluation = dict( 30 | type='GenerativeEvalHook', 31 | interval=10000, 32 | metrics=dict( 33 | type='FID', num_images=50000, inception_pkl=None, bgr2rgb=True), 34 | sample_kwargs=dict(sample_model='orig')) 35 | 36 | total_iters = 160000 37 | # use ddp wrapper for faster training 38 | use_ddp_wrapper = True 39 | find_unused_parameters = False 40 | 41 | runner = dict( 42 | type='DynamicIterBasedRunner', 43 | is_dynamic_ddp=False, # Note that this flag should be False. 44 | pass_training_status=True) 45 | 46 | metrics = dict( 47 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000), 48 | swd16k=dict(type='SWD', num_images=16384, image_shape=(3, 128, 128)), 49 | fid50k=dict(type='FID', num_images=50000, inception_pkl=None)) 50 | --------------------------------------------------------------------------------